You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been trying to integrate use this kernel to compute loss/grads with large activations and vocabularies. I noticed significant changes in the runtime of the kernel based on the distribution of the input data.
This sounds a bit insane, but I have a small reproducer that shows this behaviour:
importcut_cross_entropyascce_libimporttorchimporttimedevice=torch.device('cuda:0')
defbenchmark_fn(embeddings, vocab, labels):
# clone, set requires_grad, and cast to device (though, last should be not needed)embeddings, vocab=torch.clone(embeddings).requires_grad_(True).to(device), torch.clone(vocab).requires_grad_(True).to(device)
labels=torch.clone(labels).to(device)
torch.cuda.synchronize()
start_time=time.time()
loss=cce_lib.linear_cross_entropy(embeddings, vocab, labels)
loss.backward()
torch.cuda.synchronize()
print(time.time() -start_time)
# generatee inputsembeddings=torch.randn(4*8192, 4096, device=device, dtype=torch.bfloat16)
vocab=torch.randn(256_000, 4096, device=device, dtype=torch.bfloat16)
labels=torch.randint(0, 256_000, (4*8192,), device=device)
# compile call, exclude from loopsprint("first call")
benchmark_fn(embeddings, vocab, labels)
print("regular inputs")
for_inrange(5):
benchmark_fn(embeddings, vocab, labels)
print("scaled down weights result in slowdown!")
vocab=vocab*1/8for_inrange(5):
benchmark_fn(embeddings, vocab, labels)
vocab=vocab*8print("scaled up weights do not.")
vocab=vocab*8for_inrange(5):
benchmark_fn(embeddings, vocab, labels)
vocab=vocab*1/8print("scaled down activations also cause slowdown")
embeddings=embeddings*1/8for_inrange(5):
benchmark_fn(embeddings, vocab, labels)
embeddings=embeddings*8print("but not scaled up")
embeddings=embeddings*8for_inrange(5):
benchmark_fn(embeddings, vocab, labels)
embeddings=embeddings*1/8
which will output:
first call
1.8656930923461914
regular inputs
0.4381716251373291
0.4551575183868408
0.4390103816986084
0.44695091247558594
0.44449806213378906
scaled down weights result in slowdown!
1.2427277565002441
1.2380731105804443
1.2460360527038574
1.2462165355682373
1.245213270187378
scaled up weights do not.
0.4104018211364746
0.41744208335876465
0.42533063888549805
0.41300249099731445
0.4272499084472656
scaled down activations also cause slowdown
1.2431302070617676
1.2402148246765137
1.2415409088134766
1.2405052185058594
1.2423794269561768
but not scaled up
0.41109371185302734
0.41783595085144043
0.42434120178222656
0.4148707389831543
0.42697978019714355
By simply scaling down either the weights or activations by a factor of 8, we get a completely different runtime.
I am kinda at a loss (hah) to understand why this could happen! I have checked by passing additional debug flags that no additional recompilations are happening, so this is a single compiled kernel with this behaviour.
Relevant parts of my env:
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 18.1.8 (++20240731024944+3b5b5c1ec4a3-1~exp1~20240731145000.144)
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.10.11 (main, Nov 17 2024, 19:27:51) [GCC 11.4.0] (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 12.6.68
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
Nvidia driver version: 535.216.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.5.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.5.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==2.2.0
[pip3] torch==2.5.1
[pip3] triton==3.1.0
[conda] Could not collect
Thank you for your assistance 🙏
The text was updated successfully, but these errors were encountered:
I've been trying to integrate use this kernel to compute loss/grads with large activations and vocabularies. I noticed significant changes in the runtime of the kernel based on the distribution of the input data.
This sounds a bit insane, but I have a small reproducer that shows this behaviour:
which will output:
By simply scaling down either the weights or activations by a factor of 8, we get a completely different runtime.
I am kinda at a loss (hah) to understand why this could happen! I have checked by passing additional debug flags that no additional recompilations are happening, so this is a single compiled kernel with this behaviour.
Relevant parts of my env:
Thank you for your assistance 🙏
The text was updated successfully, but these errors were encountered: