forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Device.cpp
95 lines (90 loc) · 2.76 KB
/
Device.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <c10/Device.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <algorithm>
#include <array>
#include <exception>
#include <ostream>
#include <string>
#include <tuple>
#include <vector>
namespace c10 {
namespace {
DeviceType parse_type(const std::string& device_string) {
static const std::array<std::pair<std::string, DeviceType>, 7> types = {{
{"cpu", DeviceType::CPU},
{"cuda", DeviceType::CUDA},
{"mkldnn", DeviceType::MKLDNN},
{"opengl", DeviceType::OPENGL},
{"opencl", DeviceType::OPENCL},
{"ideep", DeviceType::IDEEP},
{"hip", DeviceType::HIP},
}};
auto device = std::find_if(
types.begin(),
types.end(),
[device_string](const std::pair<std::string, DeviceType>& p) {
return p.first == device_string;
});
if (device != types.end()) {
return device->second;
}
AT_ERROR(
"Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, or hip device type at start of device string");
}
} // namespace
// `std::regex` is still in a very incomplete state in GCC 4.8.x,
// so we have to do our own parsing, like peasants.
// https://stackoverflow.com/questions/12530406/is-gcc-4-8-or-earlier-buggy-about-regular-expressions
//
// Replace with the following code once we shed our GCC skin:
//
// static const std::regex regex(
// "(cuda|cpu)|(cuda|cpu):([0-9]+)|([0-9]+)",
// std::regex_constants::basic);
// std::smatch match;
// const bool ok = std::regex_match(device_string, match, regex);
// AT_CHECK(ok, "Invalid device string: '", device_string, "'");
// if (match[1].matched) {
// type_ = parse_type_from_string(match[1].str());
// } else {
// if (match[2].matched) {
// type_ = parse_type_from_string(match[1].str());
// } else {
// type_ = Type::CUDA;
// }
// AT_ASSERT(match[3].matched);
// index_ = std::stoi(match[3].str());
// }
Device::Device(const std::string& device_string) : Device(Type::CPU) {
AT_CHECK(!device_string.empty(), "Device string must not be empty");
int index = device_string.find(":");
if (index == std::string::npos) {
type_ = parse_type(device_string);
return;
} else {
std::string s;
s = device_string.substr(0, index);
AT_CHECK(!s.empty(), "Device string must not be empty");
type_ = parse_type(s);
}
std::string device_index = device_string.substr(index + 1);
try {
index_ = c10::stoi(device_index);
} catch (const std::exception&) {
AT_ERROR(
"Could not parse device index '",
device_index,
"' in device string '",
device_string,
"'");
}
}
std::ostream& operator<<(std::ostream& stream, const Device& device) {
stream << device.type();
if (device.has_index()) {
stream << ":" << device.index();
}
return stream;
}
} // namespace c10