Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to RVV Concat/Combine ops #2420

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 79 additions & 45 deletions hwy/ops/rvv-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3228,67 +3228,101 @@ Get(D d, VFromD<D> v) {
}
}

#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \
MLEN, NAME, OP) \
template <size_t kIndex> \
HWY_API HWY_RVV_V(BASE, SEW, LMUL) \
NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMULH) v) { \
return __riscv_v##OP##_v_##CHAR##SEW##LMULH##_##CHAR##SEW##LMUL( \
dest, kIndex, v); /* no AVL */ \
#define HWY_RVV_PARTIAL_VEC_SET_HALF(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \
LMULH, SHIFT, MLEN, NAME, OP) \
template <size_t kIndex> \
HWY_API HWY_RVV_V(BASE, SEW, LMUL) \
NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMULH) v, \
size_t half_N) { \
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \
const DFromV<decltype(dest)> d; \
HWY_IF_CONSTEXPR(kIndex == 0) { \
return __riscv_v##OP##_v_v_##CHAR##SEW##LMUL##_tu(dest, Ext(d, v), \
half_N); \
} \
else { \
return SlideUp(dest, Ext(d, v), half_N); \
} \
}
#define HWY_RVV_SET_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \
SHIFT, MLEN, NAME, OP) \
template <size_t kIndex> \
HWY_API HWY_RVV_V(BASE, SEW, LMUL) \
NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMULH) v) { \
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \
auto d = HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT){}; \
auto df2 = \
HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT - 1){}; \
HWY_IF_CONSTEXPR(kIndex == 0) { \
return __riscv_vmv_v_v_##CHAR##SEW##LMUL##_tu(dest, Ext(d, v), \
Lanes(df2)); \
} \
else { \
return SlideUp(dest, Ext(d, v), Lanes(df2)); \
} \
#define HWY_RVV_PARTIAL_VEC_SET_HALF_SMALLEST( \
BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, MLEN, NAME, OP) \
template <size_t kIndex> \
HWY_API HWY_RVV_V(BASE, SEW, LMUL) \
NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMUL) v, \
size_t half_N) { \
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \
HWY_IF_CONSTEXPR(kIndex == 0) { \
return __riscv_v##OP##_v_v_##CHAR##SEW##LMUL##_tu(dest, v, half_N); \
} \
else { \
return SlideUp(dest, v, half_N); \
} \
}
#define HWY_RVV_SET_SMALLEST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \
SHIFT, MLEN, NAME, OP) \
template <size_t kIndex> \
HWY_RVV_FOREACH(HWY_RVV_PARTIAL_VEC_SET_HALF, PartialVecSetHalf, mv, _GET_SET)
HWY_RVV_FOREACH(HWY_RVV_PARTIAL_VEC_SET_HALF, PartialVecSetHalf, mv,
_GET_SET_VIRT)
HWY_RVV_FOREACH(HWY_RVV_PARTIAL_VEC_SET_HALF_SMALLEST, PartialVecSetHalf, mv,
_GET_SET_SMALLEST)
#undef HWY_RVV_PARTIAL_VEC_SET_HALF
#undef HWY_RVV_PARTIAL_VEC_SET_HALF_SMALLEST

