广州中小企业网站制作,刚做的网站为什么百度搜不到,福建网站开发速成班,网站设计教程self-attention 的 CUDA 实现及优化 (上)
导 读
self-attention 是 Transformer 中最关键、最复杂的部分#xff0c;也是 Transformer 优化的核心环节。理解 self-attention #xff0c;对于深入理解 Transformer 具有关键作用#xff0c;本篇主要就围绕 self-attention 展…self-attention 的 CUDA 实现及优化 (上)
导 读
self-attention 是 Transformer 中最关键、最复杂的部分也是 Transformer 优化的核心环节。理解 self-attention 对于深入理解 Transformer 具有关键作用本篇主要就围绕 self-attention 展开由于该部分比较复杂故分为上下两篇本篇为上篇。
0****1
self-attention的CUDA简单实现
self-attention 的原理非常常见在之前的文章中也分析很多因此不在此介绍介绍其原理仅解读代码。
1、CPU版本
以下是基础的 CPU 版本的实现下面对其稍作分析
• 输入inp 为 x 与 QKV_weight 相乘后得到的 QKV 值对于b(batch size), t(sequence len), h(head) 的 q(query_t) 值的索引为 inp[b,t,h*hs:(h1)hs] , k(key_t2) 值在此基础上偏移 C 维即可即inp[b,t,h*hsC:(h1)hsC]
• 得到 q,k 之后便通过点乘计算 attention 值算完一个 attn 值之后进行 scale 操作同时记录最大值以便进行softmax计算完一行后进行 mask 操作
• 进行 softmax 操作得到 attn 值
• 索引 v(value_t2) 并与 attn 值进行矩阵乘法运算
// CPU code referencevoid attention_forward_cpu(float* out, float* preatt, float* att,const float* inp,int B, int T, int C, int NH) {// input is (B, T, 3C) Q,K,V// preatt, att are (B, NH, T, T)// output is (B, T, C)int C3 C*3;int hs C / NH; // head sizefloat scale 1.0 / sqrtf(hs);for (int b 0; b B; b) {for (int t 0; t T; t) {for (int h 0; h NH; h) {const float* query_t inp b * T * C3 t * C3 h * hs;float* preatt_bth preatt b*NH*T*T h*T*T t*T;float* att_bth att b*NH*T*T h*T*T t*T;// pass 1: calculate query dot key and maxvalfloat maxval -10000.0f; // TODO something betterfor (int t2 0; t2 t; t2) {const float* key_t2 inp b * T * C3 t2 * C3 h * hs C; // C because its key// (query_t) dot (key_t2)float val 0.0f;for (int i 0; i hs; i) {val query_t[i] * key_t2[i];}val * scale;if (val maxval) {maxval val;}preatt_bth[t2] val;}// pad with -INFINITY outside of autoregressive region for debugging comparisonsfor (int t2 t1; t2 T; t2) {preatt_bth[t2] -INFINITY;}// pass 2: calculate the exp and keep track of sumfloat expsum 0.0f;for (int t2 0; t2 t; t2) {float expv expf(preatt_bth[t2] - maxval);expsum expv;att_bth[t2] expv;}float expsum_inv expsum 0.0f ? 0.0f : 1.0f / expsum;// pass 3: normalize to get the softmaxfor (int t2 0; t2 T; t2) {if (t2 t) {att_bth[t2] * expsum_inv;} else {// causal attention mask. not strictly necessary to set to zero here// only doing this explicitly for debugging and checking to PyTorchatt_bth[t2] 0.0f;}}// pass 4: accumulate weighted values into the output of attentionfloat* out_bth out b * T * C t * C h * hs;for (int i 0; i hs; i) { out_bth[i] 0.0f; }for (int t2 0; t2 t; t2) {const float* value_t2 inp b * T * C3 t2 * C3 h * hs C*2; // C*2 because its valuefloat att_btht2 att_bth[t2];for (int i 0; i hs; i) {out_bth[i] att_btht2 * value_t2[i];}}}}}
}2、CUDA初步实现V1
仍然延续 CPU 版本的基本思路只是计算的不同拆分为 3 个 kernel 进行计算
• 第一步计算 attention 值总共使用B*NH*T*T 个线程即每个线程计算一个值 // attention calculationint total_threads B * NH * T * T;int num_blocks ceil_div(total_threads, block_size);attention_query_key_kernel1num_blocks, block_size(preatt, inp, B, T, C, NH);kernel 函数的实现如下
__global__ void attention_query_key_kernel1(float* preatt, const float* inp,int B, int T, int C, int NH) {int idx blockIdx.x * blockDim.x threadIdx.x;int total_threads B * NH * T * T;if (idx total_threads) {int t2 idx % T;int t (idx / T) % T;if (t2 t) {// autoregressive maskpreatt[idx] -INFINITY;return;}int h (idx / (T * T)) % NH;int b idx / (NH * T * T);int C3 C*3;int hs C / NH; // head sizeconst float* query_t inp b * T * C3 t * C3 h * hs;const float* key_t2 inp b * T * C3 t2 * C3 h * hs C; // C because its key// (query_t) dot (key_t2)float val 0.0f;for (int i 0; i hs; i) {val query_t[i] * key_t2[i];}val * 1.0 / sqrtf(hs);preatt[idx] val;}
}• 第二步softmax 操作该操作在之前的 op 优化中已经详细讨论不予赘述
_global__ void attention_softmax_kernel1(float* att, const float* preatt,int B, int T, int NH) {int idx blockIdx.x * blockDim.x threadIdx.x;int total_threads B * T * NH;if (idx total_threads) {int h idx % NH;int t (idx / NH) % T;int b idx / (NH * T);const float* preatt_bth preatt b*NH*T*T h*T*T t*T;float* att_bth att b*NH*T*T h*T*T t*T;// find maxvalfloat maxval -10000.0f; // TODO something betterfor (int t2 0; t2 t; t2) {if (preatt_bth[t2] maxval) {maxval preatt_bth[t2];}}// calculate the exp and keep track of sumfloat expsum 0.0f;for (int t2 0; t2 t; t2) {float expv expf(preatt_bth[t2] - maxval);expsum expv;att_bth[t2] expv;}float expsum_inv expsum 0.0f ? 0.0f : 1.0f / expsum;// normalize to get the softmaxfor (int t2 0; t2 T; t2) {if (t2 t) {att_bth[t2] * expsum_inv;} else {// causal attention mask. not strictly necessary to set to zero here// only doing this explicitly for debugging and checking to PyTorchatt_bth[t2] 0.0f;}}}
}• 第三步attention 值与 v 进行矩阵乘法运算
__global__ void attention_value_kernel1(float* out, const float* att, const float* inp,int B, int T, int C, int NH) {int idx blockIdx.x * blockDim.x threadIdx.x;int total_threads B * T * NH;if (idx total_threads) {int h idx % NH;int t (idx / NH) % T;int b idx / (NH * T);int C3 C*3;int hs C / NH; // head sizefloat* out_bth out b * T * C t * C h * hs;const float* att_bth att b*NH*T*T h*T*T t*T;for (int i 0; i hs; i) { out_bth[i] 0.0f; }for (int t2 0; t2 t; t2) {const float* value_t2 inp b * T * C3 t2 * C3 h * hs C*2; // C*2 because its valuefloat att_btht2 att_bth[t2];for (int i 0; i hs; i) {out_bth[i] att_btht2 * value_t2[i];}}}
}由此完成最基本的 self-attention 的实现性能数据如下
block_size 32 | time 238.912872 ms
block_size 64 | time 252.689301 ms
block_size 128 | time 246.945175 ms
block_size 256 | time 261.469421 ms
block_size 512 | time 241.190613 ms3、flash attention的简单实现V2
flash attention 是根据 GPU 的内存体系对 self-attention 做的一个极其重要的优化。 • 首先对于关键参数进行初始化
// these are hardcoded to 32 for nowconst int Bc 32;const int Br 32;// renaming these to be consistent with the kernel// const int B B;const int nh NH;const int N T;const int d C / NH;// moreconst int Tc ceil((float) N / Bc);const int Tr ceil((float) N / Br);const float softmax_scale 1.0 / sqrt(d);• 然后计算每个 block 所需要的 SRAM以确保不会溢出
// calculate SRAM size needed per block, ensure we have enough shared memoryint col_tile_size Bc * d; // size of Kj, Vjint row_tile_size Br * d; // size of Qiconst int sram_size (2 * col_tile_size * sizeof(float)) // SRAM size for Kj, Vj (row_tile_size * sizeof(float)) // SRAM size for Qi (Bc * Br * sizeof(float)); // SRAM size for Sint max_sram_size;cudaDeviceGetAttribute(max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);if (sram_size max_sram_size) {printf(Max shared memory: %d, requested shared memory: %d \n, max_sram_size, sram_size);printf(SRAM size exceeds maximum shared memory per block\n);printf(Try decreasing col_tile_size or row_tile_size further\n);exit(1);}• 为了避免在 flash attention 中进行复杂的索引、reshape 及 permute 操作首先使用一个kernel 完成这些操作
__global__ void permute_kernel(float* q, float* k, float* v,const float* inp,int B, int N, int NH, int d) {// okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d)// but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d)int idx blockIdx.x * blockDim.x threadIdx.x;// Q[b][nh_][n][d_] inp[b][n][0][nh_][d_]if (idx B * NH * N * d) {int b idx / (NH * N * d);int rest idx % (NH * N * d);int nh_ rest / (N * d);rest rest % (N * d);int n rest / d;int d_ rest % d;int inp_idx \(b * N * 3 * NH * d) (n * 3 * NH * d) (0 * NH * d) (nh_ * d) d_;q[idx] inp[inp_idx];k[idx] inp[inp_idx NH * d];v[idx] inp[inp_idx 2 * (NH * d)];}
}• 之后就是核心环节flash attention 的实现了其过程可以参照以下图示 __global__ void attention_forward_kernel2(const float* Q,const float* K,const float* V,const int N,const int d,const int Tc,const int Tr,const int Bc,const int Br,const float softmax_scale,float* l,float* m,float* O
) {int tx threadIdx.x;int bx blockIdx.x; int by blockIdx.y; // batch and head index// Offset into Q,K,V,O,l,m - different for each batch and headint qkv_offset (bx * gridDim.y * N * d) (by * N * d); // gridDim.y nhint lm_offset (bx * gridDim.y * N) (by * N); // offset for l and m// Define SRAM for Q,K,V,Sextern __shared__ float sram[];int tile_size Bc * d; // size of Qi, Kj, Vjfloat* Qi sram;float* Kj sram[tile_size];float* Vj sram[tile_size * 2];float* S sram[tile_size * 3];for (int j 0; j Tc; j) {// Load Kj, Vj to SRAMfor (int x 0; x d; x) {Kj[(tx * d) x] K[qkv_offset (tile_size * j) (tx * d) x];Vj[(tx * d) x] V[qkv_offset (tile_size * j) (tx * d) x];}__syncthreads(); // such that the inner loop can use the correct Kj, Vjfor (int i 0; i Tr; i) {// if past the end of the sequence, breakif (i * Br tx N) {break;}// Load Qi to SRAM, l and m to registersfor (int x 0; x d; x) {Qi[(tx * d) x] Q[qkv_offset (tile_size * i) (tx * d) x];}float row_m_prev m[lm_offset (Br * i) tx];float row_l_prev l[lm_offset (Br * i) tx];// S QK^T, row_m rowmax(S)// S[tx][y] Sum_{x 0}^{d-1} {Qi[tx][x] * Kj[y][x]}// row_m Max_{y 0}^{Bc-1} S[tx][y]// with causal maskingfloat row_m -INFINITY;for (int y 0; y Bc; y) {if (j * Bc y N) {break;}float sum 0;for (int x 0; x d; x) {sum Qi[(tx * d) x] * Kj[(y * d) x];}sum * softmax_scale;if (i * Br tx j * Bc y)sum -INFINITY;S[(Bc * tx) y] sum;if (sum row_m)row_m sum;}// implement softmax with causal masking// P exp(S - row_m), row_l rowsum(P)// P[tx][y] exp(S[tx][y] - row_m)float row_l 0;for (int y 0; y Bc; y) {if (j * Bc y N) {break;}if (i * Br tx j * Bc y)S[(Bc * tx) y] 0;elseS[(Bc * tx) y] __expf(S[(Bc * tx) y] - row_m);row_l S[(Bc * tx) y];}// Compute new m and lfloat row_m_new max(row_m_prev, row_m);float row_l_new (__expf(row_m_prev - row_m_new) * row_l_prev) (__expf(row_m - row_m_new) * row_l);// Write O, l, m to HBMfor (int x 0; x d; x) {float pv 0; // Pij * Vjfor (int y 0; y Bc; y) {if (j * Bc y N) {break;}pv S[(Bc * tx) y] * Vj[(y * d) x];}O[qkv_offset (tile_size * i) (tx * d) x] (1 / row_l_new) \* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset (tile_size * i) (tx * d) x]) \ (__expf(row_m - row_m_new) * pv));}m[lm_offset (Br * i) tx] row_m_new;l[lm_offset (Br * i) tx] row_l_new;}__syncthreads(); // otherwise, thread can use the wrong Kj, Vj in inner loop}
}• 以上计算完成后还需要进行 unpermute 操作具体如下
__global__ void unpermute_kernel(const float* inp, float *out, int B, int N, int NH, int d) {// out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)int idx blockIdx.x * blockDim.x threadIdx.x;// out[b][n][nh_][d_] - inp[b][nh_][n][d_]if (idx B * NH * N * d) {int b idx / (NH * N * d);int rest idx % (NH * N * d);int nh_ rest / (N * d);rest rest % (N * d);int n rest / d;int d_ rest % d;int other_idx (b * NH * N * d) (n * NH * d) (nh_ * d) d_;out[other_idx] inp[idx];}
}这样就完成了简单的 flash attention 1 的前向过程性能相较于V1反而有所下降主要是数据量较小所致数据如下
block_size 32 | time 536.709961 ms
block_size 64 | time 526.100098 ms
block_size 128 | time 583.016235 ms
block_size 256 | time 573.955994 ms
block_size 512 | time 534.477051 ms0****2
self-attention的高效实现
1、 使用 cuBLAS 库函数(V3)
在之前的实现中所有的操作都是手动实现的尽管从结果上看完全正确但是性能上和官方版本仍有较大差距。因此本节将 self-attention 中的矩阵乘法操作使用官方 cuBLAS 库来实现。
在此仅展示两个矩阵乘法的实现过程首先是qk.T 如下
// batched matrix multiply with cuBLASconst float alpha 1.0f;const float beta 0.0f;cublasCheck(cublasSgemmStridedBatched(cublas_handle,CUBLAS_OP_T, CUBLAS_OP_N,T, T, HS,alpha,k, HS, T * HS,q, HS, T * HS,beta,preatt, T, T * T,B * NH));然后是attv 如下 // new approach: first cuBLAS another batched matmul// y att v # (B, nh, T, T) (B, nh, T, hs) - (B, nh, T, hs)cublasCheck(cublasSgemmStridedBatched(cublas_handle,CUBLAS_OP_N, CUBLAS_OP_N,HS, T, T,alpha,v, HS, T * HS,att, T, T * T,beta,vaccum, HS, T * HS,B * NH));性能相较于 V1 版本提升约百倍以上数据如下
block_size 32 | time 4.318913 ms
block_size 64 | time 2.606850 ms
block_size 128 | time 2.034935 ms
block_size 256 | time 2.031407 ms
block_size 512 | time 2.064406 ms2 、算子融合与 online softmaxV4
在 V3 基础上使用 online softmax 并且将 scale 操作融合具体如下
__global__ void softmax_forward_kernel5(float* out, float inv_temperature, const float* inp, int N, int T) {// inp, out shape: (N, T, T), where N B * NH// fuses the multiplication by scale inside attention// directly autoregressive, so we only compute the lower triangular part// uses the online softmax algorithmassert(T % 4 0);namespace cg cooperative_groups;cg::thread_block block cg::this_thread_block();cg::thread_block_tile32 warp cg::tiled_partition32(block);int idx blockIdx.x * warp.meta_group_size() warp.meta_group_rank();if(idx N * T) {return;}int own_pos idx % T;int pos_by_4 own_pos / 4;// one row of inp, i.e. inp[idx, :] of shape (T,)const float* x inp idx * T;// not INF, so we dont get NaNs accidentally when subtracting two values.float maxval -FLT_MAX;float sumval 0.0f;const float4* x_vec reinterpret_castconst float4*(x);for (int i warp.thread_rank(); i pos_by_4; i warp.size()) {float4 v x_vec[i];float old_maxval maxval;for(int k 0; k 4; k) {maxval fmaxf(maxval, vec_at(v, k));}sumval * expf(inv_temperature * (old_maxval - maxval));for(int k 0; k 4; k) {sumval expf(inv_temperature * (vec_at(v, k) - maxval));}}if(4*pos_by_4 warp.thread_rank() own_pos) {float old_maxval maxval;maxval fmaxf(maxval, x[4*pos_by_4 warp.thread_rank()]);sumval * expf(inv_temperature * (old_maxval - maxval));sumval expf(inv_temperature * (x[4*pos_by_4 warp.thread_rank()] - maxval));}float global_maxval cg::reduce(warp, maxval, cg::greaterfloat{});sumval * expf(inv_temperature * (maxval - global_maxval));float sum cg::reduce(warp, sumval, cg::plusfloat{});float norm 1.f / sum;// divide the whole row by the sumfor (int i warp.thread_rank(); i own_pos; i warp.size()) {// recalculation is faster than doing the round-trip through memory.float ev expf(inv_temperature * (__ldcs(x i) - global_maxval));__stcs(out idx * T i, ev * norm);}
}其余操作不变性能略有提升数据如下
block_size 32 | time 1.198167 ms
block_size 64 | time 1.073088 ms
block_size 128 | time 1.042434 ms
block_size 256 | time 1.041798 ms
block_size 512 | time 1.044009 ms3 、使用 FP16 进行矩阵运算V5
在 permute/unpermute 阶段进行 FP32-FP16 类型转换如下
if (!skip_permute || first_run_validation) {permute_kernel_lowpnum_blocks, block_size(q, k, v, inp, B, T, NH, HS);}
...if(!skip_permute || first_run_validation) {unpermute_kernel_lowpnum_blocks, block_size(vaccum, out, B, T, NH, HS);}性能数据如下
block_size 32 | time 0.866851 ms
block_size 64 | time 0.743674 ms
block_size 128 | time 0.703196 ms
block_size 256 | time 0.713902 ms
block_size 512 | time 0.712848 ms以上几种方法的对比如下注意坐标轴为指数计算设备的 A100-80G