forked from secretflow/spu
-
Notifications
You must be signed in to change notification settings - Fork 0
/
shape_ops.h
88 lines (72 loc) · 3.12 KB
/
shape_ops.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
// Copyright 2021 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstdint>
#include "libspu/core/value.h"
namespace spu {
class SPUContext;
}
namespace spu::kernel::hal {
/// the broadcast function
// @param in, the input
// @param to_shape, the target shape
Value broadcast_to(SPUContext* ctx, const Value& in, const Shape& to_shape,
const Axes& in_dims = {});
/// the reshape function
// @param in, the input
// @param to_shape, the target shape
Value reshape(SPUContext* ctx, const Value& in, const Shape& to_shape);
/// the slice function
// @param input, the param
// @param start_indices, the start indices
// @param end_indices, the end indices
// @param strides, the strides
Value slice(SPUContext* ctx, const Value& input, const Index& start_indices,
const Index& end_indices, const Strides& strides = {});
/// This is a special slice for single element at indices
// @returns a array with empty shape (scalar)
Value slice_scalar_at(SPUContext* ctx, const Value& input,
const Index& indices);
// update a block of in with update, start_indices is postion at in
Value update_slice(SPUContext* ctx, const Value& in, const Value& update,
const Index& start_indices);
/// the transpose function
// @param in, the param
Value transpose(SPUContext* ctx, const Value& in, const Axes& permutation = {});
//// the reverse function
// @param in, the param
// @param dimensions, dimensions to reverse
Value reverse(SPUContext* ctx, const Value& in, const Axes& dimensions);
/// Expand a scalar into to_shape.
/// Compare with broadcast, expand actually reallocates and assign memory
Value expand(SPUContext* ctx, const Value& in, const Shape& to_shape);
//// the pad function
// @param in, the param
// @param padding_value, to fill in the added padding
// @param edge_padding_low, the amount of padding added at the
// low-end (next to index 0) of each dimension
// @param edge_padding_high, the amount of padding added at the high-end
// (next to the highest index) of each dimension
// @param interior_padding, the amount of padding added between any two elements
// in each dimension
Value pad(SPUContext* ctx, const Value& in, const Value& padding_value,
const Sizes& edge_padding_low, const Sizes& edge_padding_high,
const Sizes& interior_padding);
/// the concatenate function
// @param first, the first param
// @param second, the second param
// @param axis, the axis
Value concatenate(SPUContext* ctx, const std::vector<Value>& values,
int64_t axis);
} // namespace spu::kernel::hal