Torch and Oneflow LogSumExp

unit test

  
# caffe2/python/operator\_test/segment\_ops\_test.py  
def logsumexp(x):  
    return np.log(np.sum(np.exp(x), axis=0))  
  
  
def logsumexp\_grad(grad\_out, outputs, inputs):  
    sum_exps = np.sum(np.exp(inputs[0]), axis=0)  
    return np.repeat(  
        np.expand_dims(grad_out / sum_exps, 0),  
        inputs[0].shape[0],  
        axis=0  
    ) * np.exp(inputs[0])  

decompositions

  
# torch/\_decomp/decompositions.py  
  
@register\_decomposition(aten.logsumexp.default)  
@pw\_cast\_for\_int\_to\_real  
def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor:  
    if self.numel() == 0:  
        return torch.sum(torch.exp(self), dim, keepdim).log()  
    maxes = torch.amax(self, dim, keepdim=True)  
    maxes_squeezed = maxes if keepdim else _squeeze_multiple(maxes, dim)  
    maxes_squeezed = torch.masked_fill(maxes_squeezed, maxes_squeezed.abs() == float('inf'), 0)  
    result = torch.sum(torch.exp(self - maxes), dim, keepdim)  
    return result.log().add(maxes_squeezed)  

native_functions.yaml

  
# aten/src/ATen/native/native\_functions.yaml  
- func: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor  
  device\_check: NoCheck   # TensorIterator  
  variants: function, method  
  dispatch:  
    CompositeExplicitAutograd: logsumexp  
  
- func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)  
  device\_check: NoCheck   # TensorIterator  
  dispatch:  
    CompositeExplicitAutograd: logsumexp\_out  
  
- func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor  
  device\_check: NoCheck   # TensorIterator  
  variants: function, method  
  
- func: logsumexp.names\_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)  
  device\_check: NoCheck   # TensorIterator  

derivatives.yaml

  
# tools/autograd/derivatives.yaml  
- name: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor  
  self: logsumexp\_backward(grad, self, result, dim, keepdim)  
  result: logsumexp\_jvp(self\_p, self\_t, dim, keepdim)  

DimnameList

  
// aten/src/ATen/core/Dimname.h  
  
  
namespace at {  
  
enum class NameType: uint8\_t { BASIC, WILDCARD };  
  
struct TORCH\_API Dimname {  
  static Dimname fromSymbol(Symbol name);  
  static Dimname wildcard();  
  static bool isValidName(const std::string& name);  
  
  NameType type() const { return type_; }  
  Symbol symbol() const { return name_; }  
  
  bool isBasic() const { return type_ == NameType::BASIC; }  
  bool isWildcard() const { return type_ == NameType::WILDCARD; }  
  
  bool matches(Dimname other) const;  
  c10::optional<Dimname> unify(Dimname other) const;  
  
 private:  
  Dimname(Symbol name)  
    : name_(name), type_(NameType::BASIC) {}  
  Dimname(Symbol name, NameType type)  
    : name_(name), type_(type) {}  
  
  Symbol name_;  
  NameType type_;  
};  
  
using DimnameList = c10::ArrayRef<Dimname>;  
  
TORCH_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname);  
  
inline bool operator==(const Dimname& lhs, const Dimname& rhs) {  
  return lhs.symbol() == rhs.symbol();  
}  
  
inline bool operator!=(const Dimname& lhs, const Dimname& rhs) {  
  return !(lhs == rhs);  
}  
  
} // namespace at  

Symbol

  
// aten/src/ATen/core/symbol.h  
  
