forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DeviceType.h
49 lines (39 loc) · 1.24 KB
/
DeviceType.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#pragma once
// This is directly synchronized with caffe2/proto/caffe2.proto, but
// doesn't require me to figure out how to get Protobuf headers into
// ATen/core (which would require a lot more build system hacking.)
// If you modify me, keep me synchronized with that file.
#include <c10/macros/Macros.h>
#include <ostream>
#include <functional>
namespace c10 {
enum class DeviceType : int16_t {
CPU = 0,
CUDA = 1, // CUDA.
MKLDNN = 2, // Reserved for explicit MKLDNN
OPENGL = 3, // OpenGL
OPENCL = 4, // OpenCL
IDEEP = 5, // IDEEP.
HIP = 6, // AMD HIP
FPGA = 7, // FPGA
// Change the following number if you add more devices in the code.
COMPILE_TIME_MAX_DEVICE_TYPES = 8,
ONLY_FOR_TEST = 20901, // This device type is only for test.
};
C10_API std::string DeviceTypeName(
DeviceType d,
bool lower_case = false);
C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type);
} // namespace at
namespace std {
template <> struct hash<c10::DeviceType> {
std::size_t operator()(c10::DeviceType k) const {
return std::hash<int>()(static_cast<int>(k));
}
};
} // namespace std
// TODO: Remove me when we get a global c10 namespace using in at
namespace at {
using c10::DeviceType;
using c10::DeviceTypeName;
}