项目地址: https://github.com/jinmingyi1998/make-torch-inplace

Intro

众所周知,Attention需要大量显存,其内部矩阵乘法耗费空间很大

矩阵乘法

一般的矩阵乘法Am×k×Bk×n=Cm×nA_{m\times k} \times B_{k\times n} = C_{m\times n}

而在self Attention中,大量矩阵乘法中权重矩阵都是方阵,即Am×n×Bn×n=Cm×nA_{m\times n} \times B_{n\times n} = C_{m\times n},这里CC的形状和AA的形状一样,如果我不需要AA了,我其实可以在运算时直接把CC的结果放回AA原来的位置。

但是Torch API的乘法不管这些,人家做的是通用的,C放不进A,就算能放进,为了运算速度也会空间换时间。毕竟gemm切分方式中的A矩阵中一个数要用多次,不能覆盖在原来的位置。

那,我就自己搞一个in-place的方阵乘法吧

Torch CUDA Extention

CUDA C

CUDA Intro

CUDA代码中包含Device端(GPU)和Host端(CPU),函数还支持模板(NVCC牛逼)

kernel中的线程组成为 Grid - Block - Thread,其中相邻32(WarpSize通常为32)个Thread为一个thread warp(一个warp中的线程必然在同一个block中),warp是调用最小单位,block size 不足warp倍数的则补齐

CUDA Thread

内存类型有Global,Share,Local,更多看前面的文章CUDA Shared Memory

kernel function

一个不太高效的写法

将矩阵分成32块,每个线程负责计算x%32的行

然后在计算C中每一行时也分到32个线程(warp)中,每个线程负责一行中的一段cols_per_thread,最后把这个warp做一次sum reduce得到C中一个结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
template <typename T, typename T_W>
__global__ void square_matmul_inplace_kernel_(T *mat, T_W *square_mat, int rows,
int cols, bool transpose_square) {
int lane_id = threadIdx.x % 32;
int C_row = blockIdx.x * ROW_PER_BLOCK + threadIdx.x / 32;

int cols_per_thread = (cols + 31) / 32;
int cols_this_thread = cols_per_thread;
int last_y = cols / cols_per_thread;

if (lane_id == last_y) {
cols_this_thread = cols - cols_per_thread * last_y;
} else if (lane_id > last_y) {
cols_this_thread = 0;
}
T *row_input = mat + C_row * cols;
float thread_sum_buf[64];
if (C_row < rows) {
for (int i = 0; i < cols_this_thread; i++) {
int C_col = lane_id * cols_per_thread + i;
float thread_sum = 0.0;
for (int j = 0; j < cols; j++) {
int mat2_idx;
if (transpose_square) {
mat2_idx = C_col * cols + j;
} else {
mat2_idx = j * cols + C_col;
}
float v1 = static_cast<float>(row_input[j]);
float v2 = static_cast<float>(square_mat[mat2_idx]);
thread_sum += v1 * v2;
}
thread_sum_buf[i] = thread_sum;
}
}
__syncthreads();
if (C_row < rows) {
for (int i = 0; i < cols_this_thread; i++) {
int C_col = lane_id * cols_per_thread + i;
row_input[C_col] = static_cast<T>(thread_sum_buf[i]);
}
}
}
symbol function
__global__ declares kernel, which is called on host and executed on device
__device__ declares device function, which is called and executed on device
__host__ declares host function, which is called and executed on host
__noinline__ to avoid inlining
__forceinline__ to force inlining

C API

Just pybind11

1
2
3
4
5
6
7
8
9
10
11
12
13
#include <torch/extension.h>

void square_matmul_inplace_(at::Tensor input, at::Tensor square_mat, int rows,
int cols);

void square_matmul_inplace_T_(at::Tensor input, at::Tensor square_mat, int rows,
int cols);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("square_matmul", &square_matmul_inplace_, "Matmal inplace (CUDA)");
m.def("square_matmul_transposed", &square_matmul_inplace_T_,
"Matmal inplace transposed(CUDA)");
}

Python Setup File

Use PyTorch cpp_extension

An Example Code(Need Modify)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import subprocess
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension

version_dependent_macros = [
"-DVERSION_GE_1_1",
"-DVERSION_GE_1_3",
"-DVERSION_GE_1_5",
]

extra_cuda_flags = [
"-std=c++14",
"-maxrregcount=50",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
]

cc_number = [75, 80, 86, 89]
cc_flag = []
for n in cc_number:
cc_flag.append("-gencode")
cc_flag.append(f"arch=compute_{n},code=sm_{n}")

print(cc_flag)

extra_cuda_flags += cc_flag

def get_cuda_bare_metal_version(cuda_dir=CUDA_HOME):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]

return raw_output, bare_metal_major, bare_metal_minor


nvcc_raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version()

# An example setup()
setup(
packages=find_packages(),
include_package_data=True,
package_data={
"make_torch_inplace": ["csrc/*"],
},
ext_modules=[
CUDAExtension(
name="make_torch_inplace_C",
sources=[
"make_torch_inplace/csrc/pymodule.cpp",
"make_torch_inplace/csrc/square_matmul.cu",
"make_torch_inplace/csrc/softmax.cu",
"make_torch_inplace/csrc/layernorm.cu",
],
include_dirs=[
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"make_torch_inplace/csrc/",
)
],
extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros,
"nvcc": (
["-O3", "--use_fast_math"]
+ version_dependent_macros
+ extra_cuda_flags
),
},
)
],
cmdclass={"build_ext": BuildExtension},
install_requires=["torch"]
)