namespace c10 {  
// 'prim' symbols are synthetic operators that occur only in the IR  
// and don't have corresponding implementations in ATen.  
  
// 'onnx' symbols correspond to ONNX operators.  Their semantics  
// are defined in https://github.com/onnx/onnx/blob/master/docs/Operators.md  
// The particular version we are targeting is specified by '\_onnx\_opset\_version'  
// in torch.onnx.symbolic\_helper  
//  
// In general, most ONNX operators won't get an entry here, because they  
// are handled from the Python end.  However, you may occasionally need  
// to intern an ONNX symbol here so that you can conveniently write an  
// optimization on ONNX operations.  
  
// 'attr' symbols are attribute keys.  They are shared between both ONNX and ATen  
// operators (you disambiguate their meaning by looking at the operator itself).  
// In general, you only need to define attribute keys that are used by  
// onnx or prim; ATen attributes are automatically generated in FORALL\_ATTR\_BASE\_SYMBOLS.  
  
// Note [Symbol allocation]  
// ~~~~~~~~~~~~~~~~~~~~~~~~  
//  
//  1. Symbol namespace is split up into namespaces.  
//  
//  2. The intended access pattern for built-in symbols is onnx::MatMul  
//  in the c10 namespace (this is a Symbol).  
//  
  
// Built-in constant definition strategy:  
// - Enum is the most convenient way to generate a contiguous sequence  
//   of numbers for an identifier.  
// - However, an enum gives you a fresh type.  We want onnx::MatMul to  
//   be type Symbol, not some random enum type!  
// - Therefore, after using enums to generate the sequence of integers,  
//   we then declare constexpr Symbols to get everything the actual Symbol  
//   type we want.  Symbols must be constexpr to be valid to be "case"ed on.  
  
using unique\_t = uint32\_t;  
  
const std::string& domain\_prefix();  
  
// A Symbol is like an interned string, but with a little extra  
// structure; it is namespaced via SymbolNamespace and the resulting  
// intern pointers support efficient namespace testing.  
struct TORCH\_API Symbol {  
  explicit constexpr Symbol() : value(0) {};  
  explicit constexpr Symbol(unique\_t uniq)  
  : value(uniq) {}  
  
  // So we can switch on this  
  constexpr operator unique\_t() const {  
    return value;  
  }  
    
  ...  
    
private:  
  
  explicit Symbol(Symbol ns, const std::string & s);  
  unique\_t value;  
};  
  
static inline bool operator==(Symbol lhs, Symbol rhs) {  
  return static\_cast<unique\_t>(lhs) == static\_cast<unique\_t>(rhs);  
}  
  
} // namespace c10  
  
// make symbol behave like an integer in hash tables  
namespace std {  
template <>  
struct hash<c10::Symbol> {  
  size\_t operator()(c10::Symbol s) const {  
    return std::hash<uint32\_t>()(static\_cast<uint32\_t>(s));  
  }  
};  
}  

ArrayRef

  
// c10/util/ArrayRef.h  
  
  
namespace c10 {  
  
using IntArrayRef = ArrayRef<int64\_t>;  
  
/// ArrayRef - Represent a constant reference to an array (0 or more elements  
/// consecutively in memory), i.e. a start pointer and a length.  It allows  
/// various APIs to take consecutive elements easily and conveniently.  
///  
/// This class does not own the underlying data, it is expected to be used in  
/// situations where the data resides in some other buffer, whose lifetime  
/// extends past that of the ArrayRef. For this reason, it is not in general  
/// safe to store an ArrayRef.  
///  
/// This is intended to be trivially copyable, so it should be passed by  
/// value.  
template <typename T>  
class ArrayRef final {  
 public:  
  using iterator = const T*;  
  using const_iterator = const T*;  
  using size_type = size\_t;  
  using value_type = T;  
  
  using reverse_iterator = std::reverse_iterator<iterator>;  
  
 private:  
  /// The start of the array, in an external buffer.  
  const T* Data;  
  
  /// The number of elements.  
  size_type Length;  
  
  void debugCheckNullptrInvariant() {  
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(  
        Data != nullptr || Length == 0,  
        "created ArrayRef with nullptr and non-zero length! c10::optional relies on this being illegal");  
  }  
  
 public:  
  /// @name Constructors  
  /// @{  
  
  /// Construct an empty ArrayRef.  
  /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {}  
  
  /// Construct an ArrayRef from a single element.  
  // TODO Make this explicit  
  constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}  
  
  /// Construct an ArrayRef from a pointer and length.  
  C10\_HOST\_CONSTEXPR\_EXCEPT\_WIN\_CUDA ArrayRef(const T* data, size\_t length)  
      : Data(data), Length(length) {  
    debugCheckNullptrInvariant();  
  }  
  
  /// Construct an ArrayRef from a range.  
  C10\_HOST\_CONSTEXPR\_EXCEPT\_WIN\_CUDA ArrayRef(const T* begin, const T* end)  
      : Data(begin), Length(end - begin) {  
    debugCheckNullptrInvariant();  
  }  
  
