Skip to content

Commit

Permalink
Fix number of return values for Flash2 (fairinternal/xformers#1279)
Browse files Browse the repository at this point in the history
Probably we should have a test somewhere, but it becomes hard to cover all combinations of {flash version}x{GPU}x{PyTorch version}x{Built-in Flash / PT flash / third-party flash}

__original_commit__ = fairinternal/xformers@9c39da4
  • Loading branch information
danthe3rd authored and xFormers Bot committed Dec 30, 2024
1 parent a8746f3 commit 46a02df
Showing 1 changed file with 3 additions and 21 deletions.
24 changes: 3 additions & 21 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)

FLASH_VERSION = flash_attn.__version__
FLASH_VER_MIN = (2, 6, 3)
FLASH_VER_MIN = (2, 7, 1)
FLASH_VER_LAST = (2, 7, 2) # last supported, inclusive
flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
if (
Expand Down Expand Up @@ -136,16 +136,7 @@ def _flash_fwd(
if cu_seqlens_q is None:
assert cu_seqlens_k is None
assert seqused_k is None
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
rng_state,
) = _C_flashattention.fwd(
out, softmax_lse, p, rng_state = _C_flashattention.fwd(
query,
key,
value,
Expand All @@ -161,16 +152,7 @@ def _flash_fwd(
None, # rng
)
else:
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
rng_state,
) = _C_flashattention.varlen_fwd(
out, softmax_lse, p, rng_state = _C_flashattention.varlen_fwd(
query,
key,
value,
Expand Down

0 comments on commit 46a02df

Please sign in to comment.