RowSplitsToRowIds
本文看下GPU的实现,要理解需要先了解RowSplitsToRowIds
的功能。
std::pair<torch::Tensor, torch::optional<torch::Tensor>>
ComputeTransducerLossCuda(torch::Tensor &logits, // NOLINT
const torch::Tensor &targets,
const torch::Tensor &logit\_lengths,
const torch::Tensor ⌖\_lengths, int32\_t blank,
bool from\_log\_softmax, bool one\_sym\_per\_frame) {
torch::DeviceGuard device\_guard(logits.device());
// + 1 here since each sequence is prepended with a blank
torch::Tensor sizes = logit_lengths * (target_lengths + 1);
torch::Tensor row_splits = torch::cumsum(sizes, -1, torch::kInt);
torch::Tensor zero = torch::zeros({1}, row_splits.options());
row_splits = torch::cat({zero, row_splits}, -1);
torch::Tensor row_ids = RowSplitsToRowIds(row_splits, logits.size(0));
...
从上代码可知,row_splits
中存储的是
,第一个元素是 0. row_splits
表示的是各个 batch 对应的数据在内存中的offset
。
RowSplitsToRowIds
求的是各元素对应的batch index
值,长度为row_splits[-1]
, 即 sum_all_TU
。其实现如下:第二个输入是元素的个数,即sum_all_TU
. 返回值的个数是sum_all_TU
,对应值是按offset
分割的各个分段编号,即batch index
。
// See https://github.com/k2-fsa/k2/blob/master/k2/csrc/utils.cu#L75
// for the meaning of row splits and row IDs.
/**
@param row\_splits A 1-D tensor of dtype torch.int32. Its first
element should be zero.
@param num\_elems If -1, it is equal to row\_splits[-1].
If not -1, it must be equal to row\_splits[-1].
@return Return a 1-D tensor of dtype torch.int32. Its lengths
equals to num\_elems.
*/
torch::Tensor RowSplitsToRowIds(const torch::Tensor &row\_splits,
int32\_t num\_elems = -1) {
torch::CheckedFrom c = "RowSplitsToRowIds";
auto row_splits_arg = torch::TensorArg(row_splits, "row\_splits", 0);
torch::checkScalarType(c, row_splits_arg, torch::kInt32);
torch::checkDim(c, row_splits_arg, 1);
torch::checkContiguous(c, row_splits_arg);
int32\_t num_rows = row_splits.size(0) - 1;
const int32\_t *p_row_splits = row_splits.data_ptr<int32\_t>();
if (num_elems == -1) {
num_elems = row_splits.cpu().data_ptr<int32\_t>()[num_rows];
}
torch::Tensor row_ids = torch::empty({num_elems}, row_splits.options());
ModernGpuAllocator allocator;
mgpu::load_balance_search(num_elems, p_row_splits, num_rows,
row_ids.data_ptr<int32\_t>(), allocator);
return row_ids;
}
RowSplitsToRowIds
是通过mgpu::load_balance_search
实现的。其例子如下:
# https://github.com/moderngpu/moderngpu/blob/master/tests/test\_load\_balance.cu
#include <moderngpu/kernel\_load\_balance.hxx>
using namespace mgpu;
int main(int argc, char** argv) {
standard\_context\_t context;
int count = 200030;
int spacing = 100;
int num_segments = div_up(count, spacing);
std::vector<int> segments\_host(num\_segments);
for(int i = 0; i < num_segments; ++i)
segments_host[i] = i * spacing;
mem\_t<int> segments = to_mem(segments_host, context);
mem\_t<int> lbs(count, context);
load_balance_search(count, segments.data(), num_segments, lbs.data(),
context);
std::vector<int> lbs_host = from_mem(lbs);
for(size\_t i = 0; i < lbs_host.size(); ++i) {
printf("%4d: %3d\n", (int)i, lbs_host[i]);
if(lbs_host[i] != i / spacing) {
printf("ERROR AT %d\n", (int)i);
exit(0);
}
}
return 0;
}
输出如下:
95: 0
96: 0
97: 0
98: 0
99: 0
100: 1
101: 1
102: 1
103: 1
104: 1
105: 1
106: 1
107: 1
108: 1
明白RowSplitsToRowIds
的逻辑后,后面的代码就容易理解了。
moderngpu
再 k2
中也有使用,后面看下这个 repo
的实现。
ComputeLogProbs
\_\_global\_\_ void ComputeLogProbs(const float *logits, const float *denominator,
const int32\_t *targets,
const int32\_t *target\_lengths, int32\_t blank,
const int32\_t *row\_splits,
const int32\_t *row\_ids, int32\_t sum\_all\_TU,
int32\_t vocab\_size, int32\_t targets\_col,
float *log\_probs) {
int32\_t idx01 = blockDim.x * blockIdx.x + threadIdx.x;
if (idx01 >= sum_all_TU) return; // out-of-boundary
int32\_t b = row_ids[idx01]; // batch index
// +1 since it is prepended with a blank
int32\_t U_p1 = target_lengths[b] + 1;
int32\_t offset = row_splits[b];
// 从当前batch offset 开始计算索引,左负,右正。
int32\_t idx1 = idx01 - offset;
// 不是当前batch中的thread,也会映射到正确的 u 上。
// -9 % 10 = 1
int32\_t u = idx1 % U_p1;
const float *p_logits = logits + idx01 * vocab_size;
const float *p_denominator = denominator + idx01;
const int32\_t *p_targets = targets + b * targets_col;
float d = *p_denominator;
float *p_log_probs = log_probs + idx01 * 2;
p_log_probs[kBlankCol] = p_logits[blank] - d;
if (u < U_p1 - 1) {
p_log_probs[kSymCol] = p_logits[p_targets[u]] - d;
}
}
RNN-T Loss
前向计算
// This function is based on
// https://github.com/HawkAaron/warp-transducer/blob/master/include/detail/gpu\_rnnt\_kernel.h#L12
// Call it like <<<batch\_size, maxU>>>
\_\_global\_\_ void ComputeAlpha(const float *log\_probs,
const int32\_t *logit\_lengths,
const int32\_t *target\_lengths,
const int32\_t *row\_splits, int32\_t max\_T,
int32\_t max\_U\_p1, int32\_t *counter, float *alpha,
float *total\_scores) {
int32\_t b = blockIdx.x;
int32\_t u = threadIdx.x;
int32\_t T = logit_lengths[b];
int32\_t U_p1 = target_lengths[b] + 1;
int32\_t offset = row_splits[b];
float *p_alpha = alpha + offset;
const float *p_log_probs = log_probs + offset * 2;
if (u == 0) {
p_alpha[0] = 0;
}
__syncthreads();
for (int32\_t n = 1; n < T + U_p1 - 1; ++n) {
int32\_t t = n - u;
float *p\_alpha\_t = p_alpha + t * U_p1;
float *p_alpha_t_m1 = p_alpha + (t - 1) * U_p1;
const float *p\_log\_probs\_t = p_log_probs + t * U_p1 * 2;
const float *p_log_probs_t_m1 = p_log_probs + (t - 1) * U_p1 * 2;
if (u == 0) {
if (t > 0 && t < T) {
// when u = 0, alpha(t, 0) = alpha(t-1, 0) + log\_probs(t-1, 0).blank
*p\_alpha\_t = *p_alpha_t_m1 + p_log_probs_t_m1[kBlankCol];
}
} else if (u < U_p1) {
if (t == 0) {
// when t = 0, alpha(0, u) = alpha(0, u-1) + log\_probs(0, u-1).symbol
p_alpha[u] = p_alpha[u - 1] + (p_log_probs + (u - 1) * 2)[kSymCol];
} else if (t > 0 && t < T) {
// alpha(t, u) = log\_sum\_exp(alpha(t-1,u) + log\_probs(t-1, u).blank,
// alpha(t, u-1) + log\_probs(t, u-1).symbol)
float skip_prob =
p_alpha_t_m1[u] + (p_log_probs_t_m1 + u * 2)[kBlankCol];
float emit_prob =
p\_alpha\_t[u - 1] + (p\_log\_probs\_t + (u - 1) * 2)[kSymCol];
p\_alpha\_t[u] = LogAdd(skip_prob, emit_prob);
}
}
__syncthreads();
}
if (u == 0) {
total_scores[b] = *(p_alpha + T * U_p1 - 1) +
(p_log_probs + (T * U_p1 - 1) * 2)[kBlankCol];
}
}
反向计算
/ It is based on
// https://github.com/HawkAaron/warp-transducer/blob/master/include/detail/gpu\_rnnt\_kernel.h#L80
//
// Call it like <<<batch\_size, max\_U\_p1>>>
\_\_global\_\_ void ComputeBeta(const float *log\_probs,
const int32\_t *logit\_lengths,
const int32\_t *target\_lengths,
const int32\_t *row\_splits, int32\_t max\_T,
int32\_t max\_U\_p1, int32\_t *counter, float *beta) {
int32\_t b = blockIdx.x;
int32\_t u = threadIdx.x;
int32\_t T = logit_lengths[b];
int32\_t U_p1 = target_lengths[b] + 1;
int32\_t offset = row_splits[b];
float *p_beta = beta + offset;
const float *p_log_probs = log_probs + offset * 2;
if (u == 0) {
(p_beta + T * U_p1)[-1] = (p_log_probs + T * U_p1 * 2 - 2)[kBlankCol];
}
__syncthreads();
for (int32\_t n = T + U_p1 - 2; n >= 0; --n) {
int32\_t t = n - u;
float *p\_beta\_t = p_beta + t * U_p1;
float *p_beta_t_p1 = p_beta + (t + 1) * U_p1;
const float *p\_log\_probs\_t = p_log_probs + t * U_p1 * 2;
if (u == U_p1 - 1) {
if (t >= 0 && t < T - 1) {
// when u = U\_p1 - 1,
// beta(t, U\_p1-1) = beta(t+1, U\_p1-1) + lop\_probs(t, U\_p1-1).blank
p\_beta\_t[U_p1 - 1] =
p_beta_t_p1[U_p1 - 1] + (p\_log\_probs\_t + (U_p1 - 1) * 2)[kBlankCol];
}
} else if (u < U_p1) {
if (t == T - 1) {
// when t = T - 1,
// beta(T-1 u) = beta(T-1, u+1) + log\_probs(T-1, u).symbol
(p_beta + (T - 1) * U_p1)[u] =
(p_beta + (T - 1) * U_p1)[u + 1] +
(p_log_probs + ((T - 1) * U_p1 + u) * 2)[kSymCol];
} else if (t >= 0 && t < T - 1) {
// beta(t, u) = log\_sum\_exp(beta(t+1,u) + log\_probs(t, u).blank,
// beta(t, u+1) + log\_probs(t, u).symbol)
float skip_prob = p_beta_t_p1[u] + (p\_log\_probs\_t + u * 2)[kBlankCol];
float emit_prob = p\_beta\_t[u + 1] + (p\_log\_probs\_t + u * 2)[kSymCol];
p\_beta\_t[u] = LogAdd(skip_prob, emit_prob);
}
}
__syncthreads();
}
}
梯度计算
\_\_global\_\_ void ComputeGradient(
const float *logits, const float *denominator, const int32\_t *targets,
const int32\_t *logit\_lengths, const int32\_t *target\_lengths, int32\_t blank,
const int32\_t *row\_splits, const int32\_t *row\_ids, int32\_t sum\_all\_TU,
int32\_t vocab\_size, int32\_t targets\_col, const float *alpha,
const float *beta, float *gradient) {
int32\_t idx01 = blockDim.x * blockIdx.x + threadIdx.x;
if (idx01 >= sum_all_TU) return; // out-of-boundary
int32\_t b = row_ids[idx01]; // batch index
// +1 since it is prepended with a blank
int32\_t U_p1 = target_lengths[b] + 1;
int32\_t T = logit_lengths[b];
int32\_t offset = row_splits[b];
// 计算当前batch中的 t,u index
int32\_t idx1 = idx01 - offset;
int32\_t t = idx1 / U_p1;
int32\_t u = idx1 % U_p1;
const float *p_logits_t_u = logits + idx01 * vocab_size;
const float *p_denominator = denominator + offset;
const float *p\_denominator\_t = p_denominator + t * U_p1;
const int32\_t *p_targets = targets + b * targets_col;
const float *p_alpha = alpha + offset;
const float *p\_alpha\_t = p_alpha + t * U_p1;
const float *p_beta = beta + offset;
const float *p\_beta\_t = p_beta + t * U_p1;
const float *p_beta_t_p1 = p_beta + (t + 1) * U_p1;
float *p_grad_t_u = gradient + idx01 * vocab_size;
// nll
float loss = -1 * p_beta[0];
if (isinf(loss) || isnan(loss)) {
for (int32\_t v = 0; v != vocab_size; ++v) {
p_grad_t_u[v] = 0;
}
return;
}
float c = p\_alpha\_t[u] + loss - p\_denominator\_t[u];
int32\_t target_u = (u < U_p1 - 1) ? p_targets[u] : -1; // -1 is not used
// TODO(fangjun): Use separate threads to compute the gradient
// so that we don't have a `for` loop here
for (int32\_t v = 0; v != vocab_size; ++v) {
float g = p_logits_t_u[v] + c;
float val = 0;
if (v == blank && t == T - 1 && u == U_p1 - 1) {
// last blank transition
val = expf(g + p\_beta\_t[u]) - expf(g);
} else if (v == blank && t < T - 1) {
val = expf(g + p\_beta\_t[u]) - expf(g + p_beta_t_p1[u]);
} else if (u < U_p1 - 1 && v == target_u) {
val = expf(g + p\_beta\_t[u]) - expf(g + p\_beta\_t[u + 1]);
} else {
val = expf(g + p\_beta\_t[u]);
}
p_grad_t_u[v] = val;
}
}
RNA Loss
前向计算
// Call it like <<<batch\_size, maxU>>>
\_\_global\_\_ void ComputeAlphaOneSymPerFrame(
const float *log\_probs, const int32\_t *logit\_lengths,
const int32\_t *target\_lengths, const int32\_t *row\_splits, int32\_t max\_T,
int32\_t max\_U\_p1, int32\_t *counter, float *alpha, float *total\_scores) {
int32\_t b = blockIdx.x;
int32\_t u = threadIdx.x;
int32\_t T = logit_lengths[b];
int32\_t U_p1 = target_lengths[b] + 1;
int32\_t diff = T - 1 - (U_p1 - 1);
int32\_t offset = row_splits[b];
float *p_alpha = alpha + offset;
const float *p_log_probs = log_probs + offset * 2;
if (u == 0) {
p_alpha[0] = 0;
}
__syncthreads();
for (int32\_t n = 1; n < T + U_p1 - 1; ++n) {
int32\_t t = n - u;
if (u <= t && t - u <= diff) {
float *p\_alpha\_t = p_alpha + t * U_p1;
float *p_alpha_t_m1 = p_alpha + (t - 1) * U_p1;
const float *p_log_probs_t_m1 = p_log_probs + (t - 1) * U_p1 * 2;
if (u == 0) {
if (t > 0 && t <= diff) {
// when u = 0, alpha(t, 0) = alpha(t-1, 0) + log\_probs(t-1, 0).blank
*p\_alpha\_t = *p_alpha_t_m1 + p_log_probs_t_m1[kBlankCol];
}
} else if (u < U_p1) {
if (t == u) {
// diagonal
// alpha(t, u) = alpha(t-1, u-1) + log\_probs(t-1, u-1).symbol
p\_alpha\_t[u] =
p_alpha_t_m1[u - 1] + (p_log_probs_t_m1 + (u - 1) * 2)[kSymCol];
} else {
// alpha(t, u) = log\_sum\_exp(alpha(t-1, u) + log\_probs(t-1, u).blank,
// alpha(t-1, u-1) + log\_probs(t-1, u-1).symbol)
float skip_prob =
p_alpha_t_m1[u] + (p_log_probs_t_m1 + u * 2)[kBlankCol];
float emit_prob =
p_alpha_t_m1[u - 1] + (p_log_probs_t_m1 + (u - 1) * 2)[kSymCol];
p\_alpha\_t[u] = LogAdd(skip_prob, emit_prob);
// p\_alpha\_t[u] = 0;
}
}
}
__syncthreads();
}
if (u == 0) {
total_scores[b] = *(p_alpha + T * U_p1 - 1) +
(p_log_probs + (T * U_p1 - 1) * 2)[kBlankCol];
}
}
反向计算
// Call it like <<<batch\_size, max\_U\_p1>>>
\_\_global\_\_ void ComputeBetaOneSymPerFrame(const float *log\_probs,
const int32\_t *logit\_lengths,
const int32\_t *target\_lengths,
const int32\_t *row\_splits,
int32\_t max\_T, int32\_t max\_U\_p1,
int32\_t *counter, float *beta) {
int32\_t b = blockIdx.x;
int32\_t u = threadIdx.x;
int32\_t T = logit_lengths[b];
int32\_t U_p1 = target_lengths[b] + 1;
int32\_t offset = row_splits[b];
float *p_beta = beta + offset;
const float *p_log_probs = log_probs + offset * 2;
if (u == 0) {
(p_beta + T * U_p1)[-1] = (p_log_probs + T * U_p1 * 2 - 2)[kBlankCol];
}
__syncthreads();
int32\_t diff = T - 1 - (U_p1 - 1);
for (int32\_t n = T + U_p1 - 2; n >= 0; --n) {
int32\_t t = n - u;
float *p\_beta\_t = p_beta + t * U_p1;
float *p_beta_t_p1 = p_beta + (t + 1) * U_p1;
const float *p\_log\_probs\_t = p_log_probs + t * U_p1 * 2;
if (u == U_p1 - 1) {
// beta(t, U\_p1-1) = beta(t+1, U\_p1-1) + log\_probs(t, U\_p1-1).blank
if (u <= t && t < T - 1) {
p\_beta\_t[U_p1 - 1] =
p_beta_t_p1[U_p1 - 1] + (p\_log\_probs\_t + (U_p1 - 1) * 2)[kBlankCol];
}
} else if (u < U_p1) {
if (u <= t && t - u <= diff) {
if (t - u == diff) {
// beta(t, u) = beta(t+1, u+1) + log\_probs(t, u).symbol
p\_beta\_t[u] = p_beta_t_p1[u + 1] + (p\_log\_probs\_t + u * 2)[kSymCol];
} else {
// beta(t, u) = log\_sum\_exp(beta(t+1, u) + log\_probs(t, u).blank,
// beta(t+1, u+1) + log\_probs(t, u).symbol)
float skip_prob = p_beta_t_p1[u] + (p\_log\_probs\_t + u * 2)[kBlankCol];
float emit_prob =
p_beta_t_p1[u + 1] + (p\_log\_probs\_t + u * 2)[kSymCol];
p\_beta\_t[u] = LogAdd(skip_prob, emit_prob);
}
}
}
__syncthreads();
}
}
梯度计算
\_\_global\_\_ void ComputeGradientOneSymPerFrame(
const float *logits, const float *denominator, const int32\_t *targets,
const int32\_t *logit\_lengths, const int32\_t *target\_lengths, int32\_t blank,
const int32\_t *row\_splits, const int32\_t *row\_ids, int32\_t sum\_all\_TU,
int32\_t vocab\_size, int32\_t targets\_col, const float *alpha,
const float *beta, float *gradient) {
int32\_t idx01 = blockDim.x * blockIdx.x + threadIdx.x;
if (idx01 >= sum_all_TU) return; // out-of-boundary
int32\_t b = row_ids[idx01]; // batch index
// +1 since it is prepended with a blank
int32\_t U_p1 = target_lengths[b] + 1;
int32\_t T = logit_lengths[b];
int32\_t offset = row_splits[b];
// 计算当前batch中的 t,u index
int32\_t idx1 = idx01 - offset;
int32\_t t = idx1 / U_p1;
int32\_t u = idx1 % U_p1;
const float *p_logits_t_u = logits + idx01 * vocab_size;
const float *p_denominator = denominator + offset;
const float *p\_denominator\_t = p_denominator + t * U_p1;
const int32\_t *p_targets = targets + b * targets_col;
const float *p_alpha = alpha + offset;
const float *p\_alpha\_t = p_alpha + t * U_p1;
const float *p_beta = beta + offset;
const float *p\_beta\_t = p_beta + t * U_p1;
const float *p_beta_t_p1 = p_beta + (t + 1) * U_p1;
float *p_grad_t_u = gradient + idx01 * vocab_size;
// nll
float loss = -1 * p_beta[0];
int32\_t diff = T - 1 - (U_p1 - 1);
if (u > t || t - u > diff || isinf(loss) || isnan(loss)) {
for (int32\_t v = 0; v != vocab_size; ++v) {
p_grad_t_u[v] = 0;
}
return;
}
float c = p\_alpha\_t[u] + loss - p\_denominator\_t[u];
int32\_t target_u = (u < U_p1 - 1) ? p_targets[u] : -1; // -1 is not used
for (int32\_t v = 0; v != vocab_size; ++v) {
float g = p_logits_t_u[v] + c;
float val = 0;
if (v == blank && t == T - 1 && u == U_p1 - 1) {
// last blank transition
val = expf(g + p\_beta\_t[u]) - expf(g);
} else if (v == blank && t - u < diff) {
val = expf(g + p\_beta\_t[u]) - expf(g + p_beta_t_p1[u]);
} else if (u < U_p1 - 1 && v == target_u) {
val = expf(g + p\_beta\_t[u]) - expf(g + p_beta_t_p1[u + 1]);
} else {
val = expf(g + p\_beta\_t[u]);
}
p_grad_t_u[v] = val;
} // end v
}
参考资料
- https://github.com/csukuangfj/optimized\_transducer
- https://arxiv.org/abs/1211.3711
- https://lorenlugosch.github.io/posts/2020/11/transducer/
欢迎关注公众号