PyTorch AutoGrad Function

火山方舟向量数据库大模型
  
/// To use custom autograd operations, implement a Function subclass with  
/// static forward and backward functions:  
///  
/// `forward` can take as many arguments as you want and should return either a  
/// variable list or a Variable. Use of any direct Variable arguments will be  
/// registered in the graph but no vectors/sets or any other data structures  
/// will be traversed. You can use c10::optional<Tensor> as one of the arguments  
/// and it will be registered as a variable in the graph if the argument has a  
/// value. It should take a pointer to `torch::autograd::AutogradContext` as the  
/// first argument. Variables can be saved in the `ctx` using  
/// `ctx->save\_for\_backward`  
/// (see `torch::autograd::AutogradContext::save\_for\_backward`) and other data  
/// can be saved in the `ctx->saved\_data` map  
/// (see `torch::autograd::AutogradContext::saved\_data`)  
/// in the form of `<std::string, at::IValue>` pairs.  
///  
/// `backward` should take a pointer to `torch::autograd::AutogradContext`  
/// and a variable list containing as many Variables as there were outputs from  
/// `forward` as arguments. It should return as many Variables as there were  
/// inputs with each of them containing the gradient w.r.t. its corresponding  
/// input. Variables saved in `forward` can be accessed with  
/// `ctx->get\_saved\_variables` (see  
/// `torch::autograd::AutogradContext::get\_saved\_variables`) and other saved  
/// data can be accessed from `ctx->saved\_data`.  
///  
/// For example:  
/// ```  
/// class MyFunction : public Function<MyFunction> {  
///   public:  
///   static variable\_list forward(AutogradContext *ctx, int n, Variable var) {  
///      // Save data for backward in context  
///      ctx->saved\_data["n"] = n;  
///      var.mul\_(2);  
///      // Mark var as modified by inplace operation  
///      ctx->mark\_dirty({var});  
///      return {var};  
///   }  
///  
///   static variable\_list backward(AutogradContext *ctx, variable\_list  
///   grad\_output) {  
///      // Use data saved in forward  
///      auto n = ctx->saved\_data["n"].toInt();  
///      return {grad\_output[0]*n};  
///   }  
/// };  
/// ```  
///  
/// To use `MyFunction`:  
/// ```  
/// Variable x;  
/// auto y = MyFunction::apply(6, x);  
/// // Example backward call  
/// y[0].sum().backward();  
/// ```  

Function

  
// torch/csrc/autograd/function\_hook.h  
using Variable = at::Tensor;  
using variable_list = std::vector<Variable>;  
  
// torch/csrc/autograd/custom\_function.h  
using optional_variable_list = std::vector<c10::optional<Variable>>;  
using \_jvp\_fn\_t = std::function<variable_list(variable_list, variable_list)>;  
  
// Get the return type of the forward function of the custom Function class X  
template<typename X, typename... Args>  
using forward\_t = decltype(X::forward(nullptr, std::declval<Args>()...));  

  
template <class T>  
struct TORCH\_API Function {  
  // We need to use a different template parameter than T here because T will  
  // inherit from Function, and when Function<T> is instantiated, T::forward  
  // is not declared yet.  
  // The enable\_if check is to ensure that the user doesn't explicitly provide  
  // the parameter X.  
  template<typename X=T, typename... Args>  
  static auto apply(Args&&... args) -> std::enable\_if\_t<std::is\_same<X,T>::value, forward\_t<X,Args...>>;  
};  