#define HWY_RVV_SET(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \
MLEN, NAME, OP) \
template <size_t kIndex, size_t N> \
HWY_API HWY_RVV_V(BASE, SEW, LMUL) \
NAME(HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMUL) v) { \
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1"); \
auto d = HWY_RVV_D(BASE, SEW, HWY_LANES(HWY_RVV_T(BASE, SEW)), SHIFT){}; \
HWY_IF_CONSTEXPR(kIndex == 0) { \
return __riscv_vmv_v_v_##CHAR##SEW##LMUL##_tu(dest, v, Lanes(d) / 2); \
NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) dest, \
HWY_RVV_V(BASE, SEW, LMULH) v) { \
HWY_IF_CONSTEXPR(detail::IsFull(d)) { \
return __riscv_v##OP##_v_##CHAR##SEW##LMULH##_##CHAR##SEW##LMUL( \
dest, kIndex, v); /* no AVL */ \
} \
else { \
return SlideUp(dest, v, Lanes(d) / 2); \
const Half<decltype(d)> dh; \
return PartialVecSetHalf<kIndex>(dest, v, Lanes(dh)); \
} \
}
#define HWY_RVV_SET_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \
SHIFT, MLEN, NAME, OP) \
template <size_t kIndex, size_t N> \
HWY_API HWY_RVV_V(BASE, SEW, LMUL) \
NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) dest, \
HWY_RVV_V(BASE, SEW, LMULH) v) { \
const Half<decltype(d)> dh; \
return PartialVecSetHalf<kIndex>(dest, v, Lanes(dh)); \
}
#define HWY_RVV_SET_SMALLEST(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, \
SHIFT, MLEN, NAME, OP) \
template <size_t kIndex, size_t N> \
HWY_API HWY_RVV_V(BASE, SEW, LMUL) \
NAME(HWY_RVV_D(BASE, SEW, N, SHIFT) d, HWY_RVV_V(BASE, SEW, LMUL) dest, \
HWY_RVV_V(BASE, SEW, LMUL) v) { \
return PartialVecSetHalf<kIndex>(dest, v, Lanes(d) / 2); \
}
#define HWY_RVV_SET_SMALLEST_VIRT(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, \
LMULH, SHIFT, MLEN, NAME, OP) \
template <size_t kIndex, size_t N> \
HWY_API HWY_RVV_V(BASE, SEW, LMUL) \
NAME(HWY_RVV_D(BASE, SEW, N, SHIFT - 1) d, \
HWY_RVV_V(BASE, SEW, LMUL) dest, HWY_RVV_V(BASE, SEW, LMUL) v) { \
return PartialVecSetHalf<kIndex>(dest, v, Lanes(d) / 2); \
}
HWY_RVV_FOREACH(HWY_RVV_SET, Set, set, _GET_SET)
HWY_RVV_FOREACH(HWY_RVV_SET_VIRT, Set, set, _GET_SET_VIRT)
HWY_RVV_FOREACH(HWY_RVV_SET_SMALLEST, Set, set, _GET_SET_SMALLEST)
HWY_RVV_FOREACH_UI163264(HWY_RVV_SET_SMALLEST_VIRT, Set, set, _GET_SET_SMALLEST)
HWY_RVV_FOREACH_F(HWY_RVV_SET_SMALLEST_VIRT, Set, set, _GET_SET_SMALLEST)
#undef HWY_RVV_SET
#undef HWY_RVV_SET_VIRT
#undef HWY_RVV_SET_SMALLEST
#undef HWY_RVV_SET_SMALLEST_VIRT

template <size_t kIndex, class D>
template <size_t kIndex, class D, HWY_RVV_IF_EMULATED_D(D)>
static HWY_INLINE HWY_MAYBE_UNUSED VFromD<D> Set(
D d, VFromD<D> dest, VFromD<AdjustSimdTagToMinVecPow2<Half<D>>> v) {
static_assert(kIndex == 0 || kIndex == 1, "kIndex must be 0 or 1");

const AdjustSimdTagToMinVecPow2<Half<decltype(d)>> dh;
HWY_IF_CONSTEXPR(kIndex == 0 || detail::IsFull(d)) {
(void)dh;
return Set<kIndex>(dest, v);
}
else {
const size_t slide_up_amt =
(dh.Pow2() < DFromV<decltype(v)>().Pow2()) ? Lanes(dh) : (Lanes(d) / 2);
return SlideUp(dest, ResizeBitCast(d, v), slide_up_amt);
}
const RebindToUnsigned<decltype(d)> du;
return BitCast(
d, Set<kIndex>(du, BitCast(du, dest),
BitCast(RebindToUnsigned<DFromV<decltype(v)>>(), v)));
}

} // namespace detail
Expand Down
Loading