forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
NamedTensorUtils.h
111 lines (86 loc) · 4.04 KB
/
NamedTensorUtils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#pragma once
#include <ATen/core/EnableNamedTensor.h>
#include <ATen/NamedTensor.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/DimVector.h>
#include <functional>
#ifdef BUILD_NAMEDTENSOR
namespace at {
using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
inline bool has_names(TensorList tensors) {
return std::any_of(
tensors.begin(), tensors.end(), [](const Tensor& t) { return t.has_names(); });
}
// Converts dim to an positional index. Errors if `dim` cannot be used to
// refer to any dimension of tensor.
CAFFE2_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim);
CAFFE2_API std::vector<int64_t> dimnames_to_positions(const Tensor& tensor, DimnameList dims);
// Unifies two DimnameList to produce a third. This is useful for implementing
// the named inference rule for binary broadcasting operations like add.
//
// There are three main constraints:
// 1) Check matching: Names must match positionally from the right.
// 2) Check misaligned: If a name `n` is in `names`, then it must appear at
// the same index from the right in other.
// 3) The output names are obtained by unifying the names individually from the right.
CAFFE2_API std::vector<Dimname>
unify_from_right(DimnameList names, DimnameList other, const char* action = "broadcast");
[[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) {
TORCH_CHECK(
false,
op_name, ": You passed a dimname (string) to this op in place of a dimension "
"index but it does not yet support this behavior. Please pass a dimension "
"index to work around this.");
}
namespace namedinference {
// Names get propagated via the following rules:
// 1) If result does not have names, then `names` get propagated.
// 2) If result has names, then `names` must be equal to result.names
void propagate_names(Tensor& result, optional<DimnameList> names);
void propagate_names(Tensor& result, std::vector<Dimname>&& names, bool validate_names);
CAFFE2_API void propagate_names(Tensor& result, optional<std::vector<Dimname>>&& maybe_names, bool validate_names);
void propagate_names(TensorImpl* result, optional<DimnameList> names);
void propagate_names(TensorImpl* result, std::vector<Dimname>&& names, bool validate_names);
void propagate_names(TensorImpl* result, optional<std::vector<Dimname>>&& maybe_names, bool validate_names);
// Propagates all names from src to result.
CAFFE2_API void propagate_names(Tensor& result, const Tensor& src);
void propagate_names(TensorImpl* result, /*const */TensorImpl* src);
// Propagates all names except for those at the excluded_idxs.
void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs);
// Used for reduction ops that have a `keepdim` arg.
void propagate_names_for_reduction(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs, bool keepdim);
// result = m1 @ m2 + bias
void propagate_names_for_addmm(
TensorImpl* result,
/*const*/TensorImpl* m1,
/*const*/TensorImpl* m2,
/*const*/TensorImpl* bias);
void propagate_names_for_addmv(
TensorImpl* result,
TensorImpl* mat,
TensorImpl* vec,
TensorImpl* bias);
void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);
void propagate_names_for_expand(Tensor& result, const Tensor& self);
optional<std::vector<Dimname>> compute_cat_outnames(TensorList tensors);
optional<std::vector<Dimname>> compute_broadcast_outnames(
const Tensor& self,
const Tensor& other);
optional<std::vector<Dimname>> broadcast_to_outnames(
const Tensor& tensor,
const Tensor& reference_tensor,
const char* op_name);
optional<std::vector<Dimname>> compute_baddbmm_outnames(
TensorImpl* result,
TensorImpl* self,
TensorImpl* other,
TensorImpl* bias);
optional<std::vector<Dimname>> compute_matmul_outnames(const Tensor& self, const Tensor& other);
optional<std::vector<Dimname>> compute_bmm_outnames(
Tensor& result,
const Tensor& self,
const Tensor& other);
optional<std::vector<Dimname>> compute_squeeze_outnames(const Tensor& tensor);
} // namespace namedinference
} // namespace at
#endif