forked from apache/incubator-gluten
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MyUDF.cpp
96 lines (77 loc) · 3.22 KB
/
MyUDF.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <velox/expression/VectorFunction.h>
#include <iostream>
#include "udf/Udf.h"
namespace {
using namespace facebook::velox;
template <TypeKind Kind>
class PlusConstantFunction : public exec::VectorFunction {
public:
explicit PlusConstantFunction(int32_t addition) : addition_(addition) {}
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx& context,
VectorPtr& result) const override {
using nativeType = typename TypeTraits<Kind>::NativeType;
VELOX_CHECK_EQ(args.size(), 1);
auto& arg = args[0];
// The argument may be flat or constant.
VELOX_CHECK(arg->isFlatEncoding() || arg->isConstantEncoding());
BaseVector::ensureWritable(rows, createScalarType<Kind>(), context.pool(), result);
auto* flatResult = result->asFlatVector<nativeType>();
auto* rawResult = flatResult->mutableRawValues();
flatResult->clearNulls(rows);
if (arg->isConstantEncoding()) {
auto value = arg->as<ConstantVector<nativeType>>()->valueAt(0);
rows.applyToSelected([&](auto row) { rawResult[row] = value + addition_; });
} else {
auto* rawInput = arg->as<FlatVector<nativeType>>()->rawValues();
rows.applyToSelected([&](auto row) { rawResult[row] = rawInput[row] + addition_; });
}
}
private:
const int32_t addition_;
};
static std::vector<std::shared_ptr<exec::FunctionSignature>> integerSignatures() {
// integer -> integer
return {exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").build()};
}
static std::vector<std::shared_ptr<exec::FunctionSignature>> bigintSignatures() {
// bigint -> bigint
return {exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()};
}
} // namespace
const int kNumMyUdf = 2;
gluten::UdfEntry myUdf[kNumMyUdf] = {{"myudf1", "integer"}, {"myudf2", "bigint"}};
DEFINE_GET_NUM_UDF {
return kNumMyUdf;
}
DEFINE_GET_UDF_ENTRIES {
for (auto i = 0; i < kNumMyUdf; ++i) {
udfEntries[i] = myUdf[i];
}
}
DEFINE_REGISTER_UDF {
facebook::velox::exec::registerVectorFunction(
"myudf1", integerSignatures(), std::make_unique<PlusConstantFunction<facebook::velox::TypeKind::INTEGER>>(5));
facebook::velox::exec::registerVectorFunction(
"myudf2", bigintSignatures(), std::make_unique<PlusConstantFunction<facebook::velox::TypeKind::BIGINT>>(5));
std::cout << "registered myudf1, myudf2" << std::endl;
}