RNN-T/RNA Loss GPU 实现

技术

RowSplitsToRowIds

本文看下GPU的实现,要理解需要先了解RowSplitsToRowIds的功能。

  
std::pair<torch::Tensor, torch::optional<torch::Tensor>>  
ComputeTransducerLossCuda(torch::Tensor &logits,  // NOLINT  
                          const torch::Tensor &targets,  
                          const torch::Tensor &logit\_lengths,  
                          const torch::Tensor ⌖\_lengths, int32\_t blank,  
                          bool from\_log\_softmax, bool one\_sym\_per\_frame) {  
  torch::DeviceGuard device\_guard(logits.device());  
  
  // + 1 here since each sequence is prepended with a blank  
  torch::Tensor sizes = logit_lengths * (target_lengths + 1);  
  torch::Tensor row_splits = torch::cumsum(sizes, -1, torch::kInt);  
  torch::Tensor zero = torch::zeros({1}, row_splits.options());  
  row_splits = torch::cat({zero, row_splits}, -1);  
  torch::Tensor row_ids = RowSplitsToRowIds(row_splits, logits.size(0));  
    
  ...  

从上代码可知,row_splits 中存储的是

,第一个元素是 0. row_splits 表示的是各个 batch 对应的数据在内存中的offset

RowSplitsToRowIds 求的是各元素对应的batch index值,长度为row_splits[-1], 即 sum_all_TU。其实现如下:第二个输入是元素的个数,即sum_all_TU. 返回值的个数是sum_all_TU,对应值是按offset分割的各个分段编号,即batch index

  
// See https://github.com/k2-fsa/k2/blob/master/k2/csrc/utils.cu#L75  
// for the meaning of row splits and row IDs.  
/**  
  
  @param row\_splits  A 1-D tensor of dtype torch.int32. Its first  
                     element should be zero.  
  @param num\_elems   If -1, it is equal to row\_splits[-1].  
                     If not -1, it must be equal to row\_splits[-1].  
  
  @return Return a 1-D tensor of dtype torch.int32. Its lengths  
          equals to num\_elems.  
 */  
torch::Tensor RowSplitsToRowIds(const torch::Tensor &row\_splits,  
                                int32\_t num\_elems = -1) {  
  torch::CheckedFrom c = "RowSplitsToRowIds";  
  auto row_splits_arg = torch::TensorArg(row_splits, "row\_splits", 0);  
  torch::checkScalarType(c, row_splits_arg, torch::kInt32);  
  torch::checkDim(c, row_splits_arg, 1);  
  torch::checkContiguous(c, row_splits_arg);  
  
  int32\_t num_rows = row_splits.size(0) - 1;  
  const int32\_t *p_row_splits = row_splits.data_ptr<int32\_t>();  
  if (num_elems == -1) {  
    num_elems = row_splits.cpu().data_ptr<int32\_t>()[num_rows];  
  }  
  
  torch::Tensor row_ids = torch::empty({num_elems}, row_splits.options());  
  ModernGpuAllocator allocator;  
  mgpu::load_balance_search(num_elems, p_row_splits, num_rows,  
                            row_ids.data_ptr<int32\_t>(), allocator);  
  return row_ids;  
}  

RowSplitsToRowIds是通过mgpu::load_balance_search实现的。其例子如下:

  
# https://github.com/moderngpu/moderngpu/blob/master/tests/test\_load\_balance.cu  
#include <moderngpu/kernel\_load\_balance.hxx>  
  
using namespace mgpu;  
  
int main(int argc, char** argv) {  
  
  standard\_context\_t context;  
  
  int count = 200030;  
  int spacing = 100;  
  
  int num_segments = div_up(count, spacing);  
  std::vector<int> segments\_host(num\_segments);  
  for(int i = 0; i < num_segments; ++i)  
    segments_host[i] = i * spacing;  
  mem\_t<int> segments = to_mem(segments_host, context);  
  
  mem\_t<int> lbs(count, context);  
  load_balance_search(count, segments.data(), num_segments, lbs.data(),   
    context);  
  
  std::vector<int> lbs_host = from_mem(lbs);  
  for(size\_t i = 0; i < lbs_host.size(); ++i) {  
    printf("%4d: %3d\n", (int)i, lbs_host[i]);  
    if(lbs_host[i] != i / spacing) {  
      printf("ERROR AT %d\n", (int)i);  
      exit(0);  
    }  
  }  
  
  return 0;  
}  

输出如下:

  
  95:   0  
  96:   0  
  97:   0  
  98:   0  
  99:   0  
 100:   1  
 101:   1  
 102:   1  
 103:   1  
 104:   1  
 105:   1  
 106:   1  
 107:   1  
 108:   1  

明白RowSplitsToRowIds的逻辑后,后面的代码就容易理解了。

moderngpuk2中也有使用,后面看下这个 repo 的实现。

ComputeLogProbs

  
\_\_global\_\_ void ComputeLogProbs(const float *logits, const float *denominator,  
                                const int32\_t *targets,  
                                const int32\_t *target\_lengths, int32\_t blank,  
                                const int32\_t *row\_splits,  
                                const int32\_t *row\_ids, int32\_t sum\_all\_TU,  
                                int32\_t vocab\_size, int32\_t targets\_col,  
                                float *log\_probs) {  
  int32\_t idx01 = blockDim.x * blockIdx.x + threadIdx.x;  
  if (idx01 >= sum_all_TU) return;  // out-of-boundary  
  
  int32\_t b = row_ids[idx01];  // batch index  
  
  // +1 since it is prepended with a blank  
  int32\_t U_p1 = target_lengths[b] + 1;  
  int32\_t offset = row_splits[b];  
  // 从当前batch offset 开始计算索引,左负,右正。  
  int32\_t idx1 = idx01 - offset;  
  
  // 不是当前batch中的thread,也会映射到正确的 u 上。  
  // -9 % 10 = 1  
  int32\_t u = idx1 % U_p1;  
  
  const float *p_logits = logits + idx01 * vocab_size;  
  const float *p_denominator = denominator + idx01;  
  const int32\_t *p_targets = targets + b * targets_col;  
  
  float d = *p_denominator;  
  
  float *p_log_probs = log_probs + idx01 * 2;  
  p_log_probs[kBlankCol] = p_logits[blank] - d;  
  if (u < U_p1 - 1) {  
    p_log_probs[kSymCol] = p_logits[p_targets[u]] - d;  
  }  
}  

RNN-T Loss

前向计算

  
// This function is based on  
// https://github.com/HawkAaron/warp-transducer/blob/master/include/detail/gpu\_rnnt\_kernel.h#L12  
// Call it like <<<batch\_size, maxU>>>  
\_\_global\_\_ void ComputeAlpha(const float *log\_probs,  
                             const int32\_t *logit\_lengths,  
                             const int32\_t *target\_lengths,  
                             const int32\_t *row\_splits, int32\_t max\_T,  
                             int32\_t max\_U\_p1, int32\_t *counter, float *alpha,  
                             float *total\_scores) {  
  int32\_t b = blockIdx.x;  
  int32\_t u = threadIdx.x;  
  int32\_t T = logit_lengths[b];  
  int32\_t U_p1 = target_lengths[b] + 1;  
  
  int32\_t offset = row_splits[b];  
  float *p_alpha = alpha + offset;  
  const float *p_log_probs = log_probs + offset * 2;  
  
  if (u == 0) {  
    p_alpha[0] = 0;  
  }  
  
  __syncthreads();  
  
  for (int32\_t n = 1; n < T + U_p1 - 1; ++n) {  
    int32\_t t = n - u;  
    float *p\_alpha\_t = p_alpha + t * U_p1;  
    float *p_alpha_t_m1 = p_alpha + (t - 1) * U_p1;  
    const float *p\_log\_probs\_t = p_log_probs + t * U_p1 * 2;  
    const float *p_log_probs_t_m1 = p_log_probs + (t - 1) * U_p1 * 2;  
    if (u == 0) {  
      if (t > 0 && t < T) {  
        // when u = 0, alpha(t, 0) = alpha(t-1, 0) + log\_probs(t-1, 0).blank  
        *p\_alpha\_t = *p_alpha_t_m1 + p_log_probs_t_m1[kBlankCol];  
      }  
    } else if (u < U_p1) {  
      if (t == 0) {  
        // when t = 0, alpha(0, u) = alpha(0, u-1) + log\_probs(0, u-1).symbol  
        p_alpha[u] = p_alpha[u - 1] + (p_log_probs + (u - 1) * 2)[kSymCol];  
      } else if (t > 0 && t < T) {  
        // alpha(t, u) = log\_sum\_exp(alpha(t-1,u) + log\_probs(t-1, u).blank,  
        //                           alpha(t, u-1) + log\_probs(t, u-1).symbol)  
        float skip_prob =  
            p_alpha_t_m1[u] + (p_log_probs_t_m1 + u * 2)[kBlankCol];  
        float emit_prob =  
            p\_alpha\_t[u - 1] + (p\_log\_probs\_t + (u - 1) * 2)[kSymCol];  
        p\_alpha\_t[u] = LogAdd(skip_prob, emit_prob);  
      }  
    }  
    __syncthreads();  
  }  
  
  if (u == 0) {  
    total_scores[b] = *(p_alpha + T * U_p1 - 1) +  
                      (p_log_probs + (T * U_p1 - 1) * 2)[kBlankCol];  
  }  
}  

反向计算

  
/ It is based on  
// https://github.com/HawkAaron/warp-transducer/blob/master/include/detail/gpu\_rnnt\_kernel.h#L80  
//  
// Call it like <<<batch\_size, max\_U\_p1>>>  
\_\_global\_\_ void ComputeBeta(const float *log\_probs,  
                            const int32\_t *logit\_lengths,  
                            const int32\_t *target\_lengths,  
                            const int32\_t *row\_splits, int32\_t max\_T,  
                            int32\_t max\_U\_p1, int32\_t *counter, float *beta) {  
  int32\_t b = blockIdx.x;  
  int32\_t u = threadIdx.x;  
  int32\_t T = logit_lengths[b];  
  int32\_t U_p1 = target_lengths[b] + 1;  
  
  int32\_t offset = row_splits[b];  
  float *p_beta = beta + offset;  
  const float *p_log_probs = log_probs + offset * 2;  
  
  if (u == 0) {  
    (p_beta + T * U_p1)[-1] = (p_log_probs + T * U_p1 * 2 - 2)[kBlankCol];  
  }  
  
  __syncthreads();  
  
  for (int32\_t n = T + U_p1 - 2; n >= 0; --n) {  
    int32\_t t = n - u;  
    float *p\_beta\_t = p_beta + t * U_p1;  
    float *p_beta_t_p1 = p_beta + (t + 1) * U_p1;  
    const float *p\_log\_probs\_t = p_log_probs + t * U_p1 * 2;  
    if (u == U_p1 - 1) {  
      if (t >= 0 && t < T - 1) {  
        // when u = U\_p1 - 1,  
        // beta(t, U\_p1-1) = beta(t+1, U\_p1-1) + lop\_probs(t, U\_p1-1).blank  
        p\_beta\_t[U_p1 - 1] =  
            p_beta_t_p1[U_p1 - 1] + (p\_log\_probs\_t + (U_p1 - 1) * 2)[kBlankCol];  
      }  
    } else if (u < U_p1) {  
      if (t == T - 1) {  
        // when t = T - 1,  
        // beta(T-1 u) =  beta(T-1, u+1) + log\_probs(T-1, u).symbol  
        (p_beta + (T - 1) * U_p1)[u] =  
            (p_beta + (T - 1) * U_p1)[u + 1] +  
            (p_log_probs + ((T - 1) * U_p1 + u) * 2)[kSymCol];  
      } else if (t >= 0 && t < T - 1) {  
        // beta(t, u) = log\_sum\_exp(beta(t+1,u) + log\_probs(t, u).blank,  
        //                           beta(t, u+1) + log\_probs(t, u).symbol)  
        float skip_prob = p_beta_t_p1[u] + (p\_log\_probs\_t + u * 2)[kBlankCol];  
        float emit_prob = p\_beta\_t[u + 1] + (p\_log\_probs\_t + u * 2)[kSymCol];  
        p\_beta\_t[u] = LogAdd(skip_prob, emit_prob);  
      }  
    }  
    __syncthreads();  
  }  
}  

梯度计算

  
\_\_global\_\_ void ComputeGradient(  
    const float *logits, const float *denominator, const int32\_t *targets,  
    const int32\_t *logit\_lengths, const int32\_t *target\_lengths, int32\_t blank,  
    const int32\_t *row\_splits, const int32\_t *row\_ids, int32\_t sum\_all\_TU,  
    int32\_t vocab\_size, int32\_t targets\_col, const float *alpha,  
    const float *beta, float *gradient) {  
  int32\_t idx01 = blockDim.x * blockIdx.x + threadIdx.x;  
  if (idx01 >= sum_all_TU) return;  // out-of-boundary  
  
  int32\_t b = row_ids[idx01];  // batch index  
  
  // +1 since it is prepended with a blank  
  int32\_t U_p1 = target_lengths[b] + 1;  
  int32\_t T = logit_lengths[b];  
  int32\_t offset = row_splits[b];  
  
  // 计算当前batch中的 t,u index  
  int32\_t idx1 = idx01 - offset;  
  int32\_t t = idx1 / U_p1;  
  int32\_t u = idx1 % U_p1;  
  
  const float *p_logits_t_u = logits + idx01 * vocab_size;  
  const float *p_denominator = denominator + offset;  
  const float *p\_denominator\_t = p_denominator + t * U_p1;  
  const int32\_t *p_targets = targets + b * targets_col;  
  
  const float *p_alpha = alpha + offset;  
  const float *p\_alpha\_t = p_alpha + t * U_p1;  
  
  const float *p_beta = beta + offset;  
  const float *p\_beta\_t = p_beta + t * U_p1;  
  const float *p_beta_t_p1 = p_beta + (t + 1) * U_p1;  
  
  float *p_grad_t_u = gradient + idx01 * vocab_size;  
  
  // nll  
  float loss = -1 * p_beta[0];  
  
  if (isinf(loss) || isnan(loss)) {  
    for (int32\_t v = 0; v != vocab_size; ++v) {  
      p_grad_t_u[v] = 0;  
    }  
    return;  
  }  
  
  float c = p\_alpha\_t[u] + loss - p\_denominator\_t[u];  
  
  int32\_t target_u = (u < U_p1 - 1) ? p_targets[u] : -1;  // -1 is not used  
  
  // TODO(fangjun): Use separate threads to compute the gradient  
  // so that we don't have a `for` loop here  
  for (int32\_t v = 0; v != vocab_size; ++v) {  
    float g = p_logits_t_u[v] + c;  
    float val = 0;  
    if (v == blank && t == T - 1 && u == U_p1 - 1) {  
      // last blank transition  
      val = expf(g + p\_beta\_t[u]) - expf(g);  
    } else if (v == blank && t < T - 1) {  
      val = expf(g + p\_beta\_t[u]) - expf(g + p_beta_t_p1[u]);  
    } else if (u < U_p1 - 1 && v == target_u) {  
      val = expf(g + p\_beta\_t[u]) - expf(g + p\_beta\_t[u + 1]);  
    } else {  
      val = expf(g + p\_beta\_t[u]);  
    }  
  
    p_grad_t_u[v] = val;  
  }  
}  

RNA Loss

前向计算

  
// Call it like <<<batch\_size, maxU>>>  
\_\_global\_\_ void ComputeAlphaOneSymPerFrame(  
    const float *log\_probs, const int32\_t *logit\_lengths,  
    const int32\_t *target\_lengths, const int32\_t *row\_splits, int32\_t max\_T,  
    int32\_t max\_U\_p1, int32\_t *counter, float *alpha, float *total\_scores) {  
  int32\_t b = blockIdx.x;  
  int32\_t u = threadIdx.x;  
  int32\_t T = logit_lengths[b];  
  int32\_t U_p1 = target_lengths[b] + 1;  
  
  int32\_t diff = T - 1 - (U_p1 - 1);  
  
  int32\_t offset = row_splits[b];  
  float *p_alpha = alpha + offset;  
  const float *p_log_probs = log_probs + offset * 2;  
  
  if (u == 0) {  
    p_alpha[0] = 0;  
  }  
  
  __syncthreads();  
  
  for (int32\_t n = 1; n < T + U_p1 - 1; ++n) {  
    int32\_t t = n - u;  
    if (u <= t && t - u <= diff) {  
      float *p\_alpha\_t = p_alpha + t * U_p1;  
      float *p_alpha_t_m1 = p_alpha + (t - 1) * U_p1;  
      const float *p_log_probs_t_m1 = p_log_probs + (t - 1) * U_p1 * 2;  
      if (u == 0) {  
        if (t > 0 && t <= diff) {  
          // when u = 0, alpha(t, 0) = alpha(t-1, 0) + log\_probs(t-1, 0).blank  
          *p\_alpha\_t = *p_alpha_t_m1 + p_log_probs_t_m1[kBlankCol];  
        }  
      } else if (u < U_p1) {  
        if (t == u) {  
          // diagonal  
          // alpha(t, u) = alpha(t-1, u-1) + log\_probs(t-1, u-1).symbol  
          p\_alpha\_t[u] =  
              p_alpha_t_m1[u - 1] + (p_log_probs_t_m1 + (u - 1) * 2)[kSymCol];  
        } else {  
          // alpha(t, u) = log\_sum\_exp(alpha(t-1, u) + log\_probs(t-1, u).blank,  
          //                      alpha(t-1, u-1) + log\_probs(t-1, u-1).symbol)  
          float skip_prob =  
              p_alpha_t_m1[u] + (p_log_probs_t_m1 + u * 2)[kBlankCol];  
          float emit_prob =  
              p_alpha_t_m1[u - 1] + (p_log_probs_t_m1 + (u - 1) * 2)[kSymCol];  
          p\_alpha\_t[u] = LogAdd(skip_prob, emit_prob);  
          // p\_alpha\_t[u] = 0;  
        }  
      }  
    }  
    __syncthreads();  
  }  
  
  if (u == 0) {  
    total_scores[b] = *(p_alpha + T * U_p1 - 1) +  
                      (p_log_probs + (T * U_p1 - 1) * 2)[kBlankCol];  
  }  
}  

反向计算

  
// Call it like <<<batch\_size, max\_U\_p1>>>  
\_\_global\_\_ void ComputeBetaOneSymPerFrame(const float *log\_probs,  
                                          const int32\_t *logit\_lengths,  
                                          const int32\_t *target\_lengths,  
                                          const int32\_t *row\_splits,  
                                          int32\_t max\_T, int32\_t max\_U\_p1,  
                                          int32\_t *counter, float *beta) {  
  int32\_t b = blockIdx.x;  
  int32\_t u = threadIdx.x;  
  int32\_t T = logit_lengths[b];  
  int32\_t U_p1 = target_lengths[b] + 1;  
  
  int32\_t offset = row_splits[b];  
  float *p_beta = beta + offset;  
  const float *p_log_probs = log_probs + offset * 2;  
  
  if (u == 0) {  
    (p_beta + T * U_p1)[-1] = (p_log_probs + T * U_p1 * 2 - 2)[kBlankCol];  
  }  
  
  __syncthreads();  
  
  int32\_t diff = T - 1 - (U_p1 - 1);  
  
  for (int32\_t n = T + U_p1 - 2; n >= 0; --n) {  
    int32\_t t = n - u;  
    float *p\_beta\_t = p_beta + t * U_p1;  
    float *p_beta_t_p1 = p_beta + (t + 1) * U_p1;  
    const float *p\_log\_probs\_t = p_log_probs + t * U_p1 * 2;  
  
    if (u == U_p1 - 1) {  
      // beta(t, U\_p1-1) = beta(t+1, U\_p1-1) + log\_probs(t, U\_p1-1).blank  
      if (u <= t && t < T - 1) {  
        p\_beta\_t[U_p1 - 1] =  
            p_beta_t_p1[U_p1 - 1] + (p\_log\_probs\_t + (U_p1 - 1) * 2)[kBlankCol];  
      }  
    } else if (u < U_p1) {  
      if (u <= t && t - u <= diff) {  
        if (t - u == diff) {  
          // beta(t, u) = beta(t+1, u+1) + log\_probs(t, u).symbol  
          p\_beta\_t[u] = p_beta_t_p1[u + 1] + (p\_log\_probs\_t + u * 2)[kSymCol];  
        } else {  
          // beta(t, u) = log\_sum\_exp(beta(t+1, u) + log\_probs(t, u).blank,  
          //                          beta(t+1, u+1) + log\_probs(t, u).symbol)  
          float skip_prob = p_beta_t_p1[u] + (p\_log\_probs\_t + u * 2)[kBlankCol];  
          float emit_prob =  
              p_beta_t_p1[u + 1] + (p\_log\_probs\_t + u * 2)[kSymCol];  
          p\_beta\_t[u] = LogAdd(skip_prob, emit_prob);  
        }  
      }  
    }  
    __syncthreads();  
  }  
}  

梯度计算

  
\_\_global\_\_ void ComputeGradientOneSymPerFrame(  
    const float *logits, const float *denominator, const int32\_t *targets,  
    const int32\_t *logit\_lengths, const int32\_t *target\_lengths, int32\_t blank,  
    const int32\_t *row\_splits, const int32\_t *row\_ids, int32\_t sum\_all\_TU,  
    int32\_t vocab\_size, int32\_t targets\_col, const float *alpha,  
    const float *beta, float *gradient) {  
  int32\_t idx01 = blockDim.x * blockIdx.x + threadIdx.x;  
  if (idx01 >= sum_all_TU) return;  // out-of-boundary  
  
  int32\_t b = row_ids[idx01];  // batch index  
  
  // +1 since it is prepended with a blank  
  int32\_t U_p1 = target_lengths[b] + 1;  
  int32\_t T = logit_lengths[b];  
  int32\_t offset = row_splits[b];  
  
  // 计算当前batch中的 t,u index  
  int32\_t idx1 = idx01 - offset;  
  int32\_t t = idx1 / U_p1;  
  int32\_t u = idx1 % U_p1;  
  
  const float *p_logits_t_u = logits + idx01 * vocab_size;  
  const float *p_denominator = denominator + offset;  
  const float *p\_denominator\_t = p_denominator + t * U_p1;  
  const int32\_t *p_targets = targets + b * targets_col;  
  
  const float *p_alpha = alpha + offset;  
  const float *p\_alpha\_t = p_alpha + t * U_p1;  
  
  const float *p_beta = beta + offset;  
  const float *p\_beta\_t = p_beta + t * U_p1;  
  const float *p_beta_t_p1 = p_beta + (t + 1) * U_p1;  
  
  float *p_grad_t_u = gradient + idx01 * vocab_size;  
  // nll  
  float loss = -1 * p_beta[0];  
  int32\_t diff = T - 1 - (U_p1 - 1);  
  
  if (u > t || t - u > diff || isinf(loss) || isnan(loss)) {  
    for (int32\_t v = 0; v != vocab_size; ++v) {  
      p_grad_t_u[v] = 0;  
    }  
    return;  
  }  
  
  float c = p\_alpha\_t[u] + loss - p\_denominator\_t[u];  
  
  int32\_t target_u = (u < U_p1 - 1) ? p_targets[u] : -1;  // -1 is not used  
  for (int32\_t v = 0; v != vocab_size; ++v) {  
    float g = p_logits_t_u[v] + c;  
    float val = 0;  
    if (v == blank && t == T - 1 && u == U_p1 - 1) {  
      // last blank transition  
      val = expf(g + p\_beta\_t[u]) - expf(g);  
    } else if (v == blank && t - u < diff) {  
      val = expf(g + p\_beta\_t[u]) - expf(g + p_beta_t_p1[u]);  
    } else if (u < U_p1 - 1 && v == target_u) {  
      val = expf(g + p\_beta\_t[u]) - expf(g + p_beta_t_p1[u + 1]);  
    } else {  
      val = expf(g + p\_beta\_t[u]);  
    }  
  
    p_grad_t_u[v] = val;  
  }  // end v  
}  

参考资料


欢迎关注公众号

0
0
0
0
评论
未登录
看完啦,登录分享一下感受吧~
暂无评论