apply(Args&&... args)

  
template<class T>  
template<typename X, typename... Args>  
auto Function<T>::apply(Args&&... args) -> std::enable\_if\_t<std::is_same<X,T>::value, forward\_t<X,Args...>> {  
  std::shared\_ptr<CppNode<T>> node(new CppNode<T>(), deleteNode);  
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)  
  variable_list input_vars;  
  
  const size\_t num_inputs = sizeof...(Args);  
  input_vars.reserve(num_inputs);  
  node->is_variable_input_.reserve(num_inputs);  
  // TODO Add tracing here  
  extract_vars(node->is_variable_input_, input_vars, args...);  
  
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)  
  bool is_executable =  GradMode::is_enabled() && any_variable_requires_grad(input_vars);  
  auto next_edges = (is_executable ? collect_next_edges(input_vars) : edge_list());  
  node->set_ctx_grad_fn(node);  
  node->set_next_edges(std::move(next_edges));  
  node->clear_input_metadata();  
  
  node->input_info_.reserve(input_vars.size());  
  for (auto& var : input_vars) {  
      node->input_info_.emplace_back(var);  
  }  
  
  using forward\_return\_t = forward\_t<X, Args...>;  
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)  
  forward\_return\_t outputs;  
  {  
    AutoGradMode grad\_mode(false);  
    outputs = T::forward(&node->ctx_, std::forward<Args>(args)...);  
  }  
  
  \_jvp\_fn\_t jvp_fn = [](variable_list inputs, variable_list gI) -> variable_list {  
    TORCH_CHECK(false, "jvp is not implemented for the c++ API of custom Function yet.",  
                "Please open a feature request on Github if you need this.");  
  };  
  
  auto wrapped_outputs = _wrap_outputs(  
    input_vars,  
    node->ctx_.get_non_differentiable(),  
    node->ctx_.get_and_bump_dirty(),  
    to_optional(outputs),  
    is_executable ? node : nullptr,  
    jvp_fn);  
  
  node->output_info_.reserve(wrapped_outputs.size());  
  for (auto& output : wrapped_outputs) {  
    if (is_executable && output.has_value()) {  
      node->output_info_.emplace_back(output.value());  
    } else if (is_executable) {  
      node->output_info_.emplace_back();  
    }  
  }  
  
  if (is_executable) {  
    node->save_variables_to_ctx();  
  }  
  
  // wrapped\_outputs will be a variable\_list so, convert it to the correct  
  // return type. Only Variable and variable\_list are accepted as return types.  
 return to_output_type<forward\_return\_t>(wrapped_outputs);  
}  

FunctionPreHook

  
// torch/csrc/autograd/function\_hook.h  
struct TORCH\_API FunctionPreHook {  
  virtual ~FunctionPreHook();  
  virtual variable\_list operator()(const variable\_list& grads) = 0;  
};  

FunctionPostHook

  
// torch/csrc/autograd/function\_hook.h  
struct TORCH\_API FunctionPostHook {  
  virtual ~FunctionPostHook();  
  virtual variable\_list operator()(  
    const variable\_list& outputs /* grad\_inputs */,  
    const variable\_list& inputs /* grad\_outputs */) = 0;  
};  

AutogradContext

  
/// Context to save information during `forward` that can be accessed in `backward`  
/// in custom autograd operations (see `torch::autograd::Function` for details).  
struct TORCH\_API AutogradContext {  
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)  
  AutogradContext() : materialize_grads_(true) {}  
  AutogradContext(const AutogradContext &other) = delete;  
  AutogradContext& operator=(const AutogradContext& other) = delete;  
  
  /// Can be used to save non-variable data for `backward`.  
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)  
  ska::flat_hash_map<std::string, at::IValue> saved_data;  
  
  /// Saves the list of variables for a future call to `backward`. This  
  /// should be called at most once from inside of `forward`.  
  void save\_for\_backward(variable\_list to\_save);  
  /// Marks variables in the list as modified in an in-place operation. This  
  /// should be called at most once from inside of `forward` and all arguments  
  /// should be inputs.  
  void mark\_dirty(const variable\_list &inputs);  
  /// Marks outputs in the list as not requiring gradients. This should be called  
  /// at most once from inside of `forward` and all arguments should be outputs.  
  void mark\_non\_differentiable(const variable\_list &outputs);  
  // Sets whether undefined output grad tensors should be expanded to tensors  
  // full of zeros before calling backward function. Default value is true.  
  void set\_materialize\_grads(bool value);  
  
  /// Get the list of variables that were saved in `forward` using  
  /// `save\_for\_backward()`. Before returning them to the user, a check is made to  
  /// ensure that they were not modified by any in-place operations.  
  variable\_list get\_saved\_variables() const;  
  const std::unordered\_set<at::TensorImpl*>& get\_and\_bump\_dirty() const;  
  const std::unordered\_set<at::TensorImpl*>& get\_non\_differentiable() const;  
  
