unit test
# caffe2/python/operator\_test/segment\_ops\_test.py
def logsumexp(x):
return np.log(np.sum(np.exp(x), axis=0))
def logsumexp\_grad(grad\_out, outputs, inputs):
sum_exps = np.sum(np.exp(inputs[0]), axis=0)
return np.repeat(
np.expand_dims(grad_out / sum_exps, 0),
inputs[0].shape[0],
axis=0
) * np.exp(inputs[0])
decompositions
@register\_decomposition(aten.logsumexp.default)
@pw\_cast\_for\_int\_to\_real
def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor:
if self.numel() == 0:
return torch.sum(torch.exp(self), dim, keepdim).log()
maxes = torch.amax(self, dim, keepdim=True)
maxes_squeezed = maxes if keepdim else _squeeze_multiple(maxes, dim)
maxes_squeezed = torch.masked_fill(maxes_squeezed, maxes_squeezed.abs() == float('inf'), 0)
result = torch.sum(torch.exp(self - maxes), dim, keepdim)
return result.log().add(maxes_squeezed)
native_functions.yaml
- func: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
device\_check: NoCheck
variants: function, method
dispatch:
CompositeExplicitAutograd: logsumexp
- func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device\_check: NoCheck
dispatch:
CompositeExplicitAutograd: logsumexp\_out
- func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor
device\_check: NoCheck
variants: function, method
- func: logsumexp.names\_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device\_check: NoCheck
derivatives.yaml
- name: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
self: logsumexp\_backward(grad, self, result, dim, keepdim)
result: logsumexp\_jvp(self\_p, self\_t, dim, keepdim)
DimnameList
namespace at {
enum class NameType: uint8\_t { BASIC, WILDCARD };
struct TORCH\_API Dimname {
static Dimname fromSymbol(Symbol name);
static Dimname wildcard();
static bool isValidName(const std::string& name);
NameType type() const { return type_; }
Symbol symbol() const { return name_; }
bool isBasic() const { return type_ == NameType::BASIC; }
bool isWildcard() const { return type_ == NameType::WILDCARD; }
bool matches(Dimname other) const;
c10::optional<Dimname> unify(Dimname other) const;
private:
Dimname(Symbol name)
: name_(name), type_(NameType::BASIC) {}
Dimname(Symbol name, NameType type)
: name_(name), type_(type) {}
Symbol name_;
NameType type_;
};
using DimnameList = c10::ArrayRef<Dimname>;
TORCH_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname);
inline bool operator==(const Dimname& lhs, const Dimname& rhs) {
return lhs.symbol() == rhs.symbol();
}
inline bool operator!=(const Dimname& lhs, const Dimname& rhs) {
return !(lhs == rhs);
}
}
Symbol
namespace c10 {
using unique\_t = uint32\_t;
const std::string& domain\_prefix();
struct TORCH\_API Symbol {
explicit constexpr Symbol() : value(0) {};
explicit constexpr Symbol(unique\_t uniq)
: value(uniq) {}
constexpr operator unique\_t() const {
return value;
}
...
private:
explicit Symbol(Symbol ns, const std::string & s);
unique\_t value;
};
static inline bool operator==(Symbol lhs, Symbol rhs) {
return static\_cast<unique\_t>(lhs) == static\_cast<unique\_t>(rhs);
}
}
namespace std {
template <>
struct hash<c10::Symbol> {
size\_t operator()(c10::Symbol s) const {
return std::hash<uint32\_t>()(static\_cast<uint32\_t>(s));
}
};
}
ArrayRef
namespace c10 {
using IntArrayRef = ArrayRef<int64\_t>;
template <typename T>
class ArrayRef final {
public:
using iterator = const T*;
using const_iterator = const T*;
using size_type = size\_t;
using value_type = T;
using reverse_iterator = std::reverse_iterator<iterator>;
private:
const T* Data;
size_type Length;
void debugCheckNullptrInvariant() {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
Data != nullptr || Length == 0,
"created ArrayRef with nullptr and non-zero length! c10::optional relies on this being illegal");
}
public:
constexpr ArrayRef() : Data(nullptr), Length(0) {}
constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {}
C10\_HOST\_CONSTEXPR\_EXCEPT\_WIN\_CUDA ArrayRef(const T* data, size\_t length)
: Data(data), Length(length) {
debugCheckNullptrInvariant();
}
C10\_HOST\_CONSTEXPR\_EXCEPT\_WIN\_CUDA ArrayRef(const T* begin, const T* end)
: Data(begin), Length(end - begin) {
debugCheckNullptrInvariant();
}
template <typename U>
ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
: Data(Vec.data()), Length(Vec.size()) {
debugCheckNullptrInvariant();
}
template <
typename Container,
typename = std::enable\_if\_t<std::is_same<
std::remove\_const\_t<decltype(std::declval<Container>().data())>,
T*>::value>>
ArrayRef(const Container& container)
: Data(container.data()), Length(container.size()) {
debugCheckNullptrInvariant();
}
template <typename A>
ArrayRef(const std::vector<T, A>& Vec)
: Data(Vec.data()), Length(Vec.size()) {
static\_assert(
!std::is_same<T, bool>::value,
"ArrayRef<bool> cannot be constructed from a std::vector<bool> bitfield.");
}
template <size\_t N>
constexpr ArrayRef(const std::array<T, N>& Arr)
: Data(Arr.data()), Length(N) {}
template <size\_t N>
constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {}
constexpr ArrayRef(const std::initializer\_list<T>& Vec)
: Data(
std::begin(Vec) == std::end(Vec) ? static\_cast<T*>(nullptr)
: std::begin(Vec)),
Length(Vec.size()) {}
...
};
}
maybe_wrap_dim
namespace c10 {
namespace detail {
int64\_t maybe\_wrap\_dim\_slow(
int64\_t dim,
int64\_t dim\_post\_expr,
bool wrap\_scalar) {
if (dim_post_expr <= 0) {
TORCH_CHECK_INDEX(
wrap_scalar,
"dimension specified as ",
dim,
" but tensor has no dimensions");
return c10::maybe_wrap_dim(dim, 1, false);
}
int64\_t min = -dim_post_expr;
int64\_t max = dim_post_expr - 1;
TORCH_CHECK_INDEX(
min <= dim && dim <= max,
"Dimension out of range (expected to be in range of [",
min,
", ",
max,
"], but got ",
dim,
")");
TORCH_INTERNAL_ASSERT(
false, "should never reach here as dim should be out-of-bounds");
}
}
}
dim_list_to_bitset
namespace at {
constexpr size\_t dim_bitset_size = 64;
static inline std::bitset<dim\_bitset\_size> dim\_list\_to\_bitset(
IntArrayRef dims,
int64\_t ndims) {
TORCH_CHECK(
ndims <= (int64\_t)dim_bitset_size,
"only tensors with up to ",
dim_bitset_size,
" dims are supported");
std::bitset<dim_bitset_size> seen;
for (const auto i : c10::irange(dims.size())) {
size\_t dim = maybe_wrap_dim(dims[i], ndims);
TORCH_CHECK(
!seen[dim], "dim ", dim, " appears multiple times in the list of dims");
seen[dim] = true;
}
return seen;
}
}
unsqueeze_multiple
// torch/csrc/autograd/FunctionsManual.cpp
Tensor unsqueeze\_multiple(const Tensor & t, IntArrayRef dim, size\_t n\_dims) {
auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims)
Tensor res = t
for(const auto i : c10::irange(n_dims)){
if (dims_to_unsqueeze[i]) {
res = res.unsqueeze(i)
}
}
return res
}
squeeze_multiple
// torch/csrc/autograd/FunctionsManual.cpp
static Tensor squeeze\_multiple(const Tensor& self, IntArrayRef dims) {
int ndims = self.sizes().size()
auto dims_to_squeeze = at::dim_list_to_bitset(dims, ndims)
Tensor result = self
for (int i = ndims - 1
if (dims_to_squeeze[i]) {
result = result.squeeze(i)
}
}
return result
}
logsumexp
Tensor logsumexp(const Tensor& self, DimnameList dims, bool keepdim) {
return at::logsumexp(self, dimnames_to_positions(self, dims), keepdim);
}
Tensor& logsumexp\_out(const Tensor& self, DimnameList dims, bool keepdim, Tensor& result) {
return at::logsumexp_out(result, self, dimnames_to_positions(self, dims), keepdim);
}
Tensor special\_logsumexp(const Tensor& self, IntArrayRef dims, bool keepdim) {
return self.logsumexp(dims, keepdim);
}
Tensor& special\_logsumexp\_out(const Tensor& self, IntArrayRef dims, bool keepdim, Tensor& result) {
return at::logsumexp_out(result, self, dims, keepdim);
}
inline Tensor logsumexp(const Tensor& self, IntArrayRef dims, bool keepdim) {
return torch::special_logsumexp(self, dims, keepdim);
}
inline Tensor& logsumexp\_out(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) {
return torch::special_logsumexp_out(result, self, dims, keepdim);
}
static inline bool isIntegralType(ScalarType t, bool includeBool) {
bool isIntegral =
(t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int ||
t == ScalarType::Long || t == ScalarType::Short);
return includeBool ? isIntegral || (t == ScalarType::Bool) : isIntegral;
}
Tensor logsumexp(const Tensor& self, IntArrayRef dims, bool keepdim) {
TensorOptions result_options;
if (at::isIntegralType(self.scalar_type(), true)) {
auto default_dtype = at::typeMetaToScalarType(c10::get_default_dtype());
result_options = self.options().dtype(default_dtype);
} else {
result_options = self.options();
}
auto result = at::empty({0}, result_options);
return at::logsumexp_outf(self, dims, keepdim, result);
}
logsumexp_outf
logsumexp_out
Tensor& logsumexp\_out(const Tensor& self, IntArrayRef dims, bool keepdim, Tensor& result) {
TORCH_CHECK(at::isFloatingType(result.scalar_type()),
"logsumexp(): Expected floating point type for result tensor, but got: ",
result.scalar_type());
{
NoNamesGuard guard;
if (at::isIntegralType(self.scalar_type(), true)) {
auto default_dtype = at::typeMetaToScalarType(c10::get_default_dtype());
logsumexp_out_impl(result, self.to(default_dtype), dims, keepdim);
} else {
logsumexp_out_impl(result, self, dims, keepdim);
}
}
namedinference::propagate_names_for_reduction(result, self, dims, keepdim);
return result;
}
logsumexp_out_impl
static Tensor& logsumexp\_out\_impl(Tensor& result, const Tensor& self, IntArrayRef dims, bool keepdim) {
if (self.numel() != 0) {
auto maxes = at::amax(self, dims, true);
auto maxes_squeezed = (keepdim ? maxes : squeeze_multiple(maxes, dims));
maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);
at::sum_out(result, (self - maxes).exp_(), dims, keepdim);
result.log_().add_(maxes_squeezed);
} else {
at::sum_out(result, at::exp(self), dims, keepdim);
result.log_();
}
return result;
}
logsumexp_backward
Tensor logsumexp\_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) {
if (!keepdim && self.dim() != 0) {
grad = unsqueeze_multiple(grad, dim, self.sizes().size());
result = unsqueeze_multiple(result, dim, self.sizes().size());
}
return grad * (self - result).exp();
}
logsumexp_jvp
Tensor logsumexp\_jvp(const Tensor& self\_p, const Tensor& self\_t, IntArrayRef dim, bool keepdim) {
auto self_p_exp = (self_p - at::amax(self_p, dim, true)).exp();
auto sumexp_p = self_p_exp.sum(dim, keepdim);
TORCH_INTERNAL_ASSERT(!self\_t._is_zerotensor())
if (areAnyTensorSubclassLike({self_p, self\_t})) {
auto result = (self_p_exp * self\_t).sum(dim, keepdim);
result /= sumexp_p;
return result;
} else {
self_p_exp *= self\_t;
auto sumexp\_t = self_p_exp.sum(dim, keepdim);
return sumexp\_t /= sumexp_p;
}
}
getOperatorAliasMap
const std::unordered\_map<Symbol, Symbol>& getOperatorAliasMap() {
static const std::unordered\_map<Symbol, Symbol> alias_map = {
{aten::absolute, aten::abs},
{aten::absolute_, aten::abs_},
{aten::clip, aten::clamp},
{aten::clip_, aten::clamp_},
{aten::det, aten::linalg_det},
{aten::matrix_power, aten::linalg_matrix_power},
{aten::matrix_exp, aten::linalg_matrix_exp},
{aten::ger, aten::outer},
{aten::arccos, aten::acos},
{aten::arccos_, aten::acos_},
{aten::arcsin, aten::asin},
{aten::arcsin_, aten::asin_},
{aten::arctan, aten::atan},
{aten::arctan_, aten::atan_},
{aten::arctan2, aten::atan2},
{aten::arctan2_, aten::atan2_},
{aten::arccosh, aten::acosh},
{aten::arccosh_, aten::acosh_},
{aten::arcsinh, aten::asinh},
{aten::arcsinh_, aten::asinh_},
{aten::arctanh, aten::atanh},
{aten::arctanh_, aten::atanh_},
{aten::fix, aten::trunc},
{aten::fix_, aten::trunc_},
{aten::negative, aten::neg},
{aten::negative_, aten::neg_},
{aten::subtract, aten::sub},
{aten::subtract_, aten::sub_},
{aten::greater_equal, aten::ge},
{aten::greater_equal_, aten::ge_},
{aten::greater, aten::gt},
{aten::greater_, aten::gt_},
{aten::less_equal, aten::le},
{aten::less_equal_, aten::le_},
{aten::less, aten::lt},
{aten::less_, aten::lt_},
{aten::not_equal, aten::ne},
{aten::not_equal_, aten::ne_},
{aten::divide, aten::div},
{aten::divide_, aten::div_},
{aten::multiply, aten::mul},
{aten::multiply_, aten::mul_},
{aten::linalg_matmul, aten::matmul},
{aten::true_divide, aten::div},
{aten::true_divide_, aten::div_},
{aten::concat, aten::cat},
{aten::row_stack, aten::vstack},
{aten::swapdims, aten::transpose},
{aten::swapdims_, aten::transpose_},
{aten::swapaxes, aten::transpose},
{aten::swapaxes_, aten::transpose_},
{aten::moveaxis, aten::movedim},
{aten::special_erf, aten::erf},
{aten::special_erfc, aten::erfc},
{aten::special_erfinv, aten::erfinv},
{aten::special_expit, aten::sigmoid},
{aten::special_exp2, aten::exp2},
{aten::special_expm1, aten::expm1},
{aten::special_logit, aten::logit},
{aten::special_logsumexp, aten::logsumexp},
{aten::special_round, aten::round},
{aten::special_log1p, aten::log1p},
{aten::special_sinc, aten::sinc},
{aten::special_digamma, aten::digamma},
{aten::special_psi, aten::digamma},
{aten::special_i0, aten::i0},
{aten::special_xlogy, aten::xlogy},
{aten::special_log_softmax, aten::log_softmax},
{aten::orgqr, aten::linalg_householder_product},
{aten::adjoint, aten::mH},
{aten::special_multigammaln, aten::mvlgamma},
{aten::special_polygamma, aten::polygamma},
{aten::special_softmax, aten::softmax},
{aten::special_gammainc, aten::igamma},
{aten::special_gammaincc, aten::igammac},
{aten::special_gammaln, aten::lgamma}};
return alias_map;
}
oneflow
class LogSumExpFunctor {
public:
LogSumExpFunctor() {}
Maybe<Tensor> operator()(const std::shared\_ptr<one::Tensor>& x, const std::vector<int32\_t>& axis,
const bool& keepdims) const {
if (x->ndim() == 0) {
return To(x, JUST(DType::Get(DataType::kFloat)), false);
} else if (x->nelement() == 0) {
std::shared\_ptr<one::Tensor> exp_out = JUST(Exp(x));
return Log(JUST(ReduceSum(exp_out, axis, keepdims)));
} else {
const std::shared\_ptr<one::Tensor>& maxes = JUST(Amax(x, axis, true));
const std::shared\_ptr<one::Tensor>& maxes_squeezed =
(keepdims ? maxes : JUST(SqueezeMultiple(maxes, axis)));
JUST(MaskedFillInplace(maxes_squeezed,
JUST(ScalarLogicalEqual(JUST(Abs(maxes_squeezed)), INFINITY)), 0));
std::shared\_ptr<one::Tensor> exp_out = JUST(Exp(JUST(Sub(x, maxes, 1, false))));
return Add(JUST(Log(JUST(ReduceSum(exp_out, axis, keepdims)))), maxes_squeezed, 1, false);
}
}
private:
Maybe<Tensor> SqueezeMultiple(const std::shared\_ptr<one::Tensor>& x,
const std::vector<int32\_t>& axis) const {
int ndims = x->ndim();
const auto& dims_to_squeeze = JUST(dim_list_to_bitset(axis, ndims));
std::shared\_ptr<one::Tensor> result = x;
for (int i = ndims - 1; i >= 0; --i) {
if ((*dims_to_squeeze)[i]) {
std::vector<int32\_t> dims = {i};
result = JUST(Squeeze(result, dims));
}
}
return result;
}
};