  /// Construct an ArrayRef from a SmallVector. This is templated in order to  
  /// avoid instantiating SmallVectorTemplateCommon<T> whenever we  
  /// copy-construct an ArrayRef.  
  template <typename U>  
  /* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)  
      : Data(Vec.data()), Length(Vec.size()) {  
    debugCheckNullptrInvariant();  
  }  
  
  template <  
      typename Container,  
      typename = std::enable\_if\_t<std::is_same<  
          std::remove\_const\_t<decltype(std::declval<Container>().data())>,  
          T*>::value>>  
  /* implicit */ ArrayRef(const Container& container)  
      : Data(container.data()), Length(container.size()) {  
    debugCheckNullptrInvariant();  
  }  
  
  /// Construct an ArrayRef from a std::vector.  
  // The enable\_if stuff here makes sure that this isn't used for  
  // std::vector<bool>, because ArrayRef can't work on a std::vector<bool>  
  // bitfield.  
  template <typename A>  
  /* implicit */ ArrayRef(const std::vector<T, A>& Vec)  
      : Data(Vec.data()), Length(Vec.size()) {  
    static\_assert(  
        !std::is_same<T, bool>::value,  
        "ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");  
  }  
  
  /// Construct an ArrayRef from a std::array  
  template <size\_t N>  
  /* implicit */ constexpr ArrayRef(const std::array<T, N>& Arr)  
      : Data(Arr.data()), Length(N) {}  
  
  /// Construct an ArrayRef from a C array.  
  template <size\_t N>  
  /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}  
  
  /// Construct an ArrayRef from a std::initializer\_list.  
  /* implicit */ constexpr ArrayRef(const std::initializer\_list<T>& Vec)  
      : Data(  
            std::begin(Vec) == std::end(Vec) ? static\_cast<T*>(nullptr)  
                                             : std::begin(Vec)),  
        Length(Vec.size()) {}  
          
    ...        
};  
} // namespace c10  

maybe_wrap_dim

  
// c10/core/WrapDimMinimal.cpp  
  
namespace c10 {  
namespace detail {  
  
int64\_t maybe\_wrap\_dim\_slow(  
    int64\_t dim,  
    int64\_t dim\_post\_expr,  
    bool wrap\_scalar) {  
  if (dim_post_expr <= 0) {  
    TORCH_CHECK_INDEX(  
        wrap_scalar,  
        "dimension specified as ",  
        dim,  
        " but tensor has no dimensions");  
    return c10::maybe_wrap_dim(dim, /*dim\_post\_expr=*/1, /*wrap\_scalar=*/false);  
  }  
  
  int64\_t min = -dim_post_expr;  
  int64\_t max = dim_post_expr - 1;  
  TORCH_CHECK_INDEX(  
      min <= dim && dim <= max,  
      "Dimension out of range (expected to be in range of [",  
      min,  
      ", ",  
      max,  
      "], but got ",  
      dim,  
      ")");  
  
  TORCH_INTERNAL_ASSERT(  
      false, "should never reach here as dim should be out-of-bounds");  
}  
  
} // namespace detail  
} // namespace c10  

dim_list_to_bitset

  
// aten/src/ATen/WrapDimUtilsMulti.h  
  
namespace at {  
  
// This is in an extra file to work around strange interaction of  
// bitset on Windows with operator overloading  
  
constexpr size\_t dim_bitset_size = 64;  
  
static inline std::bitset<dim\_bitset\_size> dim\_list\_to\_bitset(  
    IntArrayRef dims,  
    int64\_t ndims) {  
  TORCH_CHECK(  
      ndims <= (int64\_t)dim_bitset_size,  
      "only tensors with up to ",  
      dim_bitset_size,  
      " dims are supported");  
        
  std::bitset<dim_bitset_size> seen;  
    
  for (const auto i : c10::irange(dims.size())) {  
    size\_t dim = maybe_wrap_dim(dims[i], ndims);  
    TORCH_CHECK(  
        !seen[dim], "dim ", dim, " appears multiple times in the list of dims");  
          
    seen[dim] = true;  
  }  
  return seen;  
}  
  
} // namespace at  

unsqueeze_multiple

  
// torch/csrc/autograd/FunctionsManual.cpp  
  
