Skip to content

Commit

Permalink
+add AVX-512BW optimizations of class ResizerBf16Bilinear (part 8: la…
Browse files Browse the repository at this point in the history
…rge channels case).
  • Loading branch information
ermig1979 committed Jan 14, 2025
1 parent 5328b2f commit ff7d4dd
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 40 deletions.
66 changes: 28 additions & 38 deletions src/Simd/SimdAvx512bwResizerBilinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,14 +921,10 @@ namespace Simd
{
size_t cn = _param.channels, cnD = AlignLo(cn, DF), cnF = AlignLo(cn, F), cnH = AlignLo(cn, F / 2);
__mmask16 cnMF = TailMask16(cn - cnF);
__mmask8 cnMH = TailMask8(cn - cnH);
__m512 _1 = _mm512_set1_ps(1.0f);
if (_rowBuf)
{
if (cn > 4)
{
Avx2::ResizerBf16Bilinear::Run(src, srcStride, dst, dstStride);
return;
}
size_t rs = _param.dstW * cn, rsQ = AlignLo(rs, Sse41::F), rsH = AlignLo(rs, Avx2::F), rsF = AlignLo(rs, F), rsD = AlignLo(rs, DF);
size_t rs3 = AlignLoAny(rs - 1, 3), rs6 = AlignLoAny(rs - 1, 6), rs12 = AlignLoAny(rs - 1, 12);
__mmask16 rsMF = TailMask16(rs - rsF);
Expand Down Expand Up @@ -975,7 +971,6 @@ namespace Simd
__m512 fx1 = _mm512_maskz_loadu_ps(rsMF, _ax.data + dx);
__m512 fx0 = _mm512_sub_ps(_1, fx1);
_mm512_mask_storeu_ps(pb + dx, rsMF, _mm512_fmadd_ps(fx0, s0, _mm512_mul_ps(fx1, s1)));
dx = rs;
}
}
else if (cn == 2)
Expand All @@ -997,7 +992,6 @@ namespace Simd
__m512 fx1 = _mm512_loadu_ps(_ax.data + dx);
__m512 fx0 = _mm512_sub_ps(_1, fx1);
_mm512_storeu_ps(pb + dx, _mm512_fmadd_ps(fx0, s0, _mm512_mul_ps(fx1, s1)));
dx = rs;
}
}
else if (cn == 3 && rs >= 3)
Expand Down Expand Up @@ -1027,6 +1021,12 @@ namespace Simd
__m128 fx0 = _mm_sub_ps(_mm512_castps512_ps128(_1), fx1);
_mm_storeu_ps(pb + dx, BilinearRowSumBf16(ps + _ix[dx], cn, fx0, fx1));
}
for (; dx < rs; dx++)
{
int32_t sx = _ix[dx];
float fx = _ax[dx];
pb[dx] = Base::BFloat16ToFloat32(ps[sx]) * (1.0f - fx) + Base::BFloat16ToFloat32(ps[sx + cn]) * fx;
}
}
else if (cn == 4)
{
Expand All @@ -1045,37 +1045,28 @@ namespace Simd
_mm_storeu_ps(pb + dx, _mm_add_ps(_mm_mul_ps(fx0, Sse41::BFloat16ToFloat32<0>(_src)), _mm_mul_ps(fx1, Sse41::BFloat16ToFloat32<1>(_src))));
}
}
// if (cn >= 8)
// {
// for (; dx < rs;)
// {
// const uint16_t* ps0 = ps + _ix[dx];
// __m256 fx1 = _mm256_set1_ps(_ax[dx]);
// __m256 fx0 = _mm256_sub_ps(_1, fx1);
// for (size_t end = dx + cnF; dx < end; dx += F, ps0 += F)
// _mm256_storeu_ps(pb + dx, BilinearRowSumBf16(ps0, cn, fx0, fx1));
// if (cnTF)
// _mm256_storeu_ps(pb + dx + cnLF, BilinearRowSumBf16(ps0 + cnLF, cn, fx0, fx1)), dx += cnTF;
// }
// }
// else if (cn > 4)
// {
// for (; dx < rs;)
// {
// const uint16_t* ps0 = ps + _ix[dx];
// __m128 fx1 = _mm_set1_ps(_ax[dx]);
// __m128 fx0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fx1);
// for (size_t end = dx + cnH; dx < end; dx += Sse41::F, ps0 += Sse41::F)
// _mm_storeu_ps(pb + dx, BilinearRowSumBf16(ps0, cn, fx0, fx1));
// if (cnTH)
// _mm_storeu_ps(pb + dx + cnLH, BilinearRowSumBf16(ps0 + cnLH, cn, fx0, fx1)), dx += cnTH;
// }
// }
for (; dx < rs; dx++)
else if (cn <= 8)
{
int32_t sx = _ix[dx];
float fx = _ax[dx];
pb[dx] = Base::BFloat16ToFloat32(ps[sx]) * (1.0f - fx) + Base::BFloat16ToFloat32(ps[sx + cn]) * fx;
for (; dx < rs; dx += cn)
{
__m256 fx1 = _mm256_set1_ps(_ax[dx]);
__m256 fx0 = _mm256_sub_ps(_mm512_castps512_ps256(_1), fx1);
_mm256_mask_storeu_ps(pb + dx, cnMH, BilinearRowSumBf16(ps + _ix[dx], cnMH, cn, fx0, fx1));
}
}
else
{
size_t cnT = cn - cnF;
for (; dx < rs;)
{
const uint16_t* ps0 = ps + _ix[dx];
__m512 fx1 = _mm512_set1_ps(_ax[dx]);
__m512 fx0 = _mm512_sub_ps(_1, fx1);
for (size_t eF = dx + cnF; dx < eF; dx += F, ps0 += F)
_mm512_storeu_ps(pb + dx, BilinearRowSumBf16(ps0, -1, cn, fx0, fx1));
if (cnMF)
_mm512_mask_storeu_ps(pb + dx, cnMF, BilinearRowSumBf16(ps0, cnMF, cn, fx0, fx1)), dx += cnT;
}
}
}

Expand Down Expand Up @@ -1150,7 +1141,6 @@ namespace Simd
}
else if (cn <= HF)
{
__mmask8 cnMH = TailMask16(cn);
for (size_t dy = 0; dy < _param.dstH; dy++, dst += dstStride)
{
__m256 fy1 = _mm256_set1_ps(_ay[dy]);
Expand Down
2 changes: 0 additions & 2 deletions src/Test/TestResize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,6 @@ namespace Test
{
bool result = true;

result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelBf16, 4, f1, f2);

#if 0
#if defined(SIMD_X64_ENABLE)
result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelFloat, 64, f1, f2);
Expand Down

0 comments on commit ff7d4dd

Please sign in to comment.