Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Extract raw QK logits from prefix_prefill triton kernel. #44

Draft
wants to merge 3 commits into
base: compact
Choose a base branch
from

Conversation

SolitaryThinker
Copy link

@SolitaryThinker SolitaryThinker commented Dec 9, 2024

First attempt at logit extraction. Does not include the softmax normalization. 15-30% slowdown compared to unmodified triton kernel.

Note: this PR currently contains the version that casts to bf16 before writing out logits

@SolitaryThinker
Copy link
Author

SolitaryThinker commented Dec 9, 2024

BF16
Raw benchmark numbers over 1k iterations of kernel

Format: batch_size, seq_len, block_size, head_size, expose_time(ms), unmodified_time(ms), slowdown(%)
1, 128, 16, 128, 0.12, 0.10, 17.82
1, 512, 16, 128, 0.11, 0.10, 16.11
1, 1024, 16, 128, 0.26, 0.21, 22.69
1, 2048, 16, 128, 0.76, 0.64, 19.21
1, 4096, 16, 128, 2.32, 1.99, 16.25
4, 128, 16, 128, 0.11, 0.10, 14.86
4, 512, 16, 128, 0.24, 0.18, 31.01
4, 1024, 16, 128, 0.69, 0.57, 22.17
4, 2048, 16, 128, 2.47, 2.09, 18.57
4, 4096, 16, 128, 9.00, 7.72, 16.59
8, 128, 16, 128, 0.11, 0.10, 15.53
8, 512, 16, 128, 0.40, 0.31, 28.70
8, 1024, 16, 128, 1.29, 1.08, 19.40
8, 2048, 16, 128, 4.66, 3.91, 19.08
8, 4096, 16, 128, 17.57, 15.12, 16.18
16, 128, 16, 128, 0.17, 0.14, 14.64
16, 512, 16, 128, 0.74, 0.59, 24.85
16, 1024, 16, 128, 2.45, 2.05, 19.45
16, 2048, 16, 128, 9.16, 7.71, 18.74
16, 4096, 16, 128, 34.91, 30.07, 16.09
32, 128, 16, 128, 0.32, 0.28, 14.10
32, 512, 16, 128, 1.39, 1.12, 24.71
32, 1024, 16, 128, 4.81, 4.02, 19.43
32, 2048, 16, 128, 18.20, 15.40, 18.13

@SolitaryThinker
Copy link
Author

float32
Seems like the perf is slightly better if we don't cast to bf16

Format: batch_size, seq_len, block_size, head_size, expose_time, prefix_time, slowdown
1, 128, 16, 128, 0.11, 0.09, 17.42
1, 512, 16, 128, 0.11, 0.09, 17.41
1, 1024, 16, 128, 0.25, 0.21, 18.61
1, 2048, 16, 128, 0.71, 0.64, 12.35
1, 4096, 16, 128, 2.23, 2.00, 11.65
4, 128, 16, 128, 0.11, 0.10, 16.93
4, 512, 16, 128, 0.22, 0.18, 18.21
4, 1024, 16, 128, 0.65, 0.57, 14.95
4, 2048, 16, 128, 2.35, 2.09, 12.52
4, 4096, 16, 128, 8.67, 7.72, 12.31
8, 128, 16, 128, 0.11, 0.10, 16.04
8, 512, 16, 128, 0.36, 0.31, 16.64
8, 1024, 16, 128, 1.23, 1.08, 13.28
8, 2048, 16, 128, 4.41, 3.91, 12.66
8, 4096, 16, 128, 16.98, 15.12, 12.29
16, 128, 16, 128, 0.16, 0.14, 13.14
16, 512, 16, 128, 0.68, 0.59, 14.74
16, 1024, 16, 128, 2.32, 2.05, 13.28
16, 2048, 16, 128, 8.70, 7.72, 12.76
16, 4096, 16, 128, 33.72, 30.06, 12.16
32, 128, 16, 128, 0.32, 0.28, 12.68
32, 512, 16, 128, 1.28, 1.12, 15.08
32, 1024, 16, 128, 4.59, 4.02, 14.07
32, 2048, 16, 128, 17.32, 15.40, 12.45

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant