Skip to content

Commit

Permalink
*improve SSE4.1 optimizations of class ResizerBf16Bilinear.
Browse files Browse the repository at this point in the history
  • Loading branch information
ermig1979 committed Dec 27, 2024
1 parent 6f7b1e4 commit 1f7f896
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 65 deletions.
2 changes: 1 addition & 1 deletion src/Simd/SimdAvx2ResizerBilinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ namespace Simd
_mm_storeu_ps(pb + dx, _mm_add_ps(_mm_mul_ps(fx0, s0), _mm_mul_ps(fx1, s1)));
}
}
else if (Avx2::SlowGather)
else if (!Avx2::SlowGather)
{
__m256 _1 = _mm256_set1_ps(1.0f);
__m256i _cn = _mm256_set1_epi32((int)cn);
Expand Down
32 changes: 18 additions & 14 deletions src/Simd/SimdBaseResizerBilinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ namespace Simd

//-------------------------------------------------------------------------------------------------

static void EstimateIndexAlpha(const ResParam& param, size_t srcSize, size_t dstSize, size_t channels, int32_t* indices, float* alphas)
static void EstimateIndexAlpha(const ResParam& param, size_t srcSize, size_t dstSize, size_t channels, size_t iaChannels, int32_t* indices, float* alphas)
{
if (param.method == SimdResizeMethodBilinear)
{
Expand All @@ -270,9 +270,9 @@ namespace Simd
index = srcSize - 2;
alpha = 1;
}
for (size_t c = 0; c < channels; c++)
for (size_t c = 0; c < iaChannels; c++)
{
size_t offset = i * channels + c;
size_t offset = i * iaChannels + c;
indices[offset] = (int32_t)(channels * index + c);
alphas[offset] = alpha;
}
Expand All @@ -291,9 +291,9 @@ namespace Simd
index = srcSize - 2;
alpha = 1;
}
for (size_t c = 0; c < channels; c++)
for (size_t c = 0; c < iaChannels; c++)
{
size_t offset = i * channels + c;
size_t offset = i * iaChannels + c;
indices[offset] = (int32_t)(channels * index + c);
alphas[offset] = alpha;
}
Expand All @@ -317,9 +317,9 @@ namespace Simd
index = srcSize - 2;
alpha = 1;
}
for (size_t c = 0; c < channels; c++)
for (size_t c = 0; c < iaChannels; c++)
{
size_t offset = i * channels + c;
size_t offset = i * iaChannels + c;
indices[offset] = (int32_t)(channels * index + c);
alphas[offset] = alpha;
}
Expand All @@ -336,11 +336,11 @@ namespace Simd
{
_ay.Resize(_param.dstH, false, _param.align);
_iy.Resize(_param.dstH, false, _param.align);
EstimateIndexAlpha(_param, _param.srcH, _param.dstH, 1, _iy.data, _ay.data);
EstimateIndexAlpha(_param, _param.srcH, _param.dstH, 1, 1, _iy.data, _ay.data);
size_t rs = _param.dstW * _param.channels;
_ax.Resize(rs, false, _param.align);
_ix.Resize(rs, false, _param.align);
EstimateIndexAlpha(_param, _param.srcW, _param.dstW, _param.channels, _ix.data, _ax.data);
EstimateIndexAlpha(_param, _param.srcW, _param.dstW, _param.channels, _param.channels, _ix.data, _ax.data);
_bx[0].Resize(rs, false, _param.align);
_bx[1].Resize(rs, false, _param.align);
}
Expand Down Expand Up @@ -395,15 +395,19 @@ namespace Simd
ResizerBf16Bilinear::ResizerBf16Bilinear(const ResParam& param)
: Resizer(param)
{
_rowBuf = !(_param.align >= 16 && (_param.channels >= _param.align / 4 || _param.channels == 64));
_ay.Resize(_param.dstH, false, _param.align);
_iy.Resize(_param.dstH, false, _param.align);
EstimateIndexAlpha(_param, _param.srcH, _param.dstH, 1, _iy.data, _ay.data);
size_t rs = _param.dstW * _param.channels;
EstimateIndexAlpha(_param, _param.srcH, _param.dstH, 1, 1, _iy.data, _ay.data);
size_t rs = _param.dstW * (_rowBuf ? _param.channels : 1);
_ax.Resize(rs, false, _param.align);
_ix.Resize(rs, false, _param.align);
EstimateIndexAlpha(_param, _param.srcW, _param.dstW, _param.channels, _ix.data, _ax.data);
_bx[0].Resize(rs, false, _param.align);
_bx[1].Resize(rs, false, _param.align);
EstimateIndexAlpha(_param, _param.srcW, _param.dstW, _param.channels, _rowBuf ? _param.channels : 1, _ix.data, _ax.data);
if (_rowBuf)
{
_bx[0].Resize(rs, false, _param.align);
_bx[1].Resize(rs, false, _param.align);
}
}