private:  
  std::unordered\_set<at::TensorImpl*> non_differentiable_;  
  std::unordered\_set<at::TensorImpl*> dirty_inputs_;  
  std::vector<torch::autograd::SavedVariable> saved_variables_;  
  variable_list to_save_;  
  bool materialize_grads_;  
  
  // The CppNode in the autograd graph that owns this AutogradContext. We need a  
  // weak\_ptr to avoid a refcycle. Since grad\_fn\_ owns this AutogradContext, it  
  // will always be alive when we want to use it.  
  std::weak_ptr<Node> grad_fn_;  
  bool has_freed_buffers_;  
  
  void save\_variables();  
  
  template <class T> friend struct CppNode;  
};  

SavedVariable

  
/// A snapshot of a variable at a certain version. A `SavedVariable` stores  
/// enough information to reconstruct a variable from a certain point in time.  

  
// torch/csrc/autograd/saved\_variable.h  
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)  
class TORCH\_API SavedVariable {  
 public:  
  SavedVariable() = default;  
  SavedVariable(const Variable& variable, bool is_output, bool is_inplace_on_view=false);  
  SavedVariable(const c10::optional<Variable>& variable, bool is_output, bool is_inplace_on_view=false);  
  SavedVariable(SavedVariable&&) = default;  
  SavedVariable& operator=(SavedVariable&&) = default;  
  ~SavedVariable() {  
    if (fw_grad_) {  
      // See note [ Using ForwardGrad ]  
      fw_grad_->clear();  
    }  
  }  
  
  /// Reconstructs the saved variable. Pass `saved\_for` as the gradient  
  /// function if constructing the `SavedVariable` with it would have caused a  
  /// circular reference.  
  Variable unpack(std::shared\_ptr<Node> saved\_for = nullptr) const;  
  
  void register\_hooks(std::unique\_ptr<SavedVariableHooks>&& hooks);  
  
  void reset\_data();  
  
 private:  
  // This field contains either:  
  // 1. the variable to save  
  // 2. or its tensor\_data.  
  // If storing the variable itself would create a circular reference,  
  // we fall into the second case and its metadata is also saved separately.  
  // In that case, the grad\_fn must be passed in to the unpack function when  
  // reconstructing the Variable (except when we are doing an inplace operation on  
  // a view, see below).  
  // The field saved\_orignal\_ below reflects the two cases: its value is true  
  // in the first case and false in the second case.  
  // The value data\_.defined() can be false in three cases:  
  // 1. SavedVariable was constructed without a Tensor (the value to save is None), in  
  // that case was\_default\_constructed\_ will be kept at true  
  // 2. The saved variable has been released by calling SavedVariable::reset\_data(), typically  
  // during the backward pass  
  // 3. Hooks have been registered. In that case, hooks\_ will be defined instead.  
  // Note that the value of saved\_original\_ only reflects what happened during the construction  
  // of the SavedVariable. If saved\_original\_ is true, we saved the original tensor in data\_,  
  // but if the user registers hooks, we will no longer have it (despite the saved\_original\_ still  
  // being true)  
  at::Tensor data_;  
  
  // This field is used to store the forward AD gradients associated with  
  // the saved Tensor. Note that this shared\_ptr must never be shared with  
  // either the saved Tensor or the unpacked Tensor. See note [ Using ForwardGrad ]  
  std::shared\_ptr<ForwardGrad> fw_grad_;  
  
  // Weak version of grad\_fn\_ that prevents leaks in rebase\_history() for  
  // inplace views.  
  // This variable is used when the user chooses to create a SavedVariable with  
  // is\_inplace\_on\_view = true.  
  // In that case, the grad\_fn passed in to the unpack function at unwrapping  
  // time is unused.  
  std::weak_ptr<Node> weak_grad_fn_;  
  c10::VariableVersion version_counter_;  
  
  uint32\_t saved_version_ = 0;  
  uint32\_t output_nr_ = 0;  
  bool was_default_constructed_ = true;  
  bool is_inplace_on_view_ = false;  
  bool saved_original_ = false;  
  bool is_leaf_ = false;  
  bool is_output_ = false;  
  
