包含以下内容:
- hardswish_f32_kernel
- hardswish_f32x4_kernel(float4向量化版本)
- hardswish_f16_kernel(fp16版本)
- hardswish_f16x2_kernel(fp16向量化版本)
- hardswish_f16x8_kernel(fp16向量化版本)
- hardswish_f16x8_pack_kernel(fp16向量化,pack版本)
- PyTorch bindings
# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada
python3 hardswish.py
输出:
-------------------------------------------------------------------------------------
S=1024, K=1024
out_f32: ['-0.31692463 ', '-0.13540865 '], time:0.00399518ms
out_f32x4: ['-0.31692463 ', '-0.13540865 '], time:0.00348544ms
out_f32_th: ['-0.31692463 ', '-0.13540865 '], time:0.00680089ms
-------------------------------------------------------------------------------------
out_f16: ['-0.31713867 ', '-0.13537598 '], time:0.00405478ms
out_f16x2: ['-0.31713867 ', '-0.13537598 '], time:0.00265884ms
out_f16x8: ['-0.31713867 ', '-0.13537598 '], time:0.00252485ms
out_f16x8pack: ['-0.31713867 ', '-0.13537598 '], time:0.00242925ms
out_f16_th: ['-0.31689453 ', '-0.13537598 '], time:0.00608802ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=1024, K=2048
out_f32: ['0.02895264 ', '0.06114347 '], time:0.00553298ms
out_f32x4: ['0.02895264 ', '0.06114347 '], time:0.00528932ms
out_f32_th: ['0.02895264 ', '0.06114347 '], time:0.01048636ms
-------------------------------------------------------------------------------------
out_f16: ['0.02894592 ', '0.06112671 '], time:0.00549722ms
out_f16x2: ['0.02894592 ', '0.06112671 '], time:0.00471425ms
out_f16x8: ['0.02894592 ', '0.06112671 '], time:0.00379252ms
out_f16x8pack: ['0.02894592 ', '0.06112671 '], time:0.00358367ms
out_f16_th: ['0.02894592 ', '0.06115723 '], time:0.00684929ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=1024, K=4096
out_f32: ['1.28615212 ', '0.17574076 '], time:0.00932550ms
out_f32x4: ['1.28615212 ', '0.17574076 '], time:0.00886083ms
out_f32_th: ['1.28615212 ', '0.17574076 '], time:0.01790905ms
-------------------------------------------------------------------------------------
out_f16: ['1.28613281 ', '0.17578125 '], time:0.00924945ms
out_f16x2: ['1.28613281 ', '0.17578125 '], time:0.00908136ms
out_f16x8: ['1.28613281 ', '0.17578125 '], time:0.00614285ms
out_f16x8pack: ['1.28613281 ', '0.17578125 '], time:0.00555348ms
out_f16_th: ['1.28613281 ', '0.17578125 '], time:0.01063919ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=1024
out_f32: ['2.739393 ', '0.04897798 '], time:0.00654101ms
out_f32x4: ['2.739393 ', '0.04897798 '], time:0.00531363ms
out_f32_th: ['2.739393 ', '0.04897798 '], time:0.01048684ms
-------------------------------------------------------------------------------------
out_f16: ['2.73632812 ', '0.04898071 '], time:0.00676942ms
out_f16x2: ['2.73632812 ', '0.04898071 '], time:0.00383520ms
out_f16x8: ['2.73632812 ', '0.04898071 '], time:0.00384569ms
out_f16x8pack: ['2.73632812 ', '0.04898071 '], time:0.00372910ms
out_f16_th: ['2.73828125 ', '0.04898071 '], time:0.00684285ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=2048
out_f32: ['0.00921244 ', '-0.36749047 '], time:0.00932741ms
out_f32x4: ['0.00921244 ', '-0.36749047 '], time:0.00964785ms
out_f32_th: ['0.00921244 ', '-0.36749047 '], time:0.01939940ms
-------------------------------------------------------------------------------------
out_f16: ['0.00920868 ', '-0.36743164 '], time:0.00925016ms
out_f16x2: ['0.00920868 ', '-0.36743164 '], time:0.00796676ms
out_f16x8: ['0.00920868 ', '-0.36743164 '], time:0.00587964ms
out_f16x8pack: ['0.00920868 ', '-0.36743164 '], time:0.00551319ms
out_f16_th: ['0.00920868 ', '-0.36743164 '], time:0.01064467ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=2048, K=4096
out_f32: ['1.23418164 ', '-0.35043269 '], time:0.01706600ms
out_f32x4: ['1.23418164 ', '-0.35043269 '], time:0.01722789ms
out_f32_th: ['1.23418164 ', '-0.35043269 '], time:0.09308505ms
-------------------------------------------------------------------------------------
out_f16: ['1.23535156 ', '-0.35058594 '], time:0.01689029ms
out_f16x2: ['1.23535156 ', '-0.35058594 '], time:0.01665306ms
out_f16x8: ['1.23535156 ', '-0.35058594 '], time:0.01000905ms
out_f16x8pack: ['1.23535156 ', '-0.35058594 '], time:0.00916457ms
out_f16_th: ['1.234375 ', '-0.3503418 '], time:0.01818967ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=1024
out_f32: ['1.13679588 ', '2.5862627 '], time:0.01175451ms
out_f32x4: ['1.13679588 ', '2.5862627 '], time:0.00892878ms
out_f32_th: ['1.13679588 ', '2.5862627 '], time:0.01798749ms
-------------------------------------------------------------------------------------
out_f16: ['1.13671875 ', '2.5859375 '], time:0.01221919ms
out_f16x2: ['1.13671875 ', '2.5859375 '], time:0.00619817ms
out_f16x8: ['1.13671875 ', '2.5859375 '], time:0.00586224ms
out_f16x8pack: ['1.13671875 ', '2.5859375 '], time:0.00551724ms
out_f16_th: ['1.13671875 ', '2.5859375 '], time:0.01065254ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=2048
out_f32: ['-0.11715776 ', '-0.03201698 '], time:0.01850605ms
out_f32x4: ['-0.11715776 ', '-0.03201698 '], time:0.01827025ms
out_f32_th: ['-0.11715776 ', '-0.03201698 '], time:0.09311175ms
-------------------------------------------------------------------------------------
out_f16: ['-0.11712646 ', '-0.03201294 '], time:0.01689458ms
out_f16x2: ['-0.11712646 ', '-0.03201294 '], time:0.01446867ms
out_f16x8: ['-0.11712646 ', '-0.03201294 '], time:0.00979257ms
out_f16x8pack: ['-0.11712646 ', '-0.03201294 '], time:0.00915074ms
out_f16_th: ['-0.11712646 ', '-0.03204346 '], time:0.01819777ms
-------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------
S=4096, K=4096
out_f32: ['-0.36491728 ', '1.09765351 '], time:0.14531279ms
out_f32x4: ['-0.36491728 ', '1.09765351 '], time:0.14574075ms
out_f32_th: ['-0.36491728 ', '1.09765351 '], time:0.29305243ms
-------------------------------------------------------------------------------------
out_f16: ['-0.36499023 ', '1.09765625 '], time:0.03205943ms
out_f16x2: ['-0.36499023 ', '1.09765625 '], time:0.03170896ms
out_f16x8: ['-0.36499023 ', '1.09765625 '], time:0.01824880ms
out_f16x8pack: ['-0.36499023 ', '1.09765625 '], time:0.01677060ms
out_f16_th: ['-0.36499023 ', '1.09765625 '], time:0.09315753ms
-------------------------------------------------------------------------------------