CUDA BlockReduce and WarpReduce

warp-level primitives

https://developer.nvidia.com/zh-cn/blog/using-cuda-warp-level-primitives

https://zhuanlan.zhihu.com/p/572820783

  
#define FULL\_WARP\_MASK 0xFFFFFFFF  
  
#define CREATE\_SHFL\_MASK(mask, predicate) \  
  mask = \_\_ballot\_sync(FULL\_WARP\_MASK, (predicate))  
  
  
template <typename T>  
\_\_forceinline\_\_ \_\_device\_\_ T CudaShuffleDownSync(unsigned mask, T val, int delta, int width = warpSize) {  
  return __shfl_down_sync(mask, val, static\_cast<unsigned>(delta), width);  
}  

__ballot_sync

  
int __ballot_sync(unsigned mask, int predicate);  

返回一个 32 位无符号整数,代表了该线程束内变量 predicate 的非零值分布情况(即线程 predicate 为零的该函数返回值该位为 0,线程 predicate 非零的该函数返回值该位为 1 )

  
unsigned mask = __ballot_sync(FULL_MASK, threadIdx.x < NUM_ELEMENTS);  
if (threadIdx.x < NUM_ELEMENTS) {   
    val = input[threadIdx.x];   
    for (int offset = 16; offset > 0; offset /= 2)  
        val += __shfl_down_sync(mask, val, offset);  
    …  
}  

假设我们要计算数组 input[] 的所有元素的总和,其大小 NUM_ELEMENTS 小于线程块中的线程数,这个时候就可以考虑使用 __ballot_sync() 指定 predicate 为 thread.Idx.x < NUM_ELEMENTS 来计算 __shfl_down_sync() 需要的成员掩码 mask, 从而决定哪些线程将参与规约求和任务。__ballot_sync() 自身使用 FULL_MASK(32 个线程为 0xffffffff),因为我们假设所有线程都会执行它。

__shfl_down_sync

  
int __shfl_down_sync(unsigned mask, int var, unsigned detla, int width=warpSize);  

表示被 mask 指定的线程返回向后偏移为 delta 的线程中的变量 var 的值,其余线程返回0;

picture.image

__shfl_down_sync(0xffffffff, x, 2),表示 laneid 为 0 ~29 的线程分别获得 laneid 为 2 ~ 31 的线程中变量 x 的值;

picture.image

__shfl_down_sync(0xffffffff, x, 2, 16),表示 laneid 为 0 ~13 的线程分别获得 laneid 为 2 ~ 15 的线程中变量 x 的值;laneid 为 16 ~29 的线程分别获得 laneid 为 18 ~ 31 的线程中变量 x 的值;

picture.image

__shfl_down_sync primitives 其实随处可见,在广泛应用的并行规约(parallel reductions)算法中,最常见的就是并行规约求和(BlockReduceSum)。理解 BlockReduceSum 之前先来看一下 WarpReduceSum 的实现过程,这里截取一段 pytorch 中实现的 WarpReduceSum 代码,共同学习一下;

经过 log2(warpSize)=5 次 shfl_down.

  
// Sums `val` accross all threads in a warp.  
//  
// Assumptions:  
//   - The size of each block should be a multiple of `warpSize`  
template <typename T>  
\_\_inline\_\_ \_\_device\_\_ T WarpReduceSum(T val) {  
#pragma unroll  
  for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) {  
    val += __shfl_down_sync(0xffffffff, val, offset, warpSize);  
  }  
  return val;  
}  

这段代码的实现逻辑可以借助下面这张图来辅助理解,假设初始 warp 内每个线程的 val 值为1,经过 5 轮循环之后线程0获得最终正确的 reduce sum 结果val=32。注意,这里只画了有助于理解最终规约结果的线程,实则所有线程都会改变值,只是他们在规约中并不会用到而已;

picture.image

有了 WarpReduceSum 的基础,那么再截取一段 pytorch 中实现的 BlockReduceSum 代码,BlockReduce 主要借助 WarpReduce 来做,因此 blocksize 必须是 warp 的整数倍。整个流程如下:

  1. 首先让所有线程执行 WarpReduceSum
  2. 然后将每个线程束的 reduce 结果存储到 shared memory 中,注意这里是 lane_id=0 的线程去存储,因为前面提到了只有线程0上有正确的reduce结果
  3. 从 shared memory 把数据读取出来,最后再用一个 warp 对其做 reduce,即可获得整个 block 的 reduce 结果
  
// Sums `val` accross all threads in a block.  
//  
// Assumptions:  
//   - Thread blocks are an 1D set of threads (indexed with `threadIdx.x` only)  
//   - The size of each block should be a multiple of `warpSize`  
//   - `shared` should be a pointer to shared memory with size of, at least,  
//     `sizeof(T) * number\_of\_warps`  
template <typename T>  
\_\_inline\_\_ \_\_device\_\_ T BlockReduceSum(T val, T* shared) {  
  const int laneid = threadIdx.x % warpSize;  
  const int warpid = threadIdx.x / warpSize;  
  val = WarpReduceSum(val);  
  __syncthreads();  
  if (laneid == 0) {  
    shared[warpid] = val;  
  }  
  __syncthreads();  
  val = (threadIdx.x < blockDim.x / warpSize) ? shared[laneid] : T(0);  
  if (warpid == 0) {  
    val = WarpReduceSum(val);  
  }  
  return val;  
}  

WarpReduce

  
constexpr int kWarpSize = 32;  
  