void ResizerBf16Bilinear::Run(const uint8_t* src, size_t srcStride, uint8_t* dst, size_t dstStride)
Expand Down
1 change: 1 addition & 0 deletions src/Simd/SimdResizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ namespace Simd
class ResizerBf16Bilinear : public Resizer
{
protected:
bool _rowBuf;
Array32i _ix, _iy;
Array32f _ax, _ay, _bx[2];

Expand Down
137 changes: 89 additions & 48 deletions src/Simd/SimdSse41ResizerBilinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,68 +681,109 @@ namespace Simd

void ResizerBf16Bilinear::Run(const uint16_t* src, size_t srcStride, uint16_t* dst, size_t dstStride)
{
size_t cn = _param.channels, cnF = AlignLo(cn, F);
size_t rs = _param.dstW * cn;
float* pbx[2] = { _bx[0].data, _bx[1].data };
int32_t prev = -2;
size_t rsh = AlignLo(rs, Sse41::F);
size_t cn = _param.channels, cnF = AlignLo(cn, F), cnT = cn - cnF;
__m128 _1 = _mm_set1_ps(1.0f);
for (size_t dy = 0; dy < _param.dstH; dy++, dst += dstStride)
if (_rowBuf)
{
float fy1 = _ay[dy];
float fy0 = 1.0f - fy1;
int32_t sy = _iy[dy];
int32_t k = 0;

if (sy == prev)
k = 2;
else if (sy == prev + 1)
size_t rs = _param.dstW * cn, rsF = AlignLo(rs, Sse41::F);
float* pbx[2] = { _bx[0].data, _bx[1].data };
int32_t prev = -2;
for (size_t dy = 0; dy < _param.dstH; dy++, dst += dstStride)
{
Swap(pbx[0], pbx[1]);
k = 1;
}
float fy1 = _ay[dy];
float fy0 = 1.0f - fy1;
int32_t sy = _iy[dy];
int32_t k = 0;

if (sy == prev)
k = 2;
else if (sy == prev + 1)
{
Swap(pbx[0], pbx[1]);
k = 1;
}

prev = sy;
prev = sy;

for (; k < 2; k++)
{
float* pb = pbx[k];
const uint16_t* ps = src + (sy + k) * srcStride;
size_t dx = 0;
if (cn == cnF)
for (; k < 2; k++)
{
for (; dx < rsh; dx += Sse41::F)
float* pb = pbx[k];
const uint16_t* ps = src + (sy + k) * srcStride;
size_t dx = 0;
if (cn == cnF)
{
const uint16_t* ps0 = ps + _ix[dx];
__m128 s0 = BFloat16ToFloat32(UnpackU16<0>(_mm_loadl_epi64((__m128i*)ps0)));
__m128 s1 = BFloat16ToFloat32(UnpackU16<0>(_mm_loadl_epi64((__m128i*)(ps0 + cn))));
__m128 fx1 = _mm_loadu_ps(_ax.data + dx);
__m128 fx0 = _mm_sub_ps(_1, fx1);
__m128 m0 = _mm_mul_ps(fx0, s0);
__m128 m1 = _mm_mul_ps(fx1, s1);
_mm_store_ps(pb + dx, _mm_add_ps(m0, m1));
for (; dx < rsF; dx += Sse41::F)
{
const uint16_t* ps0 = ps + _ix[dx];
__m128 s0 = BFloat16ToFloat32(UnpackU16<0>(_mm_loadl_epi64((__m128i*)ps0)));
__m128 s1 = BFloat16ToFloat32(UnpackU16<0>(_mm_loadl_epi64((__m128i*)(ps0 + cn))));
__m128 fx1 = _mm_loadu_ps(_ax.data + dx);
__m128 fx0 = _mm_sub_ps(_1, fx1);
__m128 m0 = _mm_mul_ps(fx0, s0);
__m128 m1 = _mm_mul_ps(fx1, s1);
_mm_store_ps(pb + dx, _mm_add_ps(m0, m1));
}
}
for (; dx < rs; dx++)
{
int32_t sx = _ix[dx];
float fx = _ax[dx];
pb[dx] = Base::BFloat16ToFloat32(ps[sx]) * (1.0f - fx) + Base::BFloat16ToFloat32(ps[sx + cn]) * fx;
}
}
for (; dx < rs; dx++)

size_t dx = 0;
__m128 _fy0 = _mm_set1_ps(fy0);
__m128 _fy1 = _mm_set1_ps(fy1);
for (; dx < rsF; dx += Sse41::F)
{
int32_t sx = _ix[dx];
float fx = _ax[dx];
pb[dx] = Base::BFloat16ToFloat32(ps[sx]) * (1.0f - fx) + Base::BFloat16ToFloat32(ps[sx + cn]) * fx;
__m128 m0 = _mm_mul_ps(_mm_load_ps(pbx[0] + dx), _fy0);
__m128 m1 = _mm_mul_ps(_mm_load_ps(pbx[1] + dx), _fy1);
__m128i d0 = Float32ToBFloat16(_mm_add_ps(m0, m1));
_mm_storel_epi64((__m128i*)(dst + dx), _mm_packus_epi32(d0, K_ZERO));
}
for (; dx < rs; dx++)
dst[dx] = Base::Float32ToBFloat16(pbx[0][dx] * fy0 + pbx[1][dx] * fy1);
}

size_t dx = 0;
__m128 _fy0 = _mm_set1_ps(fy0);
__m128 _fy1 = _mm_set1_ps(fy1);
for (; dx < rsh; dx += Sse41::F)
}
else
{
for (size_t dy = 0; dy < _param.dstH; dy++, dst += dstStride)
{
__m128 m0 = _mm_mul_ps(_mm_load_ps(pbx[0] + dx), _fy0);
__m128 m1 = _mm_mul_ps(_mm_load_ps(pbx[1] + dx), _fy1);
__m128i d0 = Float32ToBFloat16(_mm_add_ps(m0, m1));
_mm_storel_epi64((__m128i*)(dst + dx), _mm_packus_epi32(d0, K_ZERO));
__m128 fy1 = _mm_set1_ps(_ay[dy]);
__m128 fy0 = _mm_sub_ps(_1, fy1);
const uint16_t* src0 = src + _iy[dy] * srcStride, * src1 = src0 + srcStride;
for (size_t dx = 0; dx < _param.dstW; dx++)
{
size_t os = _ix[dx], end = os + cnF, od = dx * cn;
__m128 fx1 = _mm_set1_ps(_ax[dx]);
__m128 fx0 = _mm_sub_ps(_1, fx1);
for (; os < end; os += F, od += F)
{
__m128 s00 = BFloat16ToFloat32(UnpackU16<0>(_mm_loadl_epi64((__m128i*)(src0 + os))));
__m128 s01 = BFloat16ToFloat32(UnpackU16<0>(_mm_loadl_epi64((__m128i*)(src0 + os + cn))));
__m128 s10 = BFloat16ToFloat32(UnpackU16<0>(_mm_loadl_epi64((__m128i*)(src1 + os))));
__m128 s11 = BFloat16ToFloat32(UnpackU16<0>(_mm_loadl_epi64((__m128i*)(src1 + os + cn))));
__m128 r0 = _mm_add_ps(_mm_mul_ps(fx0, s00), _mm_mul_ps(fx1, s01));
__m128 r1 = _mm_add_ps(_mm_mul_ps(fx0, s10), _mm_mul_ps(fx1, s11));
__m128i d0 = Float32ToBFloat16(_mm_add_ps(_mm_mul_ps(r0, fy0), _mm_mul_ps(r1, fy1)));
_mm_storel_epi64((__m128i*)(dst + od), _mm_packus_epi32(d0, K_ZERO));
}
if (cnT)
{
os += cnT - F;
od += cnT - F;
__m128 s00 = BFloat16ToFloat32(UnpackU16<0>(_mm_loadl_epi64((__m128i*)(src0 + os))));
__m128 s01 = BFloat16ToFloat32(UnpackU16<0>(_mm_loadl_epi64((__m128i*)(src0 + os + cn))));
__m128 s10 = BFloat16ToFloat32(UnpackU16<0>(_mm_loadl_epi64((__m128i*)(src1 + os))));
__m128 s11 = BFloat16ToFloat32(UnpackU16<0>(_mm_loadl_epi64((__m128i*)(src1 + os + cn))));
__m128 r0 = _mm_add_ps(_mm_mul_ps(fx0, s00), _mm_mul_ps(fx1, s01));
__m128 r1 = _mm_add_ps(_mm_mul_ps(fx0, s10), _mm_mul_ps(fx1, s11));
__m128i d0 = Float32ToBFloat16(_mm_add_ps(_mm_mul_ps(r0, fy0), _mm_mul_ps(r1, fy1)));
_mm_storel_epi64((__m128i*)(dst + od), _mm_packus_epi32(d0, K_ZERO));
}
}
}
for (; dx < rs; dx++)
dst[dx] = Base::Float32ToBFloat16(pbx[0][dx] * fy0 + pbx[1][dx] * fy1);
}
}
}
Expand Down
11 changes: 9 additions & 2 deletions src/Test/TestResize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,17 @@ namespace Test
{
bool result = true;

result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelFloat, 64, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelFloat, 16, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelBf16, 64, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelBf16, 16, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodNearest, SimdResizeChannelFloat, 16, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodNearest, SimdResizeChannelBf16, 16, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelFloat, 10, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelBf16, 10, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelFloat, 3, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelBf16, 4, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelBf16, 3, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelBf16, 2, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelBf16, 1, f1, f2);

return result;

Expand Down

0 comments on commit 1f7f896

Please sign in to comment.