  // Hooks are a pair of functions pack\_hook/unpack\_hook that provides fine-grained control  
  // over how the SavedVariable should save its data.  
  // pack\_hook is called upon registration, while unpack\_hook is called when unpacking.  
  std::unique\_ptr<SavedVariableHooks> hooks_;  
  // Fields grad\_fn\_, grad\_accumulator\_, and requires\_grad\_ are only used if hooks are defined.  
  // They are set before pack\_hook is called and used after unpack\_hook is called.  
  std::shared\_ptr<Node> grad_fn_;  
  std::weak_ptr<Node> grad_accumulator_;  
  bool requires_grad_ = false;  
  
  void save\_metadata(const Variable& data);  
  static std::unique\_ptr<SavedVariableHooks> get\_default\_hooks();  
  void set\_hooks\_and\_pack\_data(std::unique\_ptr<SavedVariableHooks>&& hooks, const Variable& data);  
};  

ForwardGrad

  
// [ Using ForwardGrad ]  
// ForwardGrad needs to be a shared\_ptr to satisfy constraints of its inner design. But  
// this shared\_ptr must be uniquely associated with the object that stores it (as of  
// writing, either AutogradMeta or SavedVariable). This object is called the "owning object"  
// in the discussions below. This owning object must call `ForwardGrad::clear()` when it  
// is destroyed to ensure that the ForwardGrad is properly de-allocated.  

  
// torch/csrc/autograd/forward\_grad.h  
struct TORCH\_API ForwardGrad : std::enable_shared_from_this<ForwardGrad> {  
  ForwardGrad() = default;  
  
  // This function must only be called when AutogradMeta or SavedVariable is  
  // being destructed as it ensures that:  
  //   - The only (potential) other references to this ForwardGrad are the  
  //     different level it is registered to  
  //   - No other thread will try to call `set\_value` or `value` ever from now  
  //   on  
  //   - Any of the ForwardADLevel that this ForwardGrad is registered with  
  //   might  
  //     call `reset` at any point during this function  
  void clear() {  
    c10::SmallVector<uint64\_t, EXPECTED_MAX_LEVEL> levels_idx;  
  
    {  
      std::lock\_guard<std::mutex> lock(mutex\_);  
      for (auto& c : content_) {  
        levels_idx.push_back(c.first);  
      }  
    }  
  
    for (auto l_idx : levels_idx) {  
      // Use "try" version here as another thread might have deleted this  
      // level before we got here  
      // This is an owning reference as we want to keep the level alive  
      // until we successfully unregister ourselves  
      auto level = ForwardADLevel::try_get_by_idx(l_idx);  
      if (level) {  
        level->erase(shared_from_this());  
      }  
    }  
  }  
  
  void set\_value(const at::Tensor& value, uint64\_t level) {  
      // Owning reference to ensure the forward\_level is not destroyed  
      // while we are updating our internal state  
      auto forward_level = ForwardADLevel::get_by_idx(level);  
      forward_level->insert(shared_from_this());  
  
      std::lock\_guard<std::mutex> lock(mutex\_);  
      content_.insert({level, value});  
  }  
  
  // This function removes the tangent for a given level from this ForwardGrad  
  // Use the update\_level flag to disable notifying the level about this reset  
  // This flag is most notably used by the ForwardADLevel destructor.  
  void reset(uint64\_t level, bool update\_level=true) {  
      if (update_level) {  
          ForwardADLevel::get_by_idx(level)->erase(shared_from_this());  
      }  
  
      std::unique\_lock<std::mutex> lock(mutex\_);  
      const auto& it = content_.find(level);  
      TORCH_INTERNAL_ASSERT(it != content_.end(), "Resetting a non-existent level.");  
      // Keep the Tensor alive until we have released the lock  
      // This is needed as we can be in a case where this function is called by  
      // ForwardADLevel destructor  
      auto t = (*it).second;  
      content_.erase(level);  
      lock.unlock();  
  }  
  
  const at::Tensor& value(uint64\_t level) const;  
  
  bool contains(uint64\_t level) {  
      std::lock\_guard<std::mutex> lock(mutex\_);  
      return content_.count(level) > 0;  
  }  
  