// ReduceOp 可以是任何 Binary Op, e.g sum, or, and  
template <typename T, typename ReduceOp>  
\_\_device\_\_ \_\_forceinline\_\_ T WarpReduce(T val, ReduceOp reducer) {  
  unsigned mask = 0u;  
  mask = __ballot_sync(0xFFFFFFFF, true);  
    
  for (int stride = kWarpSize / 2; stride > 0; stride >>= 1) {  
    T temp = __shfl_down_sync(mask, val, stride);  
    val = reducer(val, temp);  
  }  
  return val;  
}  

BlockReduce

warpSize = 32, 能在一个warpSize内规约block中的数,那么 blockDim.x <= warpSize * warpSize = 1024, 且 blockDim.x % warpSize = 0;

blockDim.x = 128 的话,可以规约 4 个Warp的数。

  
/* e.g.  
 * |---------block---------|  
 * |warp0|warp1|warp2|warp3|  
 * |0~31|32~63|64~95|96~127|  ---->blockDim.x = 128  
 *  \|/  \|/   \|/    \|/     ---->1. First WarpReduce in each warp  
 * res0  res1  res2  res3     ---->2. Store result of each warp to shared memory  
 *   \    \    /     /        ---->3. Load the result above from shared memory  
 *        res                         to warp0 and process the second WarpReduce  
 */  

BlockXReduce

BlockXReduce 在 blockDim.x 的方向上规约,即按行规约。

  
/**  
 * @brief BlockXReduce reduce along blockDim.x.  
 * blockDim.y不变,返回 threadIdx.y 行规约后的结果。  
 */  
template <typename T, typename ReduceOp>  
\_\_device\_\_ \_\_forceinline\_\_ T BlockXReduce(T val, ReduceOp reducer) {  
  __syncthreads();  
  using kWarpSize;  
  // 最多可以处理 2*warpSize 个 warp。  
  __shared__ T shared[2 * kWarpSize];  
    
  int block_dim_x = blockDim.x;  
    
  if (blockDim.x > kWarpSize) {  
    // blockDim.x 大于 warpSize 时。即,不能在一个warpReduce内完成所有数的规约。  
      
    // Bit operation can be used when kWarpSize is 32 or 64 now  
    // kWarpSize==32, rshift\_val=5; 2^5 = 32  
    // kWarpSize==64, rshift\_val=6; 2^6 = 64  
    // other, rshift\_val = 5; 2^5 = 32  
    constexpr int rshift_val =  
        (kWarpSize != 32) ? ((kWarpSize == 64) ? 6 : 5) : 5;  
      
    // x 方向上包含多少个warp,  
    // 即 block\_dim\_x = blockDim.x / warpSize  
    block_dim_x = blockDim.x >> rshift_val;  
      
    // lane = threadIdx.x % kWarpSize  
    int lane = threadIdx.x & (kWarpSize - 1);  
      
    int tid = threadIdx.y * blockDim.x + threadIdx.x;  
    // warp id  
    int wid = tid >> rshift_val;  
    // row id = block id  
    int bid = threadIdx.y;  
      
    val = WarpReduce(val, reducer);  
    if (lane == 0) {  
      // warp id  
      shared[wid] = val;  
    }   
    __syncthreads();  
      
    // 为 <2> WarpReduce 加载数据。y方向shape不变,x方向缩减为warp的个数  
    val = shared[bid * block_dim_x + lane];  
  }  
  
  // <2> WarpReduce, 在一个warp内可以完成规约。  
  unsigned mask = 0u;  
  CREATE_SHFL_MASK(mask, true);  
  for (int stride = 1; stride < block_dim_x; stride <<= 1) {  
    T temp = CudaShuffleDownSync(mask, val, stride);  
    val = reducer(val, temp);  
  }  
  __syncthreads();  
    
  // x方向结果规约到 threadIdx.x,并根据行号threadIdx.y保存数据。  
  // 即,矩阵按行规约。  
  if (threadIdx.x == 0) {  
    shared[threadIdx.y] = val;  
  }  
  __syncthreads();  
    
  // 返回对应行的规约结果。  
  return shared[threadIdx.y];  
}  

BlockYReduce

BlockYReduce 在 blockDim.y 的方向上规约,即按列规约。

  
/**  
 * @brief Will be used in BlockYReduce, get the index of reduce\_num in shared  memory.  
 */  
\_\_device\_\_ \_\_forceinline\_\_ int SharedMemoryIndex(int stride) {  
  return (threadIdx.y + stride) * blockDim.x + threadIdx.x;  
}  
  
/**  
 * @brief BlockYReduce reduce along blockDim.y.  
 */  
template <typename T, typename ReduceOp>  
\_\_device\_\_ \_\_forceinline\_\_ T BlockYReduce(T val, ReduceOp reducer) {  
  // block 中的线程数最多1024 = 32*32  
  __shared__ T shared_memory[1024];  
    
  // block 中的数加载到 shared memory中。  
  shared_memory[SharedMemoryIndex(0)] = val;  
    
  for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {  
    __syncthreads();  
      
    if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y)   
      // reduce threadIdx.y and threadIdx.y + stride  
      T temp = shared_memory[SharedMemoryIndex(stride)];  
      val = reducer(val, temp);  
    }  
      
    shared_memory[SharedMemoryIndex(0)] = val;  
  }  
  __syncthreads();  
    
    
  // 获取按列规约后的结果,即row0的结果。  
  return shared_memory[threadIdx.x];  
}  

参考文献

0
0
0
0
评论
未登录
暂无评论