Skip to content

Commit

Permalink
Improve: Warnings for overflows
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Oct 7, 2024
1 parent 470ab82 commit 31a8b4b
Showing 1 changed file with 101 additions and 8 deletions.
109 changes: 101 additions & 8 deletions python/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def stats_fixture():
results["accurate_duration"] = []
results["baseline_duration"] = []
results["simsimd_duration"] = []
results["warnings"] = []
yield results

# Group the errors by (metric, ndim, dtype) to calculate the mean and std error.
Expand Down Expand Up @@ -203,6 +204,9 @@ def stats_fixture():
# Mean and the standard deviation for errors
baseline_errors = errors["relative_baseline_error"]
simsimd_errors = errors["relative_simsimd_error"]
#! On some platforms (like `cp312-musllinux_aarch64`) without casting via `float(x)`
#! the subsequent `:.2e` string formatting code will fail due to:
#! `TypeError: unsupported format string passed to numpy.ndarray.__format__`.
baseline_mean = float(sum(baseline_errors)) / n
simsimd_mean = float(sum(simsimd_errors)) / n
baseline_std = math.sqrt(sum((x - baseline_mean) ** 2 for x in baseline_errors) / n)
Expand Down Expand Up @@ -267,6 +271,16 @@ def stats_fixture():
]
print(tabulate.tabulate(final_results, headers=headers, tablefmt="pretty", showindex=True))

# Show the additional grouped warnings
warnings = results.get("warnings", [])
warnings = sorted(warnings)
warnings = [f"{name}: {message}" for name, message in warnings]
if len(warnings) != 0:
print("\nWarnings:")
unique_warnings, warning_counts = np.unique(warnings, return_counts=True)
for warning, count in zip(unique_warnings, warning_counts):
print(f"- {count}x times: {warning}")


@pytest.hookimpl(tryfirst=True)
def pytest_runtest_makereport(item, call):
Expand Down Expand Up @@ -313,6 +327,24 @@ def collect_errors(
stats["simsimd_duration"].append(simsimd_duration)


def get_current_test():
"""Get's the current test filename, test name, and function name.
Similar metadata can be obtained from the `request` fixture, but this
solution uses environment variables."""
full_name = os.environ.get("PYTEST_CURRENT_TEST").split(" ")[0]
test_file = full_name.split("::")[0].split("/")[-1].split(".py")[0]
test_name = full_name.split("::")[1]
# The `test_name` may look like: "test_dense_i8[cosine-1536-24-50]"
function_name = test_name.split("[")[0]
return test_file, test_name, function_name


def collect_warnings(message: str, stats: dict):
"""Collects warnings for the final report."""
_, _, function_name = get_current_test()
stats["warnings"].append((function_name, message))


# For normalized distances we use the absolute tolerance, because the result is close to zero.
# For unnormalized ones (like squared Euclidean or Jaccard), we use the relative.
SIMSIMD_RTOL = 0.1
Expand Down Expand Up @@ -645,21 +677,22 @@ def test_dense_i8(ndim, metric, stats_fixture):

baseline_kernel, simd_kernel = name_to_kernels(metric)

# Fun fact: SciPy doesn't actually raise an `OverflowError` when overflow happens
# here, instead it raises `ValueError: math domain error` during the `sqrt` operation.
try:
expected_overflow = baseline_kernel(a, b)
except OverflowError:
expected_overflow = OverflowError()
except ValueError:
expected_overflow = ValueError()
accurate_dt, accurate = profile(baseline_kernel, a.astype(np.float64), b.astype(np.float64))
expected_dt, expected = profile(baseline_kernel, a.astype(np.int64), b.astype(np.int64))
result_dt, result = profile(simd_kernel, a, b)

assert int(result) == int(expected), f"Expected {expected}, but got {result} (overflow: {expected_overflow})"
collect_errors(metric, ndim, "int8", accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture)

#! Fun fact: SciPy doesn't actually raise an `OverflowError` when overflow happens
#! here, instead it raises `ValueError: math domain error` during the `sqrt` operation.
try:
expected_overflow = baseline_kernel(a, b)
if np.isinf(expected_overflow):
collect_warnings("Couldn't avoid overflow in SciPy", stats_fixture)
except Exception as e:
collect_warnings(f"Arbitrary error raised in SciPy: {e}", stats_fixture)


@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.skipif(not scipy_available, reason="SciPy is not installed")
Expand Down Expand Up @@ -726,6 +759,66 @@ def test_cosine_zero_vector(ndim, dtype):
assert np.all(result >= 0), f"Negative result for cosine distance"


@pytest.mark.skip() # TODO: https://github.com/ashvardanian/SimSIMD/issues/206
@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.parametrize("ndim", [11, 97, 1536])
@pytest.mark.parametrize("dtype", ["float64", "float32", "float16"])
@pytest.mark.parametrize("metric", ["inner", "euclidean", "sqeuclidean", "cosine"])
def test_overflow(ndim, dtype, metric):
"""Tests if the floating-point kernels are capable of detecting overflow yield the same ±inf result."""

np.random.seed()
a = np.random.randn(ndim)
b = np.random.randn(ndim)

# Replace scalar at random position with infinity
a[np.random.randint(ndim)] = np.inf
a = a.astype(dtype)
b = b.astype(dtype)

baseline_kernel, simd_kernel = name_to_kernels(metric)
result = simd_kernel(a, b)
assert np.isinf(result), f"Expected ±inf, but got {result}"

#! In the Euclidean (L2) distance, SciPy raises a `ValueError` from the underlying
#! NumPy function: `ValueError: array must not contain infs or NaNs`.
try:
expected_overflow = baseline_kernel(a, b)
if not np.isinf(expected_overflow):
collect_warnings("Overflow not detected in SciPy", stats_fixture)
except Exception as e:
collect_warnings(f"Arbitrary error raised in SciPy: {e}", stats_fixture)


@pytest.mark.skip() # TODO: https://github.com/ashvardanian/SimSIMD/issues/206
@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
@pytest.mark.parametrize("ndim", [131072, 262144])
@pytest.mark.parametrize("metric", ["inner", "euclidean", "sqeuclidean", "cosine"])
def test_overflow_i8(ndim, metric):
"""Tests if the integral kernels are capable of detecting overflow yield the same ±inf result,
as with 2^16 elements accumulating "u32(u16(u8)*u16(u8))+u32" products should overflow and the
same is true for 2^17 elements with "i32(i15(i8))*i32(i15(i8))" products.
"""

np.random.seed()
a = np.full(ndim, fill_value=-128, dtype=np.int8)
b = np.full(ndim, fill_value=-128, dtype=np.int8)

baseline_kernel, simd_kernel = name_to_kernels(metric)
expected = baseline_kernel(a, b)
result = simd_kernel(a, b)
assert np.isinf(result), f"Expected ±inf, but got {result}"

try:
expected_overflow = baseline_kernel(a, b)
if not np.isinf(expected_overflow):
collect_warnings("Overflow not detected in SciPy", stats_fixture)
except Exception as e:
collect_warnings(f"Arbitrary error raised in SciPy: {e}", stats_fixture)


@pytest.mark.skipif(is_running_under_qemu(), reason="Complex math in QEMU fails")
@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
@pytest.mark.repeat(50)
Expand Down

0 comments on commit 31a8b4b

Please sign in to comment.