forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
interned_strings.cpp
125 lines (107 loc) · 3.44 KB
/
interned_strings.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <ATen/core/interned_strings.h>
#include <cstdint>
#include <cstring>
#include <iostream>
#include <mutex>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include <ATen/core/interned_strings_class.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
namespace c10 {
const std::string& domain_prefix() {
static const std::string _domain_prefix = "org.pytorch.";
return _domain_prefix;
}
Symbol InternedStrings::symbol(const std::string& s) {
std::lock_guard<std::mutex> guard(mutex_);
return _symbol(s);
}
std::pair<const char*, const char*> InternedStrings::string(Symbol sym) {
// Builtin Symbols are also in the maps, but
// we can bypass the need to acquire a lock
// to read the map for Builtins because we already
// know their string value
switch (sym) {
#define DEFINE_CASE(ns, s) \
case static_cast<unique_t>(ns::s): \
return {#ns "::" #s, #s};
FORALL_NS_SYMBOLS(DEFINE_CASE)
#undef DEFINE_CASE
default:
return customString(sym);
}
}
Symbol InternedStrings::ns(Symbol sym) {
switch (sym) {
#define DEFINE_CASE(ns, s) \
case static_cast<unique_t>(ns::s): \
return namespaces::ns;
FORALL_NS_SYMBOLS(DEFINE_CASE)
#undef DEFINE_CASE
default: {
std::lock_guard<std::mutex> guard(mutex_);
return sym_to_info_.at(sym).ns;
}
}
}
Symbol InternedStrings::_symbol(const std::string& s) {
auto it = string_to_sym_.find(s);
if (it != string_to_sym_.end())
return it->second;
auto pos = s.find("::");
if (pos == std::string::npos) {
std::stringstream ss;
ss << "all symbols must have a namespace, <namespace>::<string>, but found: " << s;
throw std::runtime_error(ss.str());
}
Symbol ns = _symbol("namespaces::" + s.substr(0, pos));
Symbol sym(sym_to_info_.size());
string_to_sym_[s] = sym;
sym_to_info_.push_back({ns, s, s.substr(pos + strlen("::"))});
return sym;
}
std::pair<const char*, const char*> InternedStrings::customString(Symbol sym) {
std::lock_guard<std::mutex> guard(mutex_);
SymbolInfo& s = sym_to_info_.at(sym);
return {s.qual_name.c_str(), s.unqual_name.c_str()};
}
static InternedStrings & globalStrings() {
static InternedStrings s;
return s;
}
Symbol Symbol::fromQualString(const std::string & s) {
return globalStrings().symbol(s);
}
const char * Symbol::toUnqualString() const {
return globalStrings().string(*this).second;
}
const char * Symbol::toQualString() const {
return globalStrings().string(*this).first;
}
const char * Symbol::toDisplayString() const {
// TODO: Make this actually return something that's "user friendly".
// The trouble is that, for this to be usable in printf-style assert
// statements, this has to return a const char* (whose lifetime is
// global), so we can't actually assemble a string on the fly.
return toQualString();
}
Symbol Symbol::ns() const {
return globalStrings().ns(*this);
}
std::string Symbol::domainString() const {
return domain_prefix() + ns().toUnqualString();
}
Symbol Symbol::fromDomainAndUnqualString(const std::string & d, const std::string & s) {
if (d.compare(0, domain_prefix().size(), domain_prefix()) != 0) {
std::ostringstream ss;
ss << "Symbol: domain string is expected to be prefixed with '"
<< domain_prefix() << "', e.g. 'org.pytorch.aten'";
throw std::runtime_error(ss.str());
}
std::string qualString = d.substr(domain_prefix().size()) + "::" + s;
return fromQualString(qualString);
}
} // namespace c10