forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNamedTensor.h
63 lines (48 loc) · 2.13 KB
/
NamedTensor.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#pragma once
#ifdef BUILD_NAMEDTENSOR
#include <ATen/Dimname.h>
#include <c10/core/TensorImpl.h>
#include <torch/csrc/utils/memory.h>
namespace at {
// XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen.
// Due to the c10/ATen library split, TensorImpl cannot depend on Dimname,
// so we have a couple of workarounds.
//
// In the long term, we'll move Dimname to c10 and everything in this file
// can be refactored out. The main blocker for that is that "c10::Symbol"
// actually exists outside of c10 and needs to be moved in.
// TensorImpl has a unique_ptr<NamedTensorMetaInterface> field.
// XXX: Ideally we would just put optional<vector<Dimname>> into TensorImpl.
struct CAFFE2_API NamedTensorMeta : public c10::NamedTensorMetaInterface {
explicit NamedTensorMeta(int64_t num_names)
: names_(std::vector<Dimname>(num_names, Dimname::wildcard())) {}
explicit NamedTensorMeta(DimnameList names)
: names_(names.vec()) {}
explicit NamedTensorMeta(std::vector<Dimname>&& names)
: names_(std::move(names)) {}
std::unique_ptr<c10::NamedTensorMetaInterface> clone() const override {
return torch::make_unique<NamedTensorMeta>(names_);
}
bool has_names() const;
DimnameList names() const { return names_; }
void set_names_(DimnameList new_names) {
TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
std::copy(new_names.begin(), new_names.end(), names_.begin());
}
void set_names_(std::vector<Dimname>&& new_names) {
TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
names_ = std::move(new_names);
}
private:
std::vector<Dimname> names_;
};
namespace impl {
// Some helper functions on TensorImpl. Useful for working with names in TH.
// XXX: Ideally these would exist as methods on TensorImpl
CAFFE2_API void internal_set_names_inplace(TensorImpl* impl, optional<DimnameList> names);
CAFFE2_API void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names);
CAFFE2_API optional<DimnameList> internal_get_names(TensorImpl* impl);
CAFFE2_API bool internal_has_names(TensorImpl* impl);
} // namespace impl
} // namespace at
#endif