Tensor unsqueeze\_multiple(const Tensor & t, IntArrayRef dim, size\_t n\_dims) {  
    auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);  
    Tensor res = t;  
    for(const auto i : c10::irange(n_dims)){  
      if (dims_to_unsqueeze[i]) {  
        res = res.unsqueeze(i);  
      }  
    }  
    return res;  
}  

squeeze_multiple

  
// torch/csrc/autograd/FunctionsManual.cpp  
  
static Tensor squeeze\_multiple(const Tensor& self, IntArrayRef dims) {  
  int ndims = self.sizes().size();  
  auto dims_to_squeeze = at::dim_list_to_bitset(dims, ndims);  
  Tensor result = self;  
  for (int i = ndims - 1; i >= 0; --i) {  
    if (dims_to_squeeze[i]) {  
      result = result.squeeze(i);  
    }  
  }  
  return result;  
}  

logsumexp

  
// aten/src/ATen/native/ReduceOps.cpp  
  
Tensor logsumexp(const Tensor& self, DimnameList dims, bool keepdim) {  
  return at::logsumexp(self, dimnames_to_positions(self, dims), keepdim);  
}  
  
Tensor& logsumexp\_out(const Tensor& self, DimnameList dims, bool keepdim, Tensor& result) {  
  return at::logsumexp_out(result, self, dimnames_to_positions(self, dims), keepdim);  
}  
  
// special\_logsumexp, alias for logsumexp  
Tensor special\_logsumexp(const Tensor& self, IntArrayRef dims, bool keepdim) {  
  return self.logsumexp(dims, keepdim);  
}  
Tensor& special\_logsumexp\_out(const Tensor& self, IntArrayRef dims, bool keepdim, Tensor& result) {  
  return at::logsumexp_out(result, self, dims, keepdim);  
}  

  
// torch/csrc/api/include/torch/special.h  
  
/// Computes the log of summed exponentials of each row of input in the given dimension dim  
/// See https://pytorch.org/docs/master/special.html#torch.special.logsumexp.  
///  
/// Example:  
/// ```  
/// auto t = torch::randn(3, 3);  
/// torch::special::logsumexp(t, 1);  
/// ```  
inline Tensor logsumexp(const Tensor& self, IntArrayRef dims, bool keepdim) {  
  return torch::special_logsumexp(self, dims, keepdim);  
}  
  
inline Tensor& logsumexp\_out(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) {  
  return torch::special_logsumexp_out(result, self, dims, keepdim);  
}  
  

  
// aten/src/ATen/native/ReduceOps.cpp  
  
static inline bool isIntegralType(ScalarType t, bool includeBool) {  
  bool isIntegral =  
      (t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||  
       t == ScalarType::Long || t == ScalarType::Short);  
  
  return includeBool ? isIntegral || (t == ScalarType::Bool) : isIntegral;  
}  
  
  
Tensor logsumexp(const Tensor& self, IntArrayRef dims, bool keepdim) {  
  TensorOptions result_options;  
  if (at::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {  
    // even for integral inputs, result is floating dtype  
    auto default_dtype = at::typeMetaToScalarType(c10::get_default_dtype());  
    result_options = self.options().dtype(default_dtype);  
  } else {  
    result_options = self.options();  
  }  
  auto result = at::empty({0}, result_options);  
  return at::logsumexp_outf(self, dims, keepdim, result);  
}  

logsumexp_outf

logsumexp_out

  
Tensor& logsumexp\_out(const Tensor& self, IntArrayRef dims, bool keepdim, Tensor& result) {  
  TORCH_CHECK(at::isFloatingType(result.scalar_type()),  
              "logsumexp(): Expected floating point type for result tensor, but got: ",  
              result.scalar_type());  
  {  
    NoNamesGuard guard;  
    if (at::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {  
      // for integral inputs, promote input to default floating type.  
      auto default_dtype = at::typeMetaToScalarType(c10::get_default_dtype());  
      logsumexp_out_impl(result, self.to(default_dtype), dims, keepdim);  
    } else {  
      logsumexp_out_impl(result, self, dims, keepdim);  
    }  
  }  
  namedinference::propagate_names_for_reduction(result, self, dims, keepdim);  
  return result;  
}  

logsumexp_out_impl

  
// torch/csrc/autograd/FunctionsManual.cpp  
  
static Tensor& logsumexp\_out\_impl(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) {  
  // can't take max of empty tensor  
  if (self.numel() != 0) {  
    auto maxes = at::amax(self, dims, true);  
    auto maxes_squeezed = (keepdim ? maxes : squeeze_multiple(maxes, dims));  
    maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);  
    at::sum_out(result, (self - maxes).exp_(), dims, keepdim);  
    result.log_().add_(maxes_squeezed);  
  } else {  
    at::sum_out(result, at::exp(self), dims, keepdim);  
    result.log_();  
  }  
  return result;  
}  

logsumexp_backward

  
// torch/csrc/autograd/FunctionsManual.cpp  
  
// maybe `self=in`, `result=out`  
Tensor logsumexp\_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) {  
  if (!keepdim && self.dim() != 0) {  
    grad = unsqueeze_multiple(grad, dim, self.sizes().size());  
    result = unsqueeze_multiple(result, dim, self.sizes().size());  
  }  
  return grad * (self - result).exp();  
}  

