forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsaved_variable.cpp
108 lines (95 loc) · 3.89 KB
/
saved_variable.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#include <torch/csrc/autograd/saved_variable.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/anomaly_mode.h>
#include <ATen/Tensor.h>
#include <cstdint>
#include <list>
#include <memory>
#include <sstream>
namespace torch { namespace autograd {
SavedVariable::SavedVariable(const Variable& variable, bool is_output, bool is_inplace_view) {
if (variable.defined()) {
was_default_constructed_ = false;
output_nr_ = variable.output_nr();
requires_grad_ = variable.requires_grad();
has_grad_fn_ = !variable.is_leaf();
is_inplace_view_ = is_inplace_view;
// These copies are all shared_ptr copies, so slightly more expensive.
// Do them here instead of in the init list in case data is undefined.
data_ = variable.tensor_data();
if (variable.is_leaf()) {
grad_accumulator_ = variable.grad_accumulator();
} else if (!is_output) {
grad_fn_ = variable.grad_fn();
} else if (is_inplace_view) {
weak_grad_fn_ = variable.grad_fn();
}
version_counter_ = variable.version_counter();
saved_version_ = version_counter_.current_version();
}
}
Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
if (!data_.defined()) {
if (!was_default_constructed_) {
throw std::runtime_error(ERR_BACKWARD_TWICE);
}
return Variable();
}
auto grad_fn = is_inplace_view_ ? weak_grad_fn_.lock() : grad_fn_;
if (has_grad_fn_ && !grad_fn) {
if (!saved_for) {
// If saving the grad_fn would create a circular reference, then it must
// be passed in to the unpack function.
throw std::runtime_error("No grad_fn for non-leaf saved variable");
}
grad_fn = std::move(saved_for);
}
if (saved_version_ != version_counter_.current_version()) {
std::stringstream message;
message << "one of the variables needed for gradient computation has been "
"modified by an inplace operation: [" << data_.type().toString() << " "
<< data_.sizes() << "]";
if (grad_fn) {
message << ", which is output " << output_nr_
<< " of " << grad_fn->name() << ",";
}
message << " is at version " << version_counter_.current_version()
<< "; expected version " << saved_version_ << " instead.";
if (!AnomalyMode::is_enabled()) {
message << " Hint: enable anomaly detection to find the operation "
"that failed to compute its gradient, with torch.autograd."
"set_detect_anomaly(True).";
}
else {
message << " Hint: the backtrace further above shows the operation "
"that failed to compute its gradient. The variable in question "
"was changed in there or anywhere later. Good luck!";
}
throw std::runtime_error(message.str());
}
// NB: saved views are unpacked as normal Variables (not views) even though
// they still share the same storage. This works only because we never call
// in-place functions on unpacked variables.
Variable var;
if (grad_fn) {
var = make_variable(data_, Edge(std::move(grad_fn), output_nr_));
} else {
var = make_variable(data_, requires_grad_);
}
var.set_version_counter(saved_version_);
// If a Variable is a leaf (no grad_fn saved), and it requires_grad, then we
// should have saved the grad accumulator. Even if the Variable no longer
// alive, the accumulator should be kept alive by the references in the
// graph).
if (requires_grad_ && !var.grad_fn() && grad_accumulator_.expired())
throw std::logic_error("No grad accumulator for a saved leaf!");
var.set_grad_accumulator(grad_accumulator_);
return var;
}
const char* ERR_BACKWARD_TWICE =
"Trying to backward through the graph a second time, but the buffers have "
"already been freed. Specify retain_graph=True when calling backward "
"the first time.";
}} // namespace torch::autograd