  bool empty() const {  
      return content_.empty();  
  }  
  
  static const at::Tensor& undef\_grad();  
  
  
private:  
    // TODO(albanD): replace this with a SmallVector  
    std::unordered\_map<uint64\_t, at::Tensor> content_;  
    mutable std::mutex mutex_;  
  
};  

ForwardADLevel

  
// torch/csrc/autograd/forward\_grad.h  
struct TORCH\_API ForwardADLevel {  
  ForwardADLevel(uint64\_t idx) : idx_(idx) {}  
  ~ForwardADLevel();  
  
  static uint64\_t get\_next\_idx();  
  static void release\_idx(uint64\_t idx);  
  static std::shared\_ptr<ForwardADLevel> get\_by\_idx(uint64\_t idx);  
  static std::shared\_ptr<ForwardADLevel> try\_get\_by\_idx(uint64\_t idx);  
  
  void erase(const std::shared\_ptr<ForwardGrad>& grad) {  
    std::lock\_guard<std::mutex> lock(mutex\_);  
    grads_.erase(grad);  
  }  
  
  void insert(const std::shared\_ptr<ForwardGrad>& grad) {  
    std::lock\_guard<std::mutex> lock(mutex\_);  
    grads_.insert(grad);  
  }  
  
private:  
    std::unordered\_set<std::shared\_ptr<ForwardGrad>> grads_;  
    std::mutex mutex_;  
    uint64\_t idx_;  
  
};  

SavedVariableHooks

  
// torch/csrc/autograd/saved\_variable\_hooks.h  
struct TORCH\_API SavedVariableHooks {  
  virtual void call\_pack\_hook(const at::Tensor &tensor) = 0;  
  virtual at::Tensor call\_unpack\_hook() = 0;  
  virtual ~SavedVariableHooks() = default;  
};  

VariableVersion

  
// NOTE [ Version Counter Sharing ]  
//  
// Every Tensor has a version counter. Version counters are incremented whenever  
// the data or size of a tensor changes through in-place Variable operations.  
// Version counters are used to detect modifications to saved variables which  
// would result in incorrect gradient calculations. Version counters may be  
// shared between Variables:  
//  
// 1. A view shares the version counter of the base Variable,  
// 2. `x.detach()` shares the version counter of `x`,  
// 3. Unpacked saved variables share the version counter of the source.  
//  
// Version counters are not shared in these scenarios:  
//  
// 1. When we replace a `Variable`'s underlying `Tensor` by calling  
// `set\_data(...)`,  
// 2. `x.data` does not share the version counter of `x`. (See discussion at  
// https://github.com/pytorch/pytorch/issues/5396)  
//  
// Question: Why do we put the version counter in TensorImpl instead of  
// AutogradMeta?  
//  
// Answer: After the Variable/Tensor merge, a tensor will not have AutogradMeta  
// when its `requires\_grad\_` is false, but when we use this tensor in the  
// forward pass of a function that requires saving this tensor for backward, we  
// need to keep track of this tensor's version to make sure it's always valid in  
// the autograd graph.  
//  
// To achieve this goal, we put the version counter in TensorImpl instead of  
// AutogradMeta, and have it always be available. This allows us to have the  
// optimization of not carrying AutogradMeta when a tensor doesn't require  
// gradient.  
//  
// A hypothetical alternative way to achieve this goal is to initialize  
// AutogradMeta and create the version counter for the non-requires-grad tensor  
// only when it's saved for backward. However, since saving a tensor for  
// backward happens in the forward pass, and our invariant is that forward pass  
// needs to be thread-safe, lazy-initializing AutogradMeta when saving a tensor  
// can introduce race conditions when we are running the forward pass in  
// multi-thread scenarios, thus making the forward pass not thread-safe anymore,  
// which breaks the invariant.  

  
// c10/core/TensorImpl.h  
struct C10\_API VariableVersion {  
 private:  
  struct VersionCounter : intrusive_ptr_target {  
    VersionCounter(uint32\_t version) : version_(version) {}  
    std::atomic<uint32\_t> version_;  
  };  
  c10::intrusive_ptr<VersionCounter> version_counter_;  
  
 public:  
  // Note [Disabled VariableVersion]  
  // VariableVersion struct has an intrusive\_ptr pointing VersionCounter struct  
  // with an atomic variable. Thus `VariableVersion(/*version=*/0)` is not as  
  // cheap as we expected. In some cases constructing a VariableVersion with  
  // version 0 is not necessary so we add a cheap constructor which  
  // doesn't allocate the intrusive\_ptr.  
  // Example use cases are:  
  //  - Inference tensors don't track version counter, so they'll just always  
  //    have disbaled VariableVersion.  
  //  - In SavedVariable class we override version\_counter\_ inside its  
  //  construtor  
  //    so that we can use the cheap constructor there.  
  enum Disabled { DISABLED };  
  // It's okay to return true even for inference tensor which  
  // doesn't have version counter enabled.  
  // We want to be permissive here since in many cases (e.g. make\_variable)  
  // we can std::move a TensorImpl if there's no other uses which saves us  
  // an additional TensorImpl allocation.  
  bool unique() const {  
    return version_counter_ ? 1 == version_counter_.use_count() : true;  
  }  
  // NOTE: As of C++11 and 14, default-constructing a std::atomic variable  
  // leaves it in a persistently undefined state. See  
  // https://cplusplus.github.io/LWG/issue2334.  
  VariableVersion(uint32\_t version)  
      : version_counter_(c10::make_intrusive<VersionCounter>(version)) {}  
  VariableVersion(Disabled = DISABLED) {}  
  
  bool enabled() const {  
    return version_counter_;  
  }  
  
  // Note [Inplace update inference tensor]  
  // 1. Inplace update to inference tensor is forbidden in normal mode.  
  //   For example:  
  //     inference\_tensor.copy\_(normal\_tensor\_requires\_grad)  
  //   This inplace makes inference\_tensor have requires\_grad=True and  
  //   have a grad\_fn.  This is bad because views of `inference\_tensor`  
  //   created in InferenceMode won't be able to know the grad\_fn since  
  //   their ViewMeta were not recorded. To match NoGradMode behavior  
  //   that "inplace update to a view created in NoGradMode raise an error",  
  //   we just ban inplace update to inference tensor since we can't tell  
  //   if an inference tensor is a view created in InferenceMode.  
  //  
  //   Note that views of normal tensor created in InferenceMode has proper  
  //   ViewMeta so that they're aware of the grad\_fn correctly.  
  //  
  // 2. Inplace update to inference tensor in inference tensor doesn't bump  
  //    version counter.  
  //    * It either doesn't call bump() by skipping ADInplaceOrView kernel,  
  //      - e.g. inference\_tensor.add\_(1)  
  //    * or bump() is a no-op for inference tensor.  
  //      - e.g. inference\_tensor.add\_(normal\_tensor)  
  void bump() {  
    // TODO: Replace the link to the documentation once it's available.  
    TORCH_CHECK(  
        version_counter_ || InferenceMode::is_enabled(),  
        "Inplace update to inference tensor outside InferenceMode is not allowed."  
        "You can make a clone to get a normal tensor before doing inplace update."  
        "See https://github.com/pytorch/rfcs/pull/17 for more details.");  
    if (version_counter_) {  
      ++version_counter_->version_;  
    }  
  }  
  
  // Inference tensor doesn't have version counter so it shouldn't be  
  // accessed.  
  uint32\_t current\_version() const {  
    TORCH_CHECK(  
        version_counter_, "Inference tensors do not track version counter.");  
    return version_counter_->version_;  
  }  
};  

VariableInfo

  
struct TORCH\_API VariableInfo {  
  explicit VariableInfo();  
  explicit VariableInfo(const Variable& var);  
  
  Variable zeros(at::OptionalDeviceGuard& device\_guard) const;  
  
  at::Layout layout = at::Layout::Strided;  
  at::Device device = at::kCPU;  
  at::ScalarType scalar_type = at::kFloat;  
  std::vector<int64\_t> size;  
  bool requires_grad;  
  bool is_empty;  
};  

CppNode

  
// CppNode<T> is the Node in the autograd graph that represents the user defined  
// backward function for Function<T>. Calls to CppNode::apply are forward to  
// T::backward().  
template <class T>  
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)  
struct CppNode : public Node {  
  variable\_list apply(variable\_list&& inputs) override;  
  AutogradContext ctx_;  
  std::vector<bool> is_variable_input_;  
  std::vector<VariableInfo> input_info_;  
  std::vector<VariableInfo> output_info_;  
  
  void release\_variables() override;  
  
  void set\_ctx\_grad\_fn(const std::shared\_ptr<Node> &node);  
  void save\_variables\_to\_ctx();  
};  

apply

  
// The logic here is the same as PyNode::apply, so changes to it should be done  
// in both the places  
template<class T>  
variable\_list CppNode<T>::apply(variable_list&& inputs) {  
  at::OptionalDeviceGuard _device_guard;  
  
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)  
  int num_inputs = inputs.size();  
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)  
  variable_list backward_inputs;  
  backward_inputs.reserve(num_inputs);  
  for (const auto i : c10::irange(num_inputs)) {  
    if (inputs[i].defined() || !ctx_.materialize_grads_) {  
      backward_inputs.emplace_back(inputs[i]);  
    } else {  
      backward_inputs.emplace_back(output_info_[i].zeros(_device_guard));  
    }  
  }  
  
  // Acquire lock to here protect thread safety on custom C++ Autograd Node  
  // This is needed for the custom Autograd Node since we don't know if the  
  // user defined Node will write to the shared data during backward.  
  // see Note [Thread Safety on Autograd Node]  
  std::lock\_guard<std::mutex> lock(mutex\_);  
  
  auto outputs = T::backward(&ctx_, backward_inputs);  
  
  const auto num_forward_inputs = static\_cast<int64\_t>(is_variable_input_.size());  
  auto num_outputs = static\_cast<int64\_t>(outputs.size());  
  // Returning too many results is ok, but only as long as they're all undefined.  
  // Truncate the result vector in that case.  
  if (num_outputs > num_forward_inputs) {  
    bool all_undef = true;  
    for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {  
      all_undef &= (!outputs[i].defined());  
    }  
    if (all_undef) {  
      outputs.resize(num_forward_inputs);  
      num_outputs = num_forward_inputs;  
    }  
  }  
  
  if (num_outputs != num_forward_inputs) {  
    std::string msg("function ");  
    msg += name() + " returned an incorrect number of gradients (expected ";  
    msg += c10::to_string(num_forward_inputs) + ", got " ;  
    msg += c10::to_string(num_outputs) + ")";  
    throw std::runtime_error(msg);  
  }  
  
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)  
  variable_list results;  
  results.reserve(num_outputs);  
  for (const auto i : c10::irange(num_outputs)) {  
    if (!is_variable_input_[i]) {  
      if (outputs[i].defined()) {  
        std::string msg("function ");  
        msg += name() + " returned a gradient different that is defined at position ";  
        msg += c10::to_string(i + 1) + ", but the corresponding forward input was not a Variable";  
        throw std::runtime_error(msg);  
      }  
      continue;  
    }  
    results.emplace_back(outputs[i]);  
  }  
  return results;  
}  

release_variables

  
template<class T>  
void CppNode<T>::release_variables() {  
  // lock to ensure thread safety, see [Thread Safety on Autograd Node]  
  std::lock\_guard<std::mutex> lock(mutex\_);  
  ctx_.saved_variables_.clear();  
  ctx_.has_freed_buffers_ = true;  
}  

save_variables_to_ctx

  
template<class T>  
void CppNode<T>::save_variables_to_ctx() {  
  ctx_.save_variables();  
}  

set_ctx_grad_fn

  
template<class T>  
void CppNode<T>::set_ctx_grad_fn(const std::shared\_ptr<Node> &node) {  
  ctx_.grad_fn_ = node;  
}  

参考文献

0
0
0
0
关于作者

文章

0

获赞

0

收藏

0

相关资源
CloudWeGo白皮书:字节跳动云原生微服务架构原理与开源实践
本书总结了字节跳动自2018年以来的微服务架构演进之路
相关产品
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论