Skip to content

Commit

Permalink
[GLUTEN-1632][CH]Daily Update Clickhouse Version (20240815) (#6848)
Browse files Browse the repository at this point in the history
* [GLUTEN-1632][CH]Daily Update Clickhouse Version (20240815)

* Fix Build due to ClickHouse/ClickHouse#68107

* Fix Build due to ClickHouse/ClickHouse#68135

* Fix UT due to ClickHouse/ClickHouse#68135

* Add ut for #2584

- Rebase failed with ClickHouse/ClickHouse#67879, and hence we can remove Kyligence/ClickHouse#454

(cherry picked from commit 583aa8d6566a9e1c0924c1a3ab1d315fcc229fa6)

* Fix CH BUG due to ClickHouse/ClickHouse#68135

see Kyligence/ClickHouse@d87dbba

* Resolve conflict

---------

Co-authored-by: kyligence-git <[email protected]>
Co-authored-by: Chang Chen <[email protected]>
  • Loading branch information
3 people authored Aug 15, 2024
1 parent 4ae3cb1 commit 2ad689b
Show file tree
Hide file tree
Showing 38 changed files with 346 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -937,4 +937,37 @@ class GlutenClickHouseNativeWriteTableSuite
_ => {})
)
}

test("GLUTEN-2584: fix native write and read mismatch about complex types") {
def table(format: String): String = s"t_$format"
def create(format: String, table_name: Option[String] = None): String =
s"""CREATE TABLE ${table_name.getOrElse(table(format))}(
| id INT,
| info STRUCT<name:STRING, age:INT>,
| data MAP<STRING, INT>,
| values ARRAY<INT>
|) stored as $format""".stripMargin
def insert(format: String, table_name: Option[String] = None): String =
s"""INSERT overwrite ${table_name.getOrElse(table(format))} VALUES
| (6, null, null, null);
""".stripMargin

nativeWrite2(
format => (table(format), create(format), insert(format)),
(table_name, format) => {
val vanilla_table = s"${table_name}_v"
val vanilla_create = create(format, Some(vanilla_table))
vanillaWrite {
withDestinationTable(vanilla_table, Option(vanilla_create)) {
checkInsertQuery(insert(format, Some(vanilla_table)), checkNative = false)
}
}
val rowsFromOriginTable =
spark.sql(s"select * from $vanilla_table").collect()
val dfFromWriteTable =
spark.sql(s"select * from $table_name")
checkAnswer(dfFromWriteTable, rowsFromOriginTable)
}
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ trait NativeWriteChecker
}
}

def vanillaWrite(block: => Unit): Unit = {
withSQLConf(("spark.gluten.sql.native.writer.enabled", "false")) {
block
}
}

