Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept left offsets in the masked softmax operator #1370

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions include/ctranslate2/ops/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,21 @@ namespace ctranslate2 {
void operator()(StorageView& x) const;
void operator()(const StorageView& x, StorageView& y) const override;
void operator()(const StorageView& x, const StorageView& lengths, StorageView& y) const;
void operator()(const StorageView& x, const StorageView* lengths, StorageView& y) const;
void operator()(const StorageView& x,
const StorageView& lengths,
const StorageView& offsets,
StorageView& y) const;
void operator()(const StorageView& x,
const StorageView* lengths,
const StorageView* offsets,
StorageView& y) const;

private:
template <Device D, typename T>
void compute(const StorageView& input, const StorageView* lengths, StorageView& output) const;
void compute(const StorageView& input,
const StorageView* lengths,
const StorageView* offsets,
StorageView& output) const;

bool _log;
};
Expand Down
26 changes: 14 additions & 12 deletions src/cpu/kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ namespace ctranslate2 {
template<>
void softmax<TARGET_ISA>(const float* input,
const int32_t* lengths,
const int32_t* offsets,
float* output,
dim_t batch_size,
dim_t depth,
Expand All @@ -376,23 +377,24 @@ namespace ctranslate2 {

parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
for (dim_t i = begin; i < end; ++i) {
const dim_t start = offsets ? offsets[i] : 0;
const dim_t size = lengths ? lengths[i] : depth - start;

if (size == 0)
continue;

const dim_t offset = i * depth;
const float* x = input + offset;
float* y = output + offset;

dim_t size = depth;
if (lengths) {
size = lengths[i];

// Directly set 0 in output for out of range positions.
for (dim_t j = size; j < depth; ++j) {
y[j] = 0;
}
// Directly set 0 in output for out of range positions.
for (dim_t j = 0; j < start; ++j)
y[j] = 0;
for (dim_t j = start + size; j < depth; ++j)
y[j] = 0;

if (size == 0) {
continue;
}
}
x += start;
y += start;

const auto x_max = reduce_max<TARGET_ISA>(x, size);
const auto vec_x_max = VecType::load(x_max);
Expand Down
1 change: 1 addition & 0 deletions src/cpu/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ namespace ctranslate2 {
template <CpuIsa ISA>
void softmax(const float* input,
const int32_t* lengths,
const int32_t* offsets,
float* output,
dim_t batch_size,
dim_t depth,
Expand Down
2 changes: 1 addition & 1 deletion src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ namespace ctranslate2 {
alibi->apply(output);

StorageView attn(values.dtype(), values.device());
ops::SoftMax()(output, values_lengths, attn);
ops::SoftMax()(output, values_lengths, nullptr, attn);

if (attention && !return_normalized_attention)
save_attention(*attention, std::move(output), beam_size);
Expand Down
41 changes: 28 additions & 13 deletions src/ops/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,51 @@ namespace ctranslate2 {
}

void SoftMax::operator()(StorageView& x) const {
operator()(x, nullptr, x);
operator()(x, nullptr, nullptr, x);
}

void SoftMax::operator()(const StorageView& x, StorageView& y) const {
operator()(x, nullptr, y);
operator()(x, nullptr, nullptr, y);
}

void SoftMax::operator()(const StorageView& x, const StorageView& lengths, StorageView& y) const {
operator()(x, &lengths, y);
operator()(x, &lengths, nullptr, y);
}

void SoftMax::operator()(const StorageView& x, const StorageView* lengths, StorageView& y) const {
void SoftMax::operator()(const StorageView& x,
const StorageView& lengths,
const StorageView& offsets,
StorageView& y) const {
operator()(x, &lengths, &offsets, y);
}

void SoftMax::operator()(const StorageView& x,
const StorageView* lengths,
const StorageView* offsets,
StorageView& y) const {
PROFILE(_log ? "LogSoftMax" : "SoftMax");
y.resize_as(x);

const dim_t depth = x.dim(-1);
const dim_t batch_size = x.size() / depth;

if (depth == 0)
return;

if (lengths) {
const dim_t batch_size = x.size() / depth;
if (lengths->size() != batch_size)
throw std::invalid_argument("Length mask has size "
+ std::to_string(lengths->size())
+ " which is different than the current batch size "
+ std::to_string(batch_size));
}
if (lengths && lengths->size() != batch_size)
throw std::invalid_argument("Length mask has size "
+ std::to_string(lengths->size())
+ " which is different than the current batch size "
+ std::to_string(batch_size));

if (offsets && offsets->size() != batch_size)
throw std::invalid_argument("Offsets input has size "
+ std::to_string(offsets->size())
+ " which is different than the current batch size "
+ std::to_string(batch_size));

DEVICE_AND_FLOAT_DISPATCH("SoftMax", x.device(), x.dtype(), (compute<D, T>(x, lengths, y)));
DEVICE_AND_FLOAT_DISPATCH("SoftMax", x.device(), x.dtype(),
(compute<D, T>(x, lengths, offsets, y)));
}

}
Expand Down
3 changes: 3 additions & 0 deletions src/ops/softmax_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ namespace ctranslate2 {
template <Device D, typename T>
void SoftMax::compute(const StorageView& input,
const StorageView* lengths,
const StorageView* offsets,
StorageView& output) const {
constexpr float epsilon = 0.000001f;
const dim_t depth = input.dim(-1);
const dim_t batch_size = input.size() / depth;

CPU_ISA_DISPATCH((cpu::softmax<ISA>(input.data<T>(),
lengths ? lengths->data<int32_t>() : nullptr,
offsets ? offsets->data<int32_t>() : nullptr,
output.data<T>(),
batch_size,
depth,
Expand All @@ -26,6 +28,7 @@ namespace ctranslate2 {
template void \
SoftMax::compute<Device::CPU, T>(const StorageView& input, \
const StorageView* lengths, \
const StorageView* offsets, \
StorageView& output) const;

DECLARE_IMPL(float)
Expand Down
34 changes: 24 additions & 10 deletions src/ops/softmax_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ namespace ctranslate2 {
const dim_t rows,
const dim_t cols,
const int32_t* lengths,
const int32_t* offsets,
T* y);

template <Device D, typename T>
void SoftMax::compute(const StorageView& input,
const StorageView* lengths,
const StorageView* offsets,
StorageView& output) const {
const dim_t depth = input.dim(-1);
const dim_t batch_size = input.size() / depth;
Expand All @@ -27,13 +29,15 @@ namespace ctranslate2 {
batch_size,
depth,
lengths ? lengths->data<int32_t>() : nullptr,
offsets ? offsets->data<int32_t>() : nullptr,
output.data<T>());
}

#define DECLARE_IMPL(T) \
template void \
SoftMax::compute<Device::CUDA, T>(const StorageView& input, \
const StorageView* lengths, \
const StorageView* offsets, \
StorageView& output) const;

DECLARE_IMPL(float)
Expand Down Expand Up @@ -197,7 +201,8 @@ namespace at {
cunn_SoftMaxForward(outscalar_t *output,
const scalar_t *input,
const index_t classes,
const length_t *lengths)
const length_t *lengths,
const length_t *offsets)
{
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<accscalar_t*>(smem);
Expand All @@ -207,15 +212,21 @@ namespace at {
input += row * classes;
output += row * classes;

index_t size = classes;
if (lengths)
{
const index_t start = offsets ? offsets[row] : 0;
const index_t size = lengths ? lengths[row] : classes - start;
const index_t end = start + size;

if (start > 0 || end < classes) {
// Directly set 0 in output for out of range positions.
size = lengths[row];
for (index_t i = size + threadIdx.x; i < classes; i += blockDim.x)
output[i] = 0.f;
for (index_t i = threadIdx.x; i < classes; i += blockDim.x) {
if (i < start || i >= end)
output[i] = 0.f;
}
}

input += start;
output += start;

// find the max
accscalar_t threadMax = ctranslate2::cuda::ilp_reduce(
input, size, MaxFloat<scalar_t, accscalar_t>(), -max_float);
Expand Down Expand Up @@ -245,14 +256,16 @@ namespace ctranslate2 {
const dim_t rows,
const dim_t cols,
const int32_t* lengths,
const int32_t* offsets,
T* y) {
const dim3 grid(rows);
const dim3 block(cuda::get_block_size(cols));
at::native::cunn_SoftMaxForward<T, float, T, cuda::index_t, int32_t, Epilogue>
<<<grid, block, block.x * sizeof (float), stream>>>(y,
x,
cols,
lengths);
lengths,
offsets);
}

template <typename T>
Expand All @@ -262,13 +275,14 @@ namespace ctranslate2 {
const dim_t rows,
const dim_t cols,
const int32_t* lengths,
const int32_t* offsets,
T* y) {
if (log_softmax)
softmax_kernel_impl<cuda::device_type<T>, at::native::LogSoftMaxForwardEpilogue>(
stream, cuda::device_cast(x), rows, cols, lengths, cuda::device_cast(y));
stream, cuda::device_cast(x), rows, cols, lengths, offsets, cuda::device_cast(y));
else
softmax_kernel_impl<cuda::device_type<T>, at::native::SoftMaxForwardEpilogue>(
stream, cuda::device_cast(x), rows, cols, lengths, cuda::device_cast(y));
stream, cuda::device_cast(x), rows, cols, lengths, offsets, cuda::device_cast(y));
}

}
Expand Down
17 changes: 17 additions & 0 deletions tests/ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,23 @@ TEST_P(OpDeviceFPTest, MaskedSoftMax) {
expect_storage_eq(y.to_float32(), expected, error);
}

TEST_P(OpDeviceFPTest, MaskedSoftMaxLeftPadding) {
const Device device = GetParam().device;
const DataType dtype = GetParam().dtype;
const float error = GetParam().error;
StorageView x({2, 5}, std::vector<float>{
0.0, -0.2, 3.0, 1.2, -1.1,
4.6, 3.3, 0.2, -1.6, 1.0}, device);
StorageView lengths({2}, std::vector<int32_t>{3, 4}, device);
StorageView offsets({2}, std::vector<int32_t>{1, 0}, device);
StorageView expected({2, 5}, std::vector<float>{
0, 0.033797, 0.829145, 0.137056, 0,
0.777098, 0.211783, 0.009540, 0.001577, 0}, device);
StorageView y(dtype, device);
ops::SoftMax()(x.to(dtype), lengths, offsets, y);
expect_storage_eq(y.to_float32(), expected, error);
}

TEST_P(OpDeviceFPTest, MaskedSoftMaxTriangular) {
const Device device = GetParam().device;
const DataType dtype = GetParam().dtype;
Expand Down