Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
HanGuo97 committed Nov 17, 2024
1 parent b7c4c20 commit 3f0c4bb
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions tests/higgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,12 @@ def test_vector_dequantize() -> None:
dtype=dtype,
device=device)
grid = torch.randn(
num_codes,
vector_size,
(num_codes, vector_size),
dtype=dtype,
device=device)

outputs = vector_dequantize(
weight_higgs=weight_higgs.int(),
weight_higgs=weight_higgs,
scales_higgs=scales_higgs,
grid=grid,
num_bits=num_bits,
Expand All @@ -91,10 +90,9 @@ def test_vector_dequantize() -> None:
device=device)

outputs_higgs = vector_dequantize_higgs(
weight_higgs=weight_higgs,
weight_higgs=weight_higgs.int(),
scales_higgs=scales_higgs,
grid=grid)

if not (outputs == outputs_higgs.T).all():
raise ValueError

0 comments on commit 3f0c4bb

Please sign in to comment.