logsumexp_jvp

  
// torch/csrc/autograd/FunctionsManual.cpp  
  
Tensor logsumexp\_jvp(const Tensor& self\_p, const Tensor& self\_t, IntArrayRef dim, bool keepdim) {  
  // NB: for simplicitly, we recompute some values that can be reused from forward  
  auto self_p_exp = (self_p - at::amax(self_p, dim, true)).exp();  // Use the exp-normalize trick  
  auto sumexp_p = self_p_exp.sum(dim, keepdim);  
  
  // NB: it's OK for logsumexp\_jvp to be reused for formulas like softmax/log\_softmax  
  //     that only have one differentiable input, because that means self\_t are never zerotensors  
  TORCH_INTERNAL_ASSERT(!self\_t._is_zerotensor())  
  if (areAnyTensorSubclassLike({self_p, self\_t})) {  
    auto result = (self_p_exp * self\_t).sum(dim, keepdim);  
    result /= sumexp_p;  
    return result;  
  } else {  
    self_p_exp *= self\_t;  
    auto sumexp\_t = self_p_exp.sum(dim, keepdim);  
    return sumexp\_t /= sumexp_p;  
  }  
}  

getOperatorAliasMap

  
// torch/csrc/jit/passes/normalize\_ops.cpp  
  
const std::unordered\_map<Symbol, Symbol>& getOperatorAliasMap() {  
  // map from op alias -> normalized op  
  static const std::unordered\_map<Symbol, Symbol> alias_map = {  
      {aten::absolute, aten::abs},  
      {aten::absolute_, aten::abs_},  
      {aten::clip, aten::clamp},  
      {aten::clip_, aten::clamp_},  
      {aten::det, aten::linalg_det},  
      {aten::matrix_power, aten::linalg_matrix_power},  
      {aten::matrix_exp, aten::linalg_matrix_exp},  
      {aten::ger, aten::outer},  
      {aten::arccos, aten::acos},  
      {aten::arccos_, aten::acos_},  
      {aten::arcsin, aten::asin},  
      {aten::arcsin_, aten::asin_},  
      {aten::arctan, aten::atan},  
      {aten::arctan_, aten::atan_},  
      {aten::arctan2, aten::atan2},  
      {aten::arctan2_, aten::atan2_},  
      {aten::arccosh, aten::acosh},  
      {aten::arccosh_, aten::acosh_},  
      {aten::arcsinh, aten::asinh},  
      {aten::arcsinh_, aten::asinh_},  
      {aten::arctanh, aten::atanh},  
      {aten::arctanh_, aten::atanh_},  
      {aten::fix, aten::trunc},  
      {aten::fix_, aten::trunc_},  
      {aten::negative, aten::neg},  
      {aten::negative_, aten::neg_},  
      {aten::subtract, aten::sub},  
      {aten::subtract_, aten::sub_},  
      {aten::greater_equal, aten::ge},  
      {aten::greater_equal_, aten::ge_},  
      {aten::greater, aten::gt},  
      {aten::greater_, aten::gt_},  
      {aten::less_equal, aten::le},  
      {aten::less_equal_, aten::le_},  
      {aten::less, aten::lt},  
      {aten::less_, aten::lt_},  
      {aten::not_equal, aten::ne},  
      {aten::not_equal_, aten::ne_},  
      {aten::divide, aten::div},  
      {aten::divide_, aten::div_},  
      {aten::multiply, aten::mul},  
      {aten::multiply_, aten::mul_},  
      {aten::linalg_matmul, aten::matmul},  
      {aten::true_divide, aten::div},  
      {aten::true_divide_, aten::div_},  
      {aten::concat, aten::cat},  
      {aten::row_stack, aten::vstack},  
      {aten::swapdims, aten::transpose},  
      {aten::swapdims_, aten::transpose_},  
      {aten::swapaxes, aten::transpose},  
      {aten::swapaxes_, aten::transpose_},  
      {aten::moveaxis, aten::movedim},  
      {aten::special_erf, aten::erf},  
      {aten::special_erfc, aten::erfc},  
      {aten::special_erfinv, aten::erfinv},  
      {aten::special_expit, aten::sigmoid},  
      {aten::special_exp2, aten::exp2},  
      {aten::special_expm1, aten::expm1},  
      {aten::special_logit, aten::logit},  
      {aten::special_logsumexp, aten::logsumexp},  
      {aten::special_round, aten::round},  
      {aten::special_log1p, aten::log1p},  
      {aten::special_sinc, aten::sinc},  
      {aten::special_digamma, aten::digamma},  
      {aten::special_psi, aten::digamma},  
      {aten::special_i0, aten::i0},  
      {aten::special_xlogy, aten::xlogy},  
      {aten::special_log_softmax, aten::log_softmax},  
      {aten::orgqr, aten::linalg_householder_product},  
      {aten::adjoint, aten::mH},  
      {aten::special_multigammaln, aten::mvlgamma},  
      {aten::special_polygamma, aten::polygamma},  
      {aten::special_softmax, aten::softmax},  
      {aten::special_gammainc, aten::igamma},  
      {aten::special_gammaincc, aten::igammac},  
      {aten::special_gammaln, aten::lgamma}};  
  return alias_map;  
}  

