包含以下内容:
- elu_f32_kernel
- elu_f32x4_kernel(float4向量化版本)
- elu_f16_kernel(fp16版本)
- elu_f16x2_kernel(fp16向量化版本)
- elu_f16x8_kernel(fp16向量化版本)
- elu_f16x8_pack_kernel(fp16向量化,pack版本)
- PyTorch bindings
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 elu.py
输出:
-------------------------------------------------------------------------------------
S=1024, K=1024
out_f32: ['2.11984897 ', '0.61072099 '], time:0.00429201ms
out_f32x4: ['2.11984897 ', '0.61072099 '], time:0.00371671ms
out_f32_th: ['2.11984897 ', '0.61072099 '], time:0.01802135ms
-------------------------------------------------------------------------------------
out_f16: ['2.11914062 ', '0.61083984 '], time:0.00445604ms
out_f16x2: ['2.11914062 ', '0.61083984 '], time:0.00298023ms
out_f16x8: ['2.11914062 ', '0.61083984 '], time:0.00276542ms
out_f16x8pack: ['2.11914062 ', '0.61083984 '], time:0.00266290ms
out_f16_th: ['2.11914062 ', '0.61083984 '], time:0.01671553ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=1024, K=2048
out_f32: ['-0.18413025 ', '-0.36962259 '], time:0.00586891ms
out_f32x4: ['-0.18413025 ', '-0.36962259 '], time:0.00565982ms
out_f32_th: ['-0.18413025 ', '-0.36962259 '], time:0.02765131ms
-------------------------------------------------------------------------------------
out_f16: ['-0.18408203 ', '-0.36962891 '], time:0.00610733ms
out_f16x2: ['-0.18408203 ', '-0.36962891 '], time:0.00525951ms
out_f16x8: ['-0.18408203 ', '-0.36962891 '], time:0.00424457ms
out_f16x8pack: ['-0.18408203 ', '-0.36962891 '], time:0.00393462ms
out_f16_th: ['-0.18408203 ', '-0.36962891 '], time:0.01878762ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=1024, K=4096
out_f32: ['-0.44692516 ', '0.08315884 '], time:0.00999045ms
out_f32x4: ['-0.44692516 ', '0.08315884 '], time:0.00968099ms
out_f32_th: ['-0.4469251 ', '0.08315884 '], time:0.04965019ms
-------------------------------------------------------------------------------------
out_f16: ['-0.44677734 ', '0.08312988 '], time:0.01032424ms
out_f16x2: ['-0.44677734 ', '0.08312988 '], time:0.00987983ms
out_f16x8: ['-0.44677734 ', '0.08312988 '], time:0.00668359ms
out_f16x8pack: ['-0.44677734 ', '0.08312988 '], time:0.00599647ms
out_f16_th: ['-0.44677734 ', '0.08312988 '], time:0.03015637ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=1024
out_f32: ['-0.57048458 ', '-0.37474912 '], time:0.00696707ms
out_f32x4: ['-0.57048458 ', '-0.37474912 '], time:0.00568914ms
out_f32_th: ['-0.57048458 ', '-0.37474912 '], time:0.02774405ms
-------------------------------------------------------------------------------------
out_f16: ['-0.5703125 ', '-0.37451172 '], time:0.00749683ms
out_f16x2: ['-0.5703125 ', '-0.37451172 '], time:0.00430250ms
out_f16x8: ['-0.5703125 ', '-0.37451172 '], time:0.00412726ms
out_f16x8pack: ['-0.5703125 ', '-0.37451172 '], time:0.00405145ms
out_f16_th: ['-0.5703125 ', '-0.37451172 '], time:0.01877379ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=2048
out_f32: ['0.31947908 ', '0.2870425 '], time:0.00998878ms
out_f32x4: ['0.31947908 ', '0.2870425 '], time:0.00961614ms
out_f32_th: ['0.31947908 ', '0.2870425 '], time:0.05018497ms
-------------------------------------------------------------------------------------
out_f16: ['0.31958008 ', '0.28710938 '], time:0.01032591ms
out_f16x2: ['0.31958008 ', '0.28710938 '], time:0.00894380ms
out_f16x8: ['0.31958008 ', '0.28710938 '], time:0.00641274ms
out_f16x8pack: ['0.31958008 ', '0.28710938 '], time:0.00598955ms
out_f16_th: ['0.31958008 ', '0.28710938 '], time:0.03003049ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=4096
out_f32: ['-0.45947188 ', '-0.51995623 '], time:0.01824403ms
out_f32x4: ['-0.45947188 ', '-0.51995623 '], time:0.01857781ms
out_f32_th: ['-0.45947188 ', '-0.51995623 '], time:0.23533177ms
-------------------------------------------------------------------------------------
out_f16: ['-0.45947266 ', '-0.52001953 '], time:0.01726103ms
out_f16x2: ['-0.45947266 ', '-0.52001953 '], time:0.01664019ms
out_f16x8: ['-0.45947266 ', '-0.52001953 '], time:0.01005602ms
out_f16x8pack: ['-0.45947266 ', '-0.52001953 '], time:0.00917339ms
out_f16_th: ['-0.45947266 ', '-0.52001953 '], time:0.05084753ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=1024
out_f32: ['1.88309503 ', '-0.73455477 '], time:0.01146603ms
out_f32x4: ['1.88309503 ', '-0.73455477 '], time:0.00886965ms
out_f32_th: ['1.88309503 ', '-0.73455477 '], time:0.08757877ms
-------------------------------------------------------------------------------------
out_f16: ['1.8828125 ', '-0.734375 '], time:0.01243806ms
out_f16x2: ['1.8828125 ', '-0.734375 '], time:0.00632644ms
out_f16x8: ['1.8828125 ', '-0.734375 '], time:0.00585818ms
out_f16x8pack: ['1.8828125 ', '-0.734375 '], time:0.00601745ms
out_f16_th: ['1.8828125 ', '-0.734375 '], time:0.03008652ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=2048
out_f32: ['0.4883095 ', '-0.83807635 '], time:0.01823831ms
out_f32x4: ['0.4883095 ', '-0.83807635 '], time:0.01836205ms
out_f32_th: ['0.4883095 ', '-0.83807635 '], time:0.23539877ms
-------------------------------------------------------------------------------------
out_f16: ['0.48828125 ', '-0.83789062 '], time:0.01728797ms
out_f16x2: ['0.48828125 ', '-0.83789062 '], time:0.01487613ms
out_f16x8: ['0.48828125 ', '-0.83789062 '], time:0.00981522ms
out_f16x8pack: ['0.48828125 ', '-0.83789062 '], time:0.00917029ms
out_f16_th: ['0.48828125 ', '-0.83789062 '], time:0.05086207ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=4096
out_f32: ['1.15699899 ', '0.59496713 '], time:0.14590597ms
out_f32x4: ['1.15699899 ', '0.59496713 '], time:0.14612436ms
out_f32_th: ['1.15699899 ', '0.59496713 '], time:0.76177263ms
-------------------------------------------------------------------------------------
out_f16: ['1.15722656 ', '0.59472656 '], time:0.03287864ms
out_f16x2: ['1.15722656 ', '0.59472656 '], time:0.03170896ms
out_f16x8: ['1.15722656 ', '0.59472656 '], time:0.01807237ms
out_f16x8pack: ['1.15722656 ', '0.59472656 '], time:0.01692462ms
out_f16_th: ['1.15722656 ', '0.59472656 '], time:0.26625180ms
-------------------------------------------------------------------------------------