项目地址: https://github.com/jinmingyi1998/make-torch-inplace
Intro
众所周知,Attention需要大量显存,其内部矩阵乘法耗费空间很大
一般的矩阵乘法Am×k×Bk×n=Cm×n
而在self Attention中,大量矩阵乘法中权重矩阵都是方阵,即Am×n×Bn×n=Cm×n,这里C的形状和A的形状一样,如果我不需要A了,我其实可以在运算时直接把C的结果放回A原来的位置。
但是Torch API的乘法不管这些,人家做的是通用的,C放不进A,就算能放进,为了运算速度也会空间换时间。毕竟gemm切分方式中的A矩阵中一个数要用多次,不能覆盖在原来的位置。
那,我就自己搞一个in-place的方阵乘法吧
Torch CUDA Extention
CUDA C
CUDA Intro
CUDA代码中包含Device端(GPU)和Host端(CPU),函数还支持模板(NVCC牛逼)
kernel中的线程组成为 Grid - Block - Thread,其中相邻32(WarpSize通常为32)个Thread为一个thread warp(一个warp中的线程必然在同一个block中),warp是调用最小单位,block size 不足warp倍数的则补齐
内存类型有Global,Share,Local,更多看前面的文章CUDA Shared Memory
kernel function
一个不太高效的写法
将矩阵分成32块,每个线程负责计算x%32
的行
然后在计算C中每一行时也分到32个线程(warp)中,每个线程负责一行中的一段cols_per_thread
,最后把这个warp做一次sum reduce得到C中一个结果
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| template <typename T, typename T_W> __global__ void square_matmul_inplace_kernel_(T *mat, T_W *square_mat, int rows, int cols, bool transpose_square) { int lane_id = threadIdx.x % 32; int C_row = blockIdx.x * ROW_PER_BLOCK + threadIdx.x / 32;
int cols_per_thread = (cols + 31) / 32; int cols_this_thread = cols_per_thread; int last_y = cols / cols_per_thread;
if (lane_id == last_y) { cols_this_thread = cols - cols_per_thread * last_y; } else if (lane_id > last_y) { cols_this_thread = 0; } T *row_input = mat + C_row * cols; float thread_sum_buf[64]; if (C_row < rows) { for (int i = 0; i < cols_this_thread; i++) { int C_col = lane_id * cols_per_thread + i; float thread_sum = 0.0; for (int j = 0; j < cols; j++) { int mat2_idx; if (transpose_square) { mat2_idx = C_col * cols + j; } else { mat2_idx = j * cols + C_col; } float v1 = static_cast<float>(row_input[j]); float v2 = static_cast<float>(square_mat[mat2_idx]); thread_sum += v1 * v2; } thread_sum_buf[i] = thread_sum; } } __syncthreads(); if (C_row < rows) { for (int i = 0; i < cols_this_thread; i++) { int C_col = lane_id * cols_per_thread + i; row_input[C_col] = static_cast<T>(thread_sum_buf[i]); } } }
|
symbol |
function |
__global__ |
declares kernel, which is called on host and executed on device |
__device__ |
declares device function, which is called and executed on device |
__host__ |
declares host function, which is called and executed on host |
__noinline__ |
to avoid inlining |
__forceinline__ |
to force inlining |
C API
Just pybind11
1 2 3 4 5 6 7 8 9 10 11 12 13
| #include <torch/extension.h>
void square_matmul_inplace_(at::Tensor input, at::Tensor square_mat, int rows, int cols);
void square_matmul_inplace_T_(at::Tensor input, at::Tensor square_mat, int rows, int cols);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("square_matmul", &square_matmul_inplace_, "Matmal inplace (CUDA)"); m.def("square_matmul_transposed", &square_matmul_inplace_T_, "Matmal inplace transposed(CUDA)"); }
|
Python Setup File
Use PyTorch cpp_extension
An Example Code(Need Modify)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
| import os import subprocess from setuptools import find_packages, setup from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
version_dependent_macros = [ "-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5", ]
extra_cuda_flags = [ "-std=c++14", "-maxrregcount=50", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", ]
cc_number = [75, 80, 86, 89] cc_flag = [] for n in cc_number: cc_flag.append("-gencode") cc_flag.append(f"arch=compute_{n},code=sm_{n}")
print(cc_flag)
extra_cuda_flags += cc_flag
def get_cuda_bare_metal_version(cuda_dir=CUDA_HOME): raw_output = subprocess.check_output( [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True ) output = raw_output.split() release_idx = output.index("release") + 1 release = output[release_idx].split(".") bare_metal_major = release[0] bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
nvcc_raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version()
setup( packages=find_packages(), include_package_data=True, package_data={ "make_torch_inplace": ["csrc/*"], }, ext_modules=[ CUDAExtension( name="make_torch_inplace_C", sources=[ "make_torch_inplace/csrc/pymodule.cpp", "make_torch_inplace/csrc/square_matmul.cu", "make_torch_inplace/csrc/softmax.cu", "make_torch_inplace/csrc/layernorm.cu", ], include_dirs=[ os.path.join( os.path.dirname(os.path.abspath(__file__)), "make_torch_inplace/csrc/", ) ], extra_compile_args={ "cxx": ["-O3"] + version_dependent_macros, "nvcc": ( ["-O3", "--use_fast_math"] + version_dependent_macros + extra_cuda_flags ), }, ) ], cmdclass={"build_ext": BuildExtension}, install_requires=["torch"] )
|
Links