oneflow

  
// oneflow/core/functional/impl/math\_functor.cpp  
  
class LogSumExpFunctor {  
 public:  
  LogSumExpFunctor() {}  
  Maybe<Tensor> operator()(const std::shared\_ptr<one::Tensor>& x, const std::vector<int32\_t>& axis,  
                           const bool& keepdims) const {  
    if (x->ndim() == 0) {  
      // can't take amax of 0-dim tensor  
      return To(x, JUST(DType::Get(DataType::kFloat)), false);  
    } else if (x->nelement() == 0) {  
      // can't take amax of empty tensor  
      std::shared\_ptr<one::Tensor> exp_out = JUST(Exp(x));  
      return Log(JUST(ReduceSum(exp_out, axis, keepdims)));  
    } else {  
      const std::shared\_ptr<one::Tensor>& maxes = JUST(Amax(x, axis, true));  
      const std::shared\_ptr<one::Tensor>& maxes_squeezed =  
          (keepdims ? maxes : JUST(SqueezeMultiple(maxes, axis)));  
      JUST(MaskedFillInplace(maxes_squeezed,  
                             JUST(ScalarLogicalEqual(JUST(Abs(maxes_squeezed)), INFINITY)), 0));  
      std::shared\_ptr<one::Tensor> exp_out = JUST(Exp(JUST(Sub(x, maxes, 1, false))));  
      return Add(JUST(Log(JUST(ReduceSum(exp_out, axis, keepdims)))), maxes_squeezed, 1, false);  
    }  
  }  
  
 private:  
  Maybe<Tensor> SqueezeMultiple(const std::shared\_ptr<one::Tensor>& x,  
                                const std::vector<int32\_t>& axis) const {  
    int ndims = x->ndim();  
    const auto& dims_to_squeeze = JUST(dim_list_to_bitset(axis, ndims));  
    std::shared\_ptr<one::Tensor> result = x;  
    for (int i = ndims - 1; i >= 0; --i) {  
      if ((*dims_to_squeeze)[i]) {  
        std::vector<int32\_t> dims = {i};  
        result = JUST(Squeeze(result, dims));  
      }  
    }  
    return result;  
  }  
};  

picture.image

0
0
0
0
评论
未登录
暂无评论