Skip to content

Commit

Permalink
[GLUTEN-6989][CH] Use bitmap256 to trim
Browse files Browse the repository at this point in the history
  • Loading branch information
lwz9103 committed Aug 27, 2024
1 parent 87e36b4 commit 8ea9f54
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,17 @@ class GlutenClickhouseStringFunctionsSuite extends GlutenClickHouseWholeStageTra
sql("create table trim(trim_col String, src_col String) using parquet")
sql("""
|insert into trim values
| ('aba', 'a'),('bba', 'b'),('abcdef', 'abcd'),
| ('bAa', 'a'),('bba', 'b'),('abcdef', 'abcd'),
| (null, '123'),('123', null), ('', 'aaa'), ('bbb', '')
|""".stripMargin)

val sql0 = "select rtrim('aba', 'a') from trim order by src_col"
val sql1 = "select rtrim(trim_col, src_col) from trim order by src_col"
val sql2 = "select rtrim(trim_col, 'NSSS') from trim order by src_col"
val sql2 = "select rtrim(trim_col, 'cCBbAa') from trim order by src_col"
val sql3 = "select rtrim(trim_col, '') from trim order by src_col"
val sql4 = "select rtrim('', 'AAA') from trim order by src_col"
val sql5 = "select rtrim('', src_col) from trim order by src_col"
val sql6 = "select rtrim('ttt', src_col) from trim order by src_col"
val sql6 = "select rtrim('ab', src_col) from trim order by src_col"

runQueryAndCompare(sql0) { _ => }
runQueryAndCompare(sql1) { _ => }
Expand Down
45 changes: 27 additions & 18 deletions cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ namespace
}

ColumnPtr
executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & /*result_type*/, size_t input_rows_count) const override
{
const ColumnString * src_col = checkAndGetColumn<ColumnString>(arguments[0].column.get());
const ColumnConst * src_const_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
Expand All @@ -118,39 +118,34 @@ namespace
if (trim_const_col)
trim_const_str = trim_const_col->getValue<String>();
if (trim_const_col && trim_const_str.empty()) {
return arguments[0].column->cloneResized(input_rows_count);
return arguments[0].column;
}

if (src_const_col && trim_const_col)
{
const char * dst;
size_t dst_size;
std::unordered_set<char> trim_set(trim_const_str.begin(), trim_const_str.end());
trim(src_const_str.c_str(), src_const_str.size(), dst, dst_size, trim_set);
return result_type->createColumnConst(input_rows_count, String(dst, dst_size));
}
// If both arguments are constants, it will be simplified to a constant. Skipped here.

auto res_col = ColumnString::create();
ColumnString::Chars & res_data = res_col->getChars();
ColumnString::Offsets & res_offsets = res_col->getOffsets();
res_offsets.resize_exact(input_rows_count);

// Source column is constant and trim column is not constant
if (src_const_col)
{
res_data.reserve_exact(src_const_str.size() * input_rows_count);
for (size_t row = 0; row < input_rows_count; ++row)
{
StringRef trim_str_ref = trim_col->getDataAt(row);
std::unordered_set<char> trim_set(trim_str_ref.data, trim_str_ref.data + trim_str_ref.size);
std::unique_ptr<std::bitset<256>> trim_set = buildTrimSet(trim_str_ref.toString());
executeRow(src_const_str.c_str(), src_const_str.size(), res_data, res_offsets, row, trim_set);
}
return std::move(res_col);
}

// Source column is not constant and trim column is constant
if (trim_const_col)
{
res_data.reserve_exact(src_col->getChars().size());
std::unordered_set<char> trim_set(trim_const_str.begin(), trim_const_str.end());
std::unique_ptr<std::bitset<256>> trim_set = buildTrimSet(trim_const_str);
for (size_t row = 0; row < input_rows_count; ++row)
{
StringRef src_str_ref = src_col->getDataAt(row);
Expand All @@ -165,7 +160,7 @@ namespace
{
StringRef src_str_ref = src_col->getDataAt(row);
StringRef trim_str_ref = trim_col->getDataAt(row);
std::unordered_set<char> trim_set(trim_str_ref.data, trim_str_ref.data + trim_str_ref.size);
std::unique_ptr<std::bitset<256>> trim_set = buildTrimSet(trim_str_ref.toString());
executeRow(src_str_ref.data, src_str_ref.size, res_data, res_offsets, row, trim_set);
}
return std::move(res_col);
Expand All @@ -178,7 +173,7 @@ namespace
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets,
size_t & row,
const std::unordered_set<char> & trim_set) const
const std::unique_ptr<std::bitset<256>> & trim_set) const
{
const char * dst;
size_t dst_size;
Expand All @@ -191,16 +186,30 @@ namespace
res_offsets[row] = res_offset;
}

void trim(const char * src, size_t src_size, const char *& dst, size_t & dst_size, const std::unordered_set<char> & trim_set) const
std::unique_ptr<std::bitset<256>> buildTrimSet(const String& trim_str) const
{
const char * src_end = src + src_size;
auto trim_set = std::make_unique<std::bitset<256>>();
for (unsigned char i : trim_str)
trim_set->set(i);
return trim_set;
}

void trim(const char * src, size_t src_size, const char *& dst, size_t & dst_size, const std::unique_ptr<std::bitset<256>> & trim_set) const
{
if (!trim_set || trim_set->none())
{
dst = src;
dst_size = src_size;
return;
}

const char * src_end = src + src_size;
if constexpr (TrimMode::trim_left)
while (src < src_end && trim_set.contains(*src))
while (src < src_end && trim_set->test((unsigned char)*src))
++src;

if constexpr (TrimMode::trim_right)
while (src < src_end && trim_set.contains(*(src_end - 1)))
while (src < src_end && trim_set->test((unsigned char)*(src_end - 1)))
--src_end;

dst = const_cast<char *>(src);
Expand Down

0 comments on commit 8ea9f54

Please sign in to comment.