/// To use custom autograd operations, implement a Function subclass with
/// static forward and backward functions:
///
/// `forward` can take as many arguments as you want and should return either a
/// variable list or a Variable. Use of any direct Variable arguments will be
/// registered in the graph but no vectors/sets or any other data structures
/// will be traversed. You can use c10::optional<Tensor> as one of the arguments
/// and it will be registered as a variable in the graph if the argument has a
/// value. It should take a pointer to `torch::autograd::AutogradContext` as the
/// first argument. Variables can be saved in the `ctx` using
/// `ctx->save\_for\_backward`
/// (see `torch::autograd::AutogradContext::save\_for\_backward`) and other data
/// can be saved in the `ctx->saved\_data` map
/// (see `torch::autograd::AutogradContext::saved\_data`)
/// in the form of `<std::string, at::IValue>` pairs.
///
/// `backward` should take a pointer to `torch::autograd::AutogradContext`
/// and a variable list containing as many Variables as there were outputs from
/// `forward` as arguments. It should return as many Variables as there were
/// inputs with each of them containing the gradient w.r.t. its corresponding
/// input. Variables saved in `forward` can be accessed with
/// `ctx->get\_saved\_variables` (see
/// `torch::autograd::AutogradContext::get\_saved\_variables`) and other saved
/// data can be accessed from `ctx->saved\_data`.
///
/// For example:
/// ```
/// class MyFunction : public Function<MyFunction> {
/// public:
/// static variable\_list forward(AutogradContext *ctx, int n, Variable var) {
/// // Save data for backward in context
/// ctx->saved\_data["n"] = n;
/// var.mul\_(2);
/// // Mark var as modified by inplace operation
/// ctx->mark\_dirty({var});
/// return {var};
/// }
///
/// static variable\_list backward(AutogradContext *ctx, variable\_list
/// grad\_output) {
/// // Use data saved in forward
/// auto n = ctx->saved\_data["n"].toInt();
/// return {grad\_output[0]*n};
/// }
/// };
/// ```
///
/// To use `MyFunction`:
/// ```
/// Variable x;
/// auto y = MyFunction::apply(6, x);
/// // Example backward call
/// y[0].sum().backward();
/// ```
Function
// torch/csrc/autograd/function\_hook.h
using Variable = at::Tensor;
using variable_list = std::vector<Variable>;
// torch/csrc/autograd/custom\_function.h
using optional_variable_list = std::vector<c10::optional<Variable>>;
using \_jvp\_fn\_t = std::function<variable_list(variable_list, variable_list)>;
// Get the return type of the forward function of the custom Function class X
template<typename X, typename... Args>
using forward\_t = decltype(X::forward(nullptr, std::declval<Args>()...));
template <class T>
struct TORCH\_API Function {
// We need to use a different template parameter than T here because T will
// inherit from Function, and when Function<T> is instantiated, T::forward
// is not declared yet.
// The enable\_if check is to ensure that the user doesn't explicitly provide
// the parameter X.
template<typename X=T, typename... Args>
static auto apply(Args&&... args) -> std::enable\_if\_t<std::is\_same<X,T>::value, forward\_t<X,Args...>>;
};
apply(Args&&... args)
template<class T>
template<typename X, typename... Args>
auto Function<T>::apply(Args&&... args) -> std::enable\_if\_t<std::is_same<X,T>::value, forward\_t<X,Args...>> {
std::shared\_ptr<CppNode<T>> node(new CppNode<T>(), deleteNode);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
variable_list input_vars;
const size\_t num_inputs = sizeof...(Args);
input_vars.reserve(num_inputs);
node->is_variable_input_.reserve(num_inputs);
// TODO Add tracing here
extract_vars(node->is_variable_input_, input_vars, args...);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool is_executable = GradMode::is_enabled() && any_variable_requires_grad(input_vars);
auto next_edges = (is_executable ? collect_next_edges(input_vars) : edge_list());
node->set_ctx_grad_fn(node);
node->set_next_edges(std::move(next_edges));
node->clear_input_metadata();
node->input_info_.reserve(input_vars.size());
for (auto& var : input_vars) {
node->input_info_.emplace_back(var);
}
using forward\_return\_t = forward\_t<X, Args...>;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
forward\_return\_t outputs;
{
AutoGradMode grad\_mode(false);
outputs = T::forward(&node->ctx_, std::forward<Args>(args)...);
}
\_jvp\_fn\_t jvp_fn = [](variable_list inputs, variable_list gI) -> variable_list {
TORCH_CHECK(false, "jvp is not implemented for the c++ API of custom Function yet.",
"Please open a feature request on Github if you need this.");
};
auto wrapped_outputs = _wrap_outputs(
input_vars,
node->ctx_.get_non_differentiable(),
node->ctx_.get_and_bump_dirty(),
to_optional(outputs),
is_executable ? node : nullptr,
jvp_fn);
node->output_info_.reserve(wrapped_outputs.size());
for (auto& output : wrapped_outputs) {
if (is_executable && output.has_value()) {
node->output_info_.emplace_back(output.value());
} else if (is_executable) {
node->output_info_.emplace_back();
}
}
if (is_executable) {
node->save_variables_to_ctx();
}
// wrapped\_outputs will be a variable\_list so, convert it to the correct
// return type. Only Variable and variable\_list are accepted as return types.
return to_output_type<forward\_return\_t>(wrapped_outputs);
}
FunctionPreHook
// torch/csrc/autograd/function\_hook.h
struct TORCH\_API FunctionPreHook {
virtual ~FunctionPreHook();
virtual variable\_list operator()(const variable\_list& grads) = 0;
};
FunctionPostHook
// torch/csrc/autograd/function\_hook.h
struct TORCH\_API FunctionPostHook {
virtual ~FunctionPostHook();
virtual variable\_list operator()(
const variable\_list& outputs /* grad\_inputs */,
const variable\_list& inputs /* grad\_outputs */) = 0;
};
AutogradContext
/// Context to save information during `forward` that can be accessed in `backward`
/// in custom autograd operations (see `torch::autograd::Function` for details).
struct TORCH\_API AutogradContext {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
AutogradContext() : materialize_grads_(true) {}
AutogradContext(const AutogradContext &other) = delete;
AutogradContext& operator=(const AutogradContext& other) = delete;
/// Can be used to save non-variable data for `backward`.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
ska::flat_hash_map<std::string, at::IValue> saved_data;
/// Saves the list of variables for a future call to `backward`. This
/// should be called at most once from inside of `forward`.
void save\_for\_backward(variable\_list to\_save);
/// Marks variables in the list as modified in an in-place operation. This
/// should be called at most once from inside of `forward` and all arguments
/// should be inputs.
void mark\_dirty(const variable\_list &inputs);
/// Marks outputs in the list as not requiring gradients. This should be called
/// at most once from inside of `forward` and all arguments should be outputs.
void mark\_non\_differentiable(const variable\_list &outputs);
// Sets whether undefined output grad tensors should be expanded to tensors
// full of zeros before calling backward function. Default value is true.
void set\_materialize\_grads(bool value);
/// Get the list of variables that were saved in `forward` using
/// `save\_for\_backward()`. Before returning them to the user, a check is made to
/// ensure that they were not modified by any in-place operations.
variable\_list get\_saved\_variables() const;
const std::unordered\_set<at::TensorImpl*>& get\_and\_bump\_dirty() const;
const std::unordered\_set<at::TensorImpl*>& get\_non\_differentiable() const;
private:
std::unordered\_set<at::TensorImpl*> non_differentiable_;
std::unordered\_set<at::TensorImpl*> dirty_inputs_;
std::vector<torch::autograd::SavedVariable> saved_variables_;
variable_list to_save_;
bool materialize_grads_;
// The CppNode in the autograd graph that owns this AutogradContext. We need a
// weak\_ptr to avoid a refcycle. Since grad\_fn\_ owns this AutogradContext, it
// will always be alive when we want to use it.
std::weak_ptr<Node> grad_fn_;
bool has_freed_buffers_;
void save\_variables();
template <class T> friend struct CppNode;
};
SavedVariable
/// A snapshot of a variable at a certain version. A `SavedVariable` stores
/// enough information to reconstruct a variable from a certain point in time.
// torch/csrc/autograd/saved\_variable.h
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class TORCH\_API SavedVariable {
public:
SavedVariable() = default;
SavedVariable(const Variable& variable, bool is_output, bool is_inplace_on_view=false);
SavedVariable(const c10::optional<Variable>& variable, bool is_output, bool is_inplace_on_view=false);
SavedVariable(SavedVariable&&) = default;
SavedVariable& operator=(SavedVariable&&) = default;
~SavedVariable() {
if (fw_grad_) {
// See note [ Using ForwardGrad ]
fw_grad_->clear();
}
}
/// Reconstructs the saved variable. Pass `saved\_for` as the gradient
/// function if constructing the `SavedVariable` with it would have caused a
/// circular reference.
Variable unpack(std::shared\_ptr<Node> saved\_for = nullptr) const;
void register\_hooks(std::unique\_ptr<SavedVariableHooks>&& hooks);
void reset\_data();
private:
// This field contains either:
// 1. the variable to save
// 2. or its tensor\_data.
// If storing the variable itself would create a circular reference,
// we fall into the second case and its metadata is also saved separately.
// In that case, the grad\_fn must be passed in to the unpack function when
// reconstructing the Variable (except when we are doing an inplace operation on
// a view, see below).
// The field saved\_orignal\_ below reflects the two cases: its value is true
// in the first case and false in the second case.
// The value data\_.defined() can be false in three cases:
// 1. SavedVariable was constructed without a Tensor (the value to save is None), in
// that case was\_default\_constructed\_ will be kept at true
// 2. The saved variable has been released by calling SavedVariable::reset\_data(), typically
// during the backward pass
// 3. Hooks have been registered. In that case, hooks\_ will be defined instead.
// Note that the value of saved\_original\_ only reflects what happened during the construction
// of the SavedVariable. If saved\_original\_ is true, we saved the original tensor in data\_,
// but if the user registers hooks, we will no longer have it (despite the saved\_original\_ still
// being true)
at::Tensor data_;
// This field is used to store the forward AD gradients associated with
// the saved Tensor. Note that this shared\_ptr must never be shared with
// either the saved Tensor or the unpacked Tensor. See note [ Using ForwardGrad ]
std::shared\_ptr<ForwardGrad> fw_grad_;
// Weak version of grad\_fn\_ that prevents leaks in rebase\_history() for
// inplace views.
// This variable is used when the user chooses to create a SavedVariable with
// is\_inplace\_on\_view = true.
// In that case, the grad\_fn passed in to the unpack function at unwrapping
// time is unused.
std::weak_ptr<Node> weak_grad_fn_;
c10::VariableVersion version_counter_;
uint32\_t saved_version_ = 0;
uint32\_t output_nr_ = 0;
bool was_default_constructed_ = true;
bool is_inplace_on_view_ = false;
bool saved_original_ = false;
bool is_leaf_ = false;
bool is_output_ = false;
// Hooks are a pair of functions pack\_hook/unpack\_hook that provides fine-grained control
// over how the SavedVariable should save its data.
// pack\_hook is called upon registration, while unpack\_hook is called when unpacking.
std::unique\_ptr<SavedVariableHooks> hooks_;
// Fields grad\_fn\_, grad\_accumulator\_, and requires\_grad\_ are only used if hooks are defined.
// They are set before pack\_hook is called and used after unpack\_hook is called.
std::shared\_ptr<Node> grad_fn_;
std::weak_ptr<Node> grad_accumulator_;
bool requires_grad_ = false;
void save\_metadata(const Variable& data);
static std::unique\_ptr<SavedVariableHooks> get\_default\_hooks();
void set\_hooks\_and\_pack\_data(std::unique\_ptr<SavedVariableHooks>&& hooks, const Variable& data);
};
ForwardGrad
// [ Using ForwardGrad ]
// ForwardGrad needs to be a shared\_ptr to satisfy constraints of its inner design. But
// this shared\_ptr must be uniquely associated with the object that stores it (as of
// writing, either AutogradMeta or SavedVariable). This object is called the "owning object"
// in the discussions below. This owning object must call `ForwardGrad::clear()` when it
// is destroyed to ensure that the ForwardGrad is properly de-allocated.
// torch/csrc/autograd/forward\_grad.h
struct TORCH\_API ForwardGrad : std::enable_shared_from_this<ForwardGrad> {
ForwardGrad() = default;
// This function must only be called when AutogradMeta or SavedVariable is
// being destructed as it ensures that:
// - The only (potential) other references to this ForwardGrad are the
// different level it is registered to
// - No other thread will try to call `set\_value` or `value` ever from now
// on
// - Any of the ForwardADLevel that this ForwardGrad is registered with
// might
// call `reset` at any point during this function
void clear() {
c10::SmallVector<uint64\_t, EXPECTED_MAX_LEVEL> levels_idx;
{
std::lock\_guard<std::mutex> lock(mutex\_);
for (auto& c : content_) {
levels_idx.push_back(c.first);
}
}
for (auto l_idx : levels_idx) {
// Use "try" version here as another thread might have deleted this
// level before we got here
// This is an owning reference as we want to keep the level alive
// until we successfully unregister ourselves
auto level = ForwardADLevel::try_get_by_idx(l_idx);
if (level) {
level->erase(shared_from_this());
}
}
}
void set\_value(const at::Tensor& value, uint64\_t level) {
// Owning reference to ensure the forward\_level is not destroyed
// while we are updating our internal state
auto forward_level = ForwardADLevel::get_by_idx(level);
forward_level->insert(shared_from_this());
std::lock\_guard<std::mutex> lock(mutex\_);
content_.insert({level, value});
}
// This function removes the tangent for a given level from this ForwardGrad
// Use the update\_level flag to disable notifying the level about this reset
// This flag is most notably used by the ForwardADLevel destructor.
void reset(uint64\_t level, bool update\_level=true) {
if (update_level) {
ForwardADLevel::get_by_idx(level)->erase(shared_from_this());
}
std::unique\_lock<std::mutex> lock(mutex\_);
const auto& it = content_.find(level);
TORCH_INTERNAL_ASSERT(it != content_.end(), "Resetting a non-existent level.");
// Keep the Tensor alive until we have released the lock
// This is needed as we can be in a case where this function is called by
// ForwardADLevel destructor
auto t = (*it).second;
content_.erase(level);
lock.unlock();
}
const at::Tensor& value(uint64\_t level) const;
bool contains(uint64\_t level) {
std::lock\_guard<std::mutex> lock(mutex\_);
return content_.count(level) > 0;
}
bool empty() const {
return content_.empty();
}
static const at::Tensor& undef\_grad();
private:
// TODO(albanD): replace this with a SmallVector
std::unordered\_map<uint64\_t, at::Tensor> content_;
mutable std::mutex mutex_;
};
ForwardADLevel
// torch/csrc/autograd/forward\_grad.h
struct TORCH\_API ForwardADLevel {
ForwardADLevel(uint64\_t idx) : idx_(idx) {}
~ForwardADLevel();
static uint64\_t get\_next\_idx();
static void release\_idx(uint64\_t idx);
static std::shared\_ptr<ForwardADLevel> get\_by\_idx(uint64\_t idx);
static std::shared\_ptr<ForwardADLevel> try\_get\_by\_idx(uint64\_t idx);
void erase(const std::shared\_ptr<ForwardGrad>& grad) {
std::lock\_guard<std::mutex> lock(mutex\_);
grads_.erase(grad);
}
void insert(const std::shared\_ptr<ForwardGrad>& grad) {
std::lock\_guard<std::mutex> lock(mutex\_);
grads_.insert(grad);
}
private:
std::unordered\_set<std::shared\_ptr<ForwardGrad>> grads_;
std::mutex mutex_;
uint64\_t idx_;
};
SavedVariableHooks
// torch/csrc/autograd/saved\_variable\_hooks.h
struct TORCH\_API SavedVariableHooks {
virtual void call\_pack\_hook(const at::Tensor &tensor) = 0;
virtual at::Tensor call\_unpack\_hook() = 0;
virtual ~SavedVariableHooks() = default;
};
VariableVersion
// NOTE [ Version Counter Sharing ]
//
// Every Tensor has a version counter. Version counters are incremented whenever
// the data or size of a tensor changes through in-place Variable operations.
// Version counters are used to detect modifications to saved variables which
// would result in incorrect gradient calculations. Version counters may be
// shared between Variables:
//
// 1. A view shares the version counter of the base Variable,
// 2. `x.detach()` shares the version counter of `x`,
// 3. Unpacked saved variables share the version counter of the source.
//
// Version counters are not shared in these scenarios:
//
// 1. When we replace a `Variable`'s underlying `Tensor` by calling
// `set\_data(...)`,
// 2. `x.data` does not share the version counter of `x`. (See discussion at
// https://github.com/pytorch/pytorch/issues/5396)
//
// Question: Why do we put the version counter in TensorImpl instead of
// AutogradMeta?
//
// Answer: After the Variable/Tensor merge, a tensor will not have AutogradMeta
// when its `requires\_grad\_` is false, but when we use this tensor in the
// forward pass of a function that requires saving this tensor for backward, we
// need to keep track of this tensor's version to make sure it's always valid in
// the autograd graph.
//
// To achieve this goal, we put the version counter in TensorImpl instead of
// AutogradMeta, and have it always be available. This allows us to have the
// optimization of not carrying AutogradMeta when a tensor doesn't require
// gradient.
//
// A hypothetical alternative way to achieve this goal is to initialize
// AutogradMeta and create the version counter for the non-requires-grad tensor
// only when it's saved for backward. However, since saving a tensor for
// backward happens in the forward pass, and our invariant is that forward pass
// needs to be thread-safe, lazy-initializing AutogradMeta when saving a tensor
// can introduce race conditions when we are running the forward pass in
// multi-thread scenarios, thus making the forward pass not thread-safe anymore,
// which breaks the invariant.
// c10/core/TensorImpl.h
struct C10\_API VariableVersion {
private:
struct VersionCounter : intrusive_ptr_target {
VersionCounter(uint32\_t version) : version_(version) {}
std::atomic<uint32\_t> version_;
};
c10::intrusive_ptr<VersionCounter> version_counter_;
public:
// Note [Disabled VariableVersion]
// VariableVersion struct has an intrusive\_ptr pointing VersionCounter struct
// with an atomic variable. Thus `VariableVersion(/*version=*/0)` is not as
// cheap as we expected. In some cases constructing a VariableVersion with
// version 0 is not necessary so we add a cheap constructor which
// doesn't allocate the intrusive\_ptr.
// Example use cases are:
// - Inference tensors don't track version counter, so they'll just always
// have disbaled VariableVersion.
// - In SavedVariable class we override version\_counter\_ inside its
// construtor
// so that we can use the cheap constructor there.
enum Disabled { DISABLED };
// It's okay to return true even for inference tensor which
// doesn't have version counter enabled.
// We want to be permissive here since in many cases (e.g. make\_variable)
// we can std::move a TensorImpl if there's no other uses which saves us
// an additional TensorImpl allocation.
bool unique() const {
return version_counter_ ? 1 == version_counter_.use_count() : true;
}
// NOTE: As of C++11 and 14, default-constructing a std::atomic variable
// leaves it in a persistently undefined state. See
// https://cplusplus.github.io/LWG/issue2334.
VariableVersion(uint32\_t version)
: version_counter_(c10::make_intrusive<VersionCounter>(version)) {}
VariableVersion(Disabled = DISABLED) {}
bool enabled() const {
return version_counter_;
}
// Note [Inplace update inference tensor]
// 1. Inplace update to inference tensor is forbidden in normal mode.
// For example:
// inference\_tensor.copy\_(normal\_tensor\_requires\_grad)
// This inplace makes inference\_tensor have requires\_grad=True and
// have a grad\_fn. This is bad because views of `inference\_tensor`
// created in InferenceMode won't be able to know the grad\_fn since
// their ViewMeta were not recorded. To match NoGradMode behavior
// that "inplace update to a view created in NoGradMode raise an error",
// we just ban inplace update to inference tensor since we can't tell
// if an inference tensor is a view created in InferenceMode.
//
// Note that views of normal tensor created in InferenceMode has proper
// ViewMeta so that they're aware of the grad\_fn correctly.
//
// 2. Inplace update to inference tensor in inference tensor doesn't bump
// version counter.
// * It either doesn't call bump() by skipping ADInplaceOrView kernel,
// - e.g. inference\_tensor.add\_(1)
// * or bump() is a no-op for inference tensor.
// - e.g. inference\_tensor.add\_(normal\_tensor)
void bump() {
// TODO: Replace the link to the documentation once it's available.
TORCH_CHECK(
version_counter_ || InferenceMode::is_enabled(),
"Inplace update to inference tensor outside InferenceMode is not allowed."
"You can make a clone to get a normal tensor before doing inplace update."
"See https://github.com/pytorch/rfcs/pull/17 for more details.");
if (version_counter_) {
++version_counter_->version_;
}
}
// Inference tensor doesn't have version counter so it shouldn't be
// accessed.
uint32\_t current\_version() const {
TORCH_CHECK(
version_counter_, "Inference tensors do not track version counter.");
return version_counter_->version_;
}
};
VariableInfo
struct TORCH\_API VariableInfo {
explicit VariableInfo();
explicit VariableInfo(const Variable& var);
Variable zeros(at::OptionalDeviceGuard& device\_guard) const;
at::Layout layout = at::Layout::Strided;
at::Device device = at::kCPU;
at::ScalarType scalar_type = at::kFloat;
std::vector<int64\_t> size;
bool requires_grad;
bool is_empty;
};
CppNode
// CppNode<T> is the Node in the autograd graph that represents the user defined
// backward function for Function<T>. Calls to CppNode::apply are forward to
// T::backward().
template <class T>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct CppNode : public Node {
variable\_list apply(variable\_list&& inputs) override;
AutogradContext ctx_;
std::vector<bool> is_variable_input_;
std::vector<VariableInfo> input_info_;
std::vector<VariableInfo> output_info_;
void release\_variables() override;
void set\_ctx\_grad\_fn(const std::shared\_ptr<Node> &node);
void save\_variables\_to\_ctx();
};
apply
// The logic here is the same as PyNode::apply, so changes to it should be done
// in both the places
template<class T>
variable\_list CppNode<T>::apply(variable_list&& inputs) {
at::OptionalDeviceGuard _device_guard;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int num_inputs = inputs.size();
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
variable_list backward_inputs;
backward_inputs.reserve(num_inputs);
for (const auto i : c10::irange(num_inputs)) {
if (inputs[i].defined() || !ctx_.materialize_grads_) {
backward_inputs.emplace_back(inputs[i]);
} else {
backward_inputs.emplace_back(output_info_[i].zeros(_device_guard));
}
}
// Acquire lock to here protect thread safety on custom C++ Autograd Node
// This is needed for the custom Autograd Node since we don't know if the
// user defined Node will write to the shared data during backward.
// see Note [Thread Safety on Autograd Node]
std::lock\_guard<std::mutex> lock(mutex\_);
auto outputs = T::backward(&ctx_, backward_inputs);
const auto num_forward_inputs = static\_cast<int64\_t>(is_variable_input_.size());
auto num_outputs = static\_cast<int64\_t>(outputs.size());
// Returning too many results is ok, but only as long as they're all undefined.
// Truncate the result vector in that case.
if (num_outputs > num_forward_inputs) {
bool all_undef = true;
for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {
all_undef &= (!outputs[i].defined());
}
if (all_undef) {
outputs.resize(num_forward_inputs);
num_outputs = num_forward_inputs;
}
}
if (num_outputs != num_forward_inputs) {
std::string msg("function ");
msg += name() + " returned an incorrect number of gradients (expected ";
msg += c10::to_string(num_forward_inputs) + ", got " ;
msg += c10::to_string(num_outputs) + ")";
throw std::runtime_error(msg);
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
variable_list results;
results.reserve(num_outputs);
for (const auto i : c10::irange(num_outputs)) {
if (!is_variable_input_[i]) {
if (outputs[i].defined()) {
std::string msg("function ");
msg += name() + " returned a gradient different that is defined at position ";
msg += c10::to_string(i + 1) + ", but the corresponding forward input was not a Variable";
throw std::runtime_error(msg);
}
continue;
}
results.emplace_back(outputs[i]);
}
return results;
}
release_variables
template<class T>
void CppNode<T>::release_variables() {
// lock to ensure thread safety, see [Thread Safety on Autograd Node]
std::lock\_guard<std::mutex> lock(mutex\_);
ctx_.saved_variables_.clear();
ctx_.has_freed_buffers_ = true;
}
save_variables_to_ctx
template<class T>
void CppNode<T>::save_variables_to_ctx() {
ctx_.save_variables();
}
set_ctx_grad_fn
template<class T>
void CppNode<T>::set_ctx_grad_fn(const std::shared\_ptr<Node> &node) {
ctx_.grad_fn_ = node;
}