Skip to content

Commit

Permalink
Improve: Faster FMA on Haswell
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Oct 27, 2024
1 parent d89a81c commit 40a5c38
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 4 deletions.
185 changes: 181 additions & 4 deletions include/simsimd/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,22 +497,199 @@ SIMSIMD_PUBLIC void simsimd_fma_bf16_haswell( /
SIMSIMD_PUBLIC void simsimd_wsum_i8_haswell( //
simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_size_t n, //
simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result) {
simsimd_wsum_i8_serial(a, b, n, alpha, beta, result); // TODO

simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha;
simsimd_f32_t beta_f32 = (simsimd_f32_t)beta;
__m256 alpha_vec = _mm256_set1_ps(alpha_f32);
__m256 beta_vec = _mm256_set1_ps(beta_f32);

// The main loop:
simsimd_size_t i = 0;
for (; i + 8 <= n; i += 8) {
//? Handling loads and stores with SIMD is tricky. Not because of upcasting, but the
//? downcasting at the end of the loop. In AVX2 it's a drag! Keep it for another day.
__m256 a_vec, b_vec, c_vec;
a_vec[0] = a[i + 0], a_vec[1] = a[i + 1], a_vec[2] = a[i + 2], a_vec[3] = a[i + 3], //
a_vec[4] = a[i + 4], a_vec[5] = a[i + 5], a_vec[6] = a[i + 6], a_vec[7] = a[i + 7];
b_vec[0] = b[i + 0], b_vec[1] = b[i + 1], b_vec[2] = b[i + 2], b_vec[3] = b[i + 3], //
b_vec[4] = b[i + 4], b_vec[5] = b[i + 5], b_vec[6] = b[i + 6], b_vec[7] = b[i + 7];
// The normal part.
__m256 a_scaled = _mm256_mul_ps(a_vec, alpha_vec);
__m256 b_scaled = _mm256_mul_ps(b_vec, beta_vec);
__m256 sum_vec = _mm256_add_ps(a_scaled, b_scaled);
// Instead of serial calls to expensive `SIMSIMD_F32_TO_U8`, convert and clip with SIMD.
__m256i sum_i32_vec = _mm256_cvtps_epi32(sum_vec);
sum_i32_vec = _mm256_max_epi32(sum_i32_vec, _mm256_set1_epi32(-128));
sum_i32_vec = _mm256_min_epi32(sum_i32_vec, _mm256_set1_epi32(127));
// Export into a serial buffer.
int sum_i32s[8];
_mm256_storeu_si256((__m256i *)sum_i32s, sum_i32_vec);
result[i + 0] = (simsimd_i8_t)sum_i32s[0];
result[i + 1] = (simsimd_i8_t)sum_i32s[1];
result[i + 2] = (simsimd_i8_t)sum_i32s[2];
result[i + 3] = (simsimd_i8_t)sum_i32s[3];
result[i + 4] = (simsimd_i8_t)sum_i32s[4];
result[i + 5] = (simsimd_i8_t)sum_i32s[5];
result[i + 6] = (simsimd_i8_t)sum_i32s[6];
result[i + 7] = (simsimd_i8_t)sum_i32s[7];
}

// The tail:
for (; i < n; ++i) {
simsimd_f32_t ai = a[i], bi = b[i];
simsimd_f32_t sum = alpha_f32 * ai + beta_f32 * bi;
SIMSIMD_F32_TO_I8(sum, result + i);
}
}

SIMSIMD_PUBLIC void simsimd_wsum_u8_haswell( //
simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_size_t n, //
simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result) {
simsimd_wsum_u8_serial(a, b, n, alpha, beta, result); // TODO

simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha;
simsimd_f32_t beta_f32 = (simsimd_f32_t)beta;
__m256 alpha_vec = _mm256_set1_ps(alpha_f32);
__m256 beta_vec = _mm256_set1_ps(beta_f32);

// The main loop:
simsimd_size_t i = 0;
for (; i + 8 <= n; i += 8) {
//? Handling loads and stores with SIMD is tricky. Not because of upcasting, but the
//? downcasting at the end of the loop. In AVX2 it's a drag! Keep it for another day.
__m256 a_vec, b_vec, c_vec;
a_vec[0] = a[i + 0], a_vec[1] = a[i + 1], a_vec[2] = a[i + 2], a_vec[3] = a[i + 3], //
a_vec[4] = a[i + 4], a_vec[5] = a[i + 5], a_vec[6] = a[i + 6], a_vec[7] = a[i + 7];
b_vec[0] = b[i + 0], b_vec[1] = b[i + 1], b_vec[2] = b[i + 2], b_vec[3] = b[i + 3], //
b_vec[4] = b[i + 4], b_vec[5] = b[i + 5], b_vec[6] = b[i + 6], b_vec[7] = b[i + 7];
// The normal part.
__m256 a_scaled = _mm256_mul_ps(a_vec, alpha_vec);
__m256 b_scaled = _mm256_mul_ps(b_vec, beta_vec);
__m256 sum_vec = _mm256_add_ps(a_scaled, b_scaled);
// Instead of serial calls to expensive `SIMSIMD_F32_TO_U8`, convert and clip with SIMD.
__m256i sum_i32_vec = _mm256_cvtps_epi32(sum_vec);
sum_i32_vec = _mm256_max_epi32(sum_i32_vec, _mm256_set1_epi32(0));
sum_i32_vec = _mm256_min_epi32(sum_i32_vec, _mm256_set1_epi32(255));
// Export into a serial buffer.
int sum_i32s[8];
_mm256_storeu_si256((__m256i *)sum_i32s, sum_i32_vec);
result[i + 0] = (simsimd_u8_t)sum_i32s[0];
result[i + 1] = (simsimd_u8_t)sum_i32s[1];
result[i + 2] = (simsimd_u8_t)sum_i32s[2];
result[i + 3] = (simsimd_u8_t)sum_i32s[3];
result[i + 4] = (simsimd_u8_t)sum_i32s[4];
result[i + 5] = (simsimd_u8_t)sum_i32s[5];
result[i + 6] = (simsimd_u8_t)sum_i32s[6];
result[i + 7] = (simsimd_u8_t)sum_i32s[7];
}

// The tail:
for (; i < n; ++i) {
simsimd_f32_t ai = a[i], bi = b[i];
simsimd_f32_t sum = alpha_f32 * ai + beta_f32 * bi;
SIMSIMD_F32_TO_U8(sum, result + i);
}
}

SIMSIMD_PUBLIC void simsimd_fma_i8_haswell( //
simsimd_i8_t const *a, simsimd_i8_t const *b, simsimd_i8_t const *c, simsimd_size_t n, //
simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_i8_t *result) {
simsimd_fma_i8_serial(a, b, c, n, alpha, beta, result); // TODO

simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha;
simsimd_f32_t beta_f32 = (simsimd_f32_t)beta;
__m256 alpha_vec = _mm256_set1_ps(alpha_f32);
__m256 beta_vec = _mm256_set1_ps(beta_f32);

// The main loop:
simsimd_size_t i = 0;
for (; i + 8 <= n; i += 8) {
//? Handling loads and stores with SIMD is tricky. Not because of upcasting, but the
//? downcasting at the end of the loop. In AVX2 it's a drag! Keep it for another day.
__m256 a_vec, b_vec, c_vec;
a_vec[0] = a[i + 0], a_vec[1] = a[i + 1], a_vec[2] = a[i + 2], a_vec[3] = a[i + 3], //
a_vec[4] = a[i + 4], a_vec[5] = a[i + 5], a_vec[6] = a[i + 6], a_vec[7] = a[i + 7];
b_vec[0] = b[i + 0], b_vec[1] = b[i + 1], b_vec[2] = b[i + 2], b_vec[3] = b[i + 3], //
b_vec[4] = b[i + 4], b_vec[5] = b[i + 5], b_vec[6] = b[i + 6], b_vec[7] = b[i + 7];
c_vec[0] = c[i + 0], c_vec[1] = c[i + 1], c_vec[2] = c[i + 2], c_vec[3] = c[i + 3], //
c_vec[4] = c[i + 4], c_vec[5] = c[i + 5], c_vec[6] = c[i + 6], c_vec[7] = c[i + 7];
// The normal part.
__m256 ab_vec = _mm256_mul_ps(a_vec, b_vec);
__m256 ab_scaled_vec = _mm256_mul_ps(ab_vec, alpha_vec);
__m256 c_scaled_vec = _mm256_mul_ps(c_vec, beta_vec);
__m256 sum_vec = _mm256_add_ps(ab_scaled_vec, c_scaled_vec);
// Instead of serial calls to expensive `SIMSIMD_F32_TO_U8`, convert and clip with SIMD.
__m256i sum_i32_vec = _mm256_cvtps_epi32(sum_vec);
sum_i32_vec = _mm256_max_epi32(sum_i32_vec, _mm256_set1_epi32(-128));
sum_i32_vec = _mm256_min_epi32(sum_i32_vec, _mm256_set1_epi32(127));
// Export into a serial buffer.
int sum_i32s[8];
_mm256_storeu_si256((__m256i *)sum_i32s, sum_i32_vec);
result[i + 0] = (simsimd_i8_t)sum_i32s[0];
result[i + 1] = (simsimd_i8_t)sum_i32s[1];
result[i + 2] = (simsimd_i8_t)sum_i32s[2];
result[i + 3] = (simsimd_i8_t)sum_i32s[3];
result[i + 4] = (simsimd_i8_t)sum_i32s[4];
result[i + 5] = (simsimd_i8_t)sum_i32s[5];
result[i + 6] = (simsimd_i8_t)sum_i32s[6];
result[i + 7] = (simsimd_i8_t)sum_i32s[7];
}

// The tail:
for (; i < n; ++i) {
simsimd_f32_t ai = a[i], bi = b[i], ci = c[i];
simsimd_f32_t sum = alpha_f32 * ai * bi + beta_f32 * ci;
SIMSIMD_F32_TO_I8(sum, result + i);
}
}

SIMSIMD_PUBLIC void simsimd_fma_u8_haswell( //
simsimd_u8_t const *a, simsimd_u8_t const *b, simsimd_u8_t const *c, simsimd_size_t n, //
simsimd_distance_t alpha, simsimd_distance_t beta, simsimd_u8_t *result) {
simsimd_fma_u8_serial(a, b, c, n, alpha, beta, result); // TODO

simsimd_f32_t alpha_f32 = (simsimd_f32_t)alpha;
simsimd_f32_t beta_f32 = (simsimd_f32_t)beta;
__m256 alpha_vec = _mm256_set1_ps(alpha_f32);
__m256 beta_vec = _mm256_set1_ps(beta_f32);

// The main loop:
simsimd_size_t i = 0;
for (; i + 8 <= n; i += 8) {
//? Handling loads and stores with SIMD is tricky. Not because of upcasting, but the
//? downcasting at the end of the loop. In AVX2 it's a drag! Keep it for another day.
__m256 a_vec, b_vec, c_vec;
a_vec[0] = a[i + 0], a_vec[1] = a[i + 1], a_vec[2] = a[i + 2], a_vec[3] = a[i + 3], //
a_vec[4] = a[i + 4], a_vec[5] = a[i + 5], a_vec[6] = a[i + 6], a_vec[7] = a[i + 7];
b_vec[0] = b[i + 0], b_vec[1] = b[i + 1], b_vec[2] = b[i + 2], b_vec[3] = b[i + 3], //
b_vec[4] = b[i + 4], b_vec[5] = b[i + 5], b_vec[6] = b[i + 6], b_vec[7] = b[i + 7];
c_vec[0] = c[i + 0], c_vec[1] = c[i + 1], c_vec[2] = c[i + 2], c_vec[3] = c[i + 3], //
c_vec[4] = c[i + 4], c_vec[5] = c[i + 5], c_vec[6] = c[i + 6], c_vec[7] = c[i + 7];
// The normal part.
__m256 ab_vec = _mm256_mul_ps(a_vec, b_vec);
__m256 ab_scaled_vec = _mm256_mul_ps(ab_vec, alpha_vec);
__m256 c_scaled_vec = _mm256_mul_ps(c_vec, beta_vec);
__m256 sum_vec = _mm256_add_ps(ab_scaled_vec, c_scaled_vec);
// Instead of serial calls to expensive `SIMSIMD_F32_TO_U8`, convert and clip with SIMD.
__m256i sum_i32_vec = _mm256_cvtps_epi32(sum_vec);
sum_i32_vec = _mm256_max_epi32(sum_i32_vec, _mm256_set1_epi32(0));
sum_i32_vec = _mm256_min_epi32(sum_i32_vec, _mm256_set1_epi32(255));
// Export into a serial buffer.
int sum_i32s[8];
_mm256_storeu_si256((__m256i *)sum_i32s, sum_i32_vec);
result[i + 0] = (simsimd_u8_t)sum_i32s[0];
result[i + 1] = (simsimd_u8_t)sum_i32s[1];
result[i + 2] = (simsimd_u8_t)sum_i32s[2];
result[i + 3] = (simsimd_u8_t)sum_i32s[3];
result[i + 4] = (simsimd_u8_t)sum_i32s[4];
result[i + 5] = (simsimd_u8_t)sum_i32s[5];
result[i + 6] = (simsimd_u8_t)sum_i32s[6];
result[i + 7] = (simsimd_u8_t)sum_i32s[7];
}

// The tail:
for (; i < n; ++i) {
simsimd_f32_t ai = a[i], bi = b[i], ci = c[i];
simsimd_f32_t sum = alpha_f32 * ai * bi + beta_f32 * ci;
SIMSIMD_F32_TO_U8(sum, result + i);
}
}

#pragma clang attribute pop
Expand Down
4 changes: 4 additions & 0 deletions scripts/bench.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,10 @@ int main(int argc, char **argv) {
fma_<f16_k>("wsum_f16_haswell", simsimd_wsum_f16_haswell, simsimd_wsum_f16_accurate, simsimd_l2_f16_accurate);
fma_<bf16_k>("fma_bf16_haswell", simsimd_fma_bf16_haswell, simsimd_fma_bf16_accurate, simsimd_l2_bf16_accurate);
fma_<bf16_k>("wsum_bf16_haswell", simsimd_wsum_bf16_haswell, simsimd_wsum_bf16_accurate, simsimd_l2_bf16_accurate);
fma_<i8_k>("fma_i8_haswell", simsimd_fma_i8_haswell, simsimd_fma_i8_accurate, simsimd_l2_i8_serial);
fma_<i8_k>("wsum_i8_haswell", simsimd_wsum_i8_haswell, simsimd_wsum_i8_accurate, simsimd_l2_i8_serial);
fma_<u8_k>("fma_u8_haswell", simsimd_fma_u8_haswell, simsimd_fma_u8_accurate, simsimd_l2_u8_serial);
fma_<u8_k>("wsum_u8_haswell", simsimd_wsum_u8_haswell, simsimd_wsum_u8_accurate, simsimd_l2_u8_serial);

#endif

Expand Down

0 comments on commit 40a5c38

Please sign in to comment.