def withSource(df: Dataset[Row], viewName: String, pairs: (String, String)*)(
block: => Unit): Unit = {
withSQLConf(pairs: _*) {
Expand Down
5 changes: 2 additions & 3 deletions cpp-ch/clickhouse.version
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
CH_ORG=Kyligence
CH_BRANCH=rebase_ch/20240809
CH_COMMIT=01e780d46d9

CH_BRANCH=rebase_ch/20240815
CH_COMMIT=d87dbba64fc
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ createAggregateFunctionBloomFilter(const std::string & name, const DataTypes & a
if (type != Field::Types::Int64 && type != Field::Types::UInt64)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be Int64 or UInt64", name);

if ((type == Field::Types::Int64 && parameters[i].get<Int64>() < 0))
if ((type == Field::Types::Int64 && parameters[i].safeGet<Int64>() < 0))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Parameter for aggregate function {} should be non-negative number", name);

return parameters[i].get<UInt64>();
return parameters[i].safeGet<UInt64>();
};

filter_size = get_parameter(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & argu
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", data_type->getName(), name);

bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).get<bool>();
bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).safeGet<bool>();
const UInt32 p1 = DB::getDecimalPrecision(*data_type);
const UInt32 s1 = DB::getDecimalScale(*data_type);
auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL;
Expand Down
15 changes: 9 additions & 6 deletions cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@
#include <Interpreters/JIT/CompiledExpressionCache.h>
#include <Parser/RelParser.h>
#include <Parser/SerializedPlanParser.h>
#include <Planner/PlannerActionsVisitor.h>
#include <Processors/Chunk.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <QueryPipeline/printPipeline.h>
#include <Storages/Cache/CacheManager.h>
#include <Storages/Output/WriteBufferBuilder.h>
#include <Storages/StorageMergeTreeFactory.h>
#include <Storages/SubstraitSource/ReadBufferBuilder.h>
Expand All @@ -72,7 +74,6 @@
#include <Common/LoggerExtend.h>
#include <Common/logger_useful.h>
#include <Common/typeid_cast.h>
#include <Storages/Cache/CacheManager.h>

namespace DB
{
Expand Down Expand Up @@ -463,20 +464,22 @@ const DB::ColumnWithTypeAndName * NestedColumnExtractHelper::findColumn(const DB
const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeType(
DB::ActionsDAG & actions_dag,
const DB::ActionsDAG::Node * node,
const std::string & type_name,
const DataTypePtr & cast_to_type,
const std::string & result_name,
CastType cast_type)
{
DB::ColumnWithTypeAndName type_name_col;
type_name_col.name = type_name;
type_name_col.name = cast_to_type->getName();
type_name_col.column = DB::DataTypeString().createColumnConst(0, type_name_col.name);
type_name_col.type = std::make_shared<DB::DataTypeString>();
const auto * right_arg = &actions_dag.addColumn(std::move(type_name_col));
const auto * left_arg = node;
DB::CastDiagnostic diagnostic = {node->result_name, node->result_name};
ColumnWithTypeAndName left_column{nullptr, node->result_type, {}};
DB::ActionsDAG::NodeRawConstPtrs children = {left_arg, right_arg};
return &actions_dag.addFunction(
DB::createInternalCastOverloadResolver(cast_type, std::move(diagnostic)), std::move(children), result_name);
auto func_base_cast = createInternalCast(std::move(left_column), cast_to_type, cast_type, diagnostic);

return &actions_dag.addFunction(func_base_cast, std::move(children), result_name);
}

const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeTypeIfNeeded(
Expand All @@ -489,7 +492,7 @@ const DB::ActionsDAG::Node * ActionsDAGUtil::convertNodeTypeIfNeeded(
if (node->result_type->equals(*dst_type))
return node;

return convertNodeType(actions_dag, node, dst_type->getName(), result_name, cast_type);
return convertNodeType(actions_dag, node, dst_type, result_name, cast_type);
}

String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline)
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Common/CHUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ class ActionsDAGUtil
public:
static const DB::ActionsDAG::Node * convertNodeType(
DB::ActionsDAG & actions_dag,
const DB::ActionsDAG::Node * node,
const std::string & type_name,
const DB::ActionsDAG::Node * node_to_cast,
const DB::DataTypePtr & cast_to_type,
const std::string & result_name = "",
DB::CastType cast_type = DB::CastType::nonAccurate);

Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct LambdaLess
auto compare_res_col = lambda_->reduce();
DB::Field field;
compare_res_col.column->get(0, field);
return field.get<Int32>() < 0;
return field.safeGet<Int32>() < 0;
}
private:
ALWAYS_INLINE DB::ColumnPtr oneRowColumn(size_t i) const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ template <typename To>
Field convertNumericType(const Field & from)
{
if (from.getType() == Field::Types::UInt64)
return convertNumericTypeImpl<UInt64, To>(from.get<UInt64>());
return convertNumericTypeImpl<UInt64, To>(from.safeGet<UInt64>());
if (from.getType() == Field::Types::Int64)
return convertNumericTypeImpl<Int64, To>(from.get<Int64>());
return convertNumericTypeImpl<Int64, To>(from.safeGet<Int64>());
if (from.getType() == Field::Types::UInt128)
return convertNumericTypeImpl<UInt128, To>(from.get<UInt128>());
return convertNumericTypeImpl<UInt128, To>(from.safeGet<UInt128>());
if (from.getType() == Field::Types::Int128)
return convertNumericTypeImpl<Int128, To>(from.get<Int128>());
return convertNumericTypeImpl<Int128, To>(from.safeGet<Int128>());
if (from.getType() == Field::Types::UInt256)
return convertNumericTypeImpl<UInt256, To>(from.get<UInt256>());
return convertNumericTypeImpl<UInt256, To>(from.safeGet<UInt256>());
if (from.getType() == Field::Types::Int256)
return convertNumericTypeImpl<Int256, To>(from.get<Int256>());
return convertNumericTypeImpl<Int256, To>(from.safeGet<Int256>());

throw Exception(ErrorCodes::TYPE_MISMATCH, "Type mismatch. Expected: Integer. Got: {}", from.getType());
}
Expand All @@ -81,7 +81,7 @@ inline UInt32 extractArgument(const ColumnWithTypeAndName & named_column)
throw Exception(
ErrorCodes::DECIMAL_OVERFLOW, "{} convert overflow, precision/scale value must in UInt32", named_column.type->getName());
}
return static_cast<UInt32>(to.get<UInt32>());
return static_cast<UInt32>(to.safeGet<UInt32>());
}

}
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Functions/SparkFunctionFloor.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class SparkFunctionFloor : public DB::FunctionFloor
if (scale_field.getType() != Field::Types::UInt64 && scale_field.getType() != Field::Types::Int64)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must have integer type");

Int64 scale64 = scale_field.get<Int64>();
Int64 scale64 = scale_field.safeGet<Int64>();
if (scale64 > std::numeric_limits<Scale>::max() || scale64 < std::numeric_limits<Scale>::min())
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale argument for rounding function is too large");

Expand Down
40 changes: 20 additions & 20 deletions cpp-ch/local-engine/Functions/SparkFunctionHashingExtended.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,42 +101,42 @@ class SparkFunctionAnyHash : public IFunction
if (which.isNothing())
return seed;
else if (which.isUInt8())
return applyNumber<UInt8>(field.get<UInt8>(), seed);
return applyNumber<UInt8>(field.safeGet<UInt8>(), seed);
else if (which.isUInt16())
return applyNumber<UInt16>(field.get<UInt16>(), seed);
return applyNumber<UInt16>(field.safeGet<UInt16>(), seed);
else if (which.isUInt32())
return applyNumber<UInt32>(field.get<UInt32>(), seed);
return applyNumber<UInt32>(field.safeGet<UInt32>(), seed);
else if (which.isUInt64())
return applyNumber<UInt64>(field.get<UInt64>(), seed);
return applyNumber<UInt64>(field.safeGet<UInt64>(), seed);
else if (which.isInt8())
return applyNumber<Int8>(field.get<Int8>(), seed);
return applyNumber<Int8>(field.safeGet<Int8>(), seed);
else if (which.isInt16())
return applyNumber<Int16>(field.get<Int16>(), seed);
return applyNumber<Int16>(field.safeGet<Int16>(), seed);
else if (which.isInt32())
return applyNumber<Int32>(field.get<Int32>(), seed);
return applyNumber<Int32>(field.safeGet<Int32>(), seed);
else if (which.isInt64())
return applyNumber<Int64>(field.get<Int64>(), seed);
return applyNumber<Int64>(field.safeGet<Int64>(), seed);
else if (which.isFloat32())
return applyNumber<Float32>(field.get<Float32>(), seed);
return applyNumber<Float32>(field.safeGet<Float32>(), seed);
else if (which.isFloat64())
return applyNumber<Float64>(field.get<Float64>(), seed);
return applyNumber<Float64>(field.safeGet<Float64>(), seed);
else if (which.isDate())
return applyNumber<UInt16>(field.get<UInt16>(), seed);
return applyNumber<UInt16>(field.safeGet<UInt16>(), seed);
else if (which.isDate32())
return applyNumber<Int32>(field.get<Int32>(), seed);
return applyNumber<Int32>(field.safeGet<Int32>(), seed);
else if (which.isDateTime())
return applyNumber<UInt32>(field.get<UInt32>(), seed);
return applyNumber<UInt32>(field.safeGet<UInt32>(), seed);
else if (which.isDateTime64())
return applyDecimal<DateTime64>(field.get<DateTime64>(), seed);
return applyDecimal<DateTime64>(field.safeGet<DateTime64>(), seed);
else if (which.isDecimal32())
return applyDecimal<Decimal32>(field.get<Decimal32>(), seed);
return applyDecimal<Decimal32>(field.safeGet<Decimal32>(), seed);
else if (which.isDecimal64())
return applyDecimal<Decimal64>(field.get<Decimal64>(), seed);
return applyDecimal<Decimal64>(field.safeGet<Decimal64>(), seed);
else if (which.isDecimal128())
return applyDecimal<Decimal128>(field.get<Decimal128>(), seed);
return applyDecimal<Decimal128>(field.safeGet<Decimal128>(), seed);
else if (which.isStringOrFixedString())
{
const String & str = field.get<String>();
const String & str = field.safeGet<String>();
return applyUnsafeBytes(str.data(), str.size(), seed);
}
else if (which.isTuple())
Expand All @@ -145,7 +145,7 @@ class SparkFunctionAnyHash : public IFunction
assert(tuple_type);

const auto & elements = tuple_type->getElements();
const Tuple & tuple = field.get<Tuple>();
const Tuple & tuple = field.safeGet<Tuple>();
assert(tuple.size() == elements.size());

for (size_t i = 0; i < elements.size(); ++i)
Expand All @@ -160,7 +160,7 @@ class SparkFunctionAnyHash : public IFunction
assert(array_type);

const auto & nested_type = array_type->getNestedType();
const Array & array = field.get<Array>();
const Array & array = field.safeGet<Array>();
for (size_t i=0; i < array.size(); ++i)
{
seed = applyGeneric(array[i], seed, nested_type);
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Functions/SparkFunctionMakeDecimal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ namespace
else
return false;
}
result = static_cast<ToNativeType>(convert_to.get<ToNativeType>());
result = static_cast<ToNativeType>(convert_to.safeGet<ToNativeType>());

ToNativeType pow10 = intExp10OfSize<ToNativeType>(precision_value);
if ((result < 0 && result <= -pow10) || result >= pow10)
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ class FunctionRoundingHalfUp : public IFunction
if (scale_field.getType() != Field::Types::UInt64 && scale_field.getType() != Field::Types::Int64)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must have integer type");

Int64 scale64 = scale_field.get<Int64>();
Int64 scale64 = scale_field.safeGet<Int64>();
if (scale64 > std::numeric_limits<Scale>::max() || scale64 < std::numeric_limits<Scale>::min())
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale argument for rounding function is too large");

Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class SparkFunctionConvertToDateTime : public IFunction

Field field;
named_column.column->get(0, field);
return static_cast<UInt32>(field.get<UInt32>());
return static_cast<UInt32>(field.safeGet<UInt32>());
}

DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Operator/ExpandTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void ExpandTransform::work()

if (kind == EXPAND_FIELD_KIND_SELECTION)
{
auto index = field.get<Int32>();
auto index = field.safeGet<Int32>();
const auto & input_column = input_columns[index];

DB::ColumnWithTypeAndName input_arg;
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/AggregateFunctionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ const DB::ActionsDAG::Node * AggregateFunctionParser::convertNodeTypeIfNeeded(
if (need_convert_type)
{
func_node = ActionsDAGUtil::convertNodeType(
actions_dag, func_node, TypeParser::parseType(output_type)->getName(), func_node->result_name);
actions_dag, func_node, TypeParser::parseType(output_type), func_node->result_name);
actions_dag.addOrReplaceInOutputs(*func_node);
}

Expand Down
Loading

0 comments on commit 2ad689b

Please sign in to comment.