Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Useful Sensors Moonshine model. #1808

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ option(BUILD_TESTS "Compile the tests" OFF)
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF)
option(WITH_FLASH_ATTN "Compile with Flash Attention 2" OFF)
option(MOONSHINE "Compile with moonshine specializations" OFF)

if (MOONSHINE)
add_definitions(-DMOONSHINE)
endif()

if(ENABLE_PROFILING)
message(STATUS "Enable profiling support")
Expand Down Expand Up @@ -129,6 +134,7 @@ set(SOURCES
src/layers/wav2vec2.cc
src/layers/wav2vec2bert.cc
src/layers/whisper.cc
src/layers/moonshine.cc
src/logging.cc
src/models/language_model.cc
src/models/model.cc
Expand All @@ -139,6 +145,7 @@ set(SOURCES
src/models/wav2vec2.cc
src/models/wav2vec2bert.cc
src/models/whisper.cc
src/models/moonshine.cc
src/ops/activation.cc
src/ops/add.cc
src/ops/alibi_add.cc
Expand Down
1 change: 1 addition & 0 deletions docs/conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The Python module includes a [conversion API](python/ctranslate2.converters.rst)

* [Fairseq](guides/fairseq.md)
* [Marian](guides/marian.md)
* [Moonshine](guides/moonshine.md)
* [OpenNMT-py](guides/opennmt_py.md)
* [OpenNMT-tf](guides/opennmt_tf.md)
* [OPUS-MT](guides/opus_mt.md)
Expand Down
10 changes: 10 additions & 0 deletions docs/guides/moonshine.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Moonshine

CTranslate2 supports [Moonshine](https://github.com/usefulsensors/moonshine) transcription models. The conversion requires the paths to the model and tokenizer.json files.

Please use model.safetensor and tokenizer.json files from [Moonshine Tiny](https://huggingface.co/UsefulSensors/moonshine-tiny/tree/main) and [Moonshine Base](https://huggingface.co/UsefulSensors/moonshine-base/tree/main).

```bash
ct2-moonshine-converter --model_path model.safetensors --vocab_path tokenizer.json --moonshine_variant tiny \
--output_dir ct2_model
```
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
myst-parser==0.17.*
sphinx-rtd-theme==1.0.*
sphinx==4.5.*
safetensors[torch]
75 changes: 75 additions & 0 deletions include/ctranslate2/layers/moonshine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include "ctranslate2/layers/transformer.h"

namespace ctranslate2 {
namespace layers {

class MoonshinePreprocessor : public Layer {
public:
MoonshinePreprocessor(const models::Model& model, const std::string& scope);

void operator()(const StorageView& features, StorageView& output);

DataType output_type() const override {
return _conv3.output_type();
}

dim_t output_size() const override {
return _conv3.output_size();
}

dim_t input_size() const {
return _conv1.input_size();
}
private:
const Conv1D _conv1;
const ops::Tanh _tanh;
const LayerNorm _norm;
const Conv1D _conv2;
const ops::GELU _gelu1;
const Conv1D _conv3;
const ops::GELU _gelu2;
const ops::Transpose _transpose;
};


class MoonshineEncoder : public Layer {
public:
MoonshineEncoder(const models::Model& model, const std::string& scope);

void operator()(const StorageView& features, StorageView& output);

DataType output_type() const override {
return _output_norm.output_type();
}

dim_t output_size() const override {
return _output_norm.output_size();
}

bool is_encoded(const StorageView& features) const {
// Input features shape: [batch_size, input_size, input_time]
// Encoder output shape: [batch_size, input_time // 2, output_size]
//
// input_time is variable so we check that dimension 1 is different than its original value.

return (features.rank() == 3
&& features.dim(2) == output_size()
&& features.dim(1) != 1);
}

private:
const dim_t _num_heads;
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
const LayerNorm _output_norm;
};

class MoonshineDecoder : public TransformerDecoder {
public:
using TransformerDecoder::TransformerDecoder;

bool return_normalized_attention() const override {
return false;
}
};
}
}
134 changes: 134 additions & 0 deletions include/ctranslate2/models/moonshine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#pragma once

#include "ctranslate2/generation.h"
#include "ctranslate2/layers/moonshine.h"
#include "ctranslate2/models/model.h"
#include "ctranslate2/replica_pool.h"

namespace ctranslate2 {
namespace models {

struct MoonshineOptions {
// Beam size to use for beam search (set 1 to run greedy search).
size_t beam_size = 5;

// Beam search patience factor, as described in https://arxiv.org/abs/2204.05424.
// The decoding will continue until beam_size*patience hypotheses are finished.
float patience = 1;

// Exponential penalty applied to the length during beam search.
float length_penalty = 1;

// Penalty applied to the score of previously generated tokens, as described in
// https://arxiv.org/abs/1909.05858 (set > 1 to penalize).
float repetition_penalty = 1;

// Prevent repetitions of ngrams with this size (set 0 to disable).
size_t no_repeat_ngram_size = 0;

// Maximum generation length.
size_t max_length = 448;

// Randomly sample from the top K candidates (set 0 to sample from the full distribution).
size_t sampling_topk = 1;

// High temperatures increase randomness.
float sampling_temperature = 1;

// Number of hypotheses to include in the result.
size_t num_hypotheses = 1;

// Include scores in the result.
bool return_scores = false;

// Suppress blank outputs at the beginning of the sampling.
bool suppress_blank = true;

// List of token IDs to suppress.
// -1 will suppress a default set of symbols as defined in the model config.json file.
std::vector<int> suppress_tokens = {-1};
};

struct MoonshineGenerationResult {
std::vector<std::vector<std::string>> sequences;
std::vector<std::vector<size_t>> sequences_ids;
std::vector<float> scores;

size_t num_sequences() const {
return sequences.size();
}

bool has_scores() const {
return !scores.empty();
}
};

class MoonshineModel : public Model {
public:
const Vocabulary& get_vocabulary() const;

size_t current_spec_revision() const override;
bool is_quantizable(const std::string& variable_name) const override;
bool is_linear_weight(const std::string& variable_name) const override;
std::unique_ptr<Model> clone() const override;

bool use_global_int16_scale() const override {
return false;
}

protected:
void initialize(ModelReader& model_reader) override;

private:
std::shared_ptr<const Vocabulary> _vocabulary;
};

class MoonshineReplica : public ModelReplica {
public:
static std::unique_ptr<MoonshineReplica> create_from_model(const Model& model);

MoonshineReplica(const std::shared_ptr<const MoonshineModel>& model);

StorageView encode(StorageView features, const bool to_cpu);

std::vector<MoonshineGenerationResult>
generate(StorageView features,
const std::vector<std::vector<std::string>>& prompts,
const MoonshineOptions& options);

std::vector<MoonshineGenerationResult>
generate(StorageView features,
const std::vector<std::vector<size_t>>& prompts,
const MoonshineOptions& options);

private:
const std::shared_ptr<const MoonshineModel> _model;
const std::unique_ptr<layers::MoonshinePreprocessor> _preprocessor;
const std::unique_ptr<layers::MoonshineEncoder> _encoder;
const std::unique_ptr<layers::MoonshineDecoder> _decoder;

size_t _sot_id;
size_t _eot_id;

StorageView maybe_encode(StorageView features);
};

class Moonshine : public ReplicaPool<MoonshineReplica> {
public:
using ReplicaPool::ReplicaPool;

std::future<StorageView> encode(const StorageView& features, const bool to_cpu);

std::vector<std::future<MoonshineGenerationResult>>
generate(const StorageView& features,
std::vector<std::vector<std::string>> prompts,
MoonshineOptions options = {});

std::vector<std::future<MoonshineGenerationResult>>
generate(const StorageView& features,
std::vector<std::vector<size_t>> prompts,
MoonshineOptions options = {});
};

}
}
4 changes: 3 additions & 1 deletion include/ctranslate2/ops/layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace ctranslate2 {

class LayerNorm : public TernaryOp {
public:
LayerNorm(const dim_t axis = -1, const float epsilon = 1e-5);
LayerNorm(const dim_t axis = -1, const float epsilon = 1e-5, const bool multi_axis=false);

using TernaryOp::operator();
void operator()(const StorageView& beta,
Expand All @@ -32,10 +32,12 @@ namespace ctranslate2 {
const dim_t outer_size,
const dim_t axis_size,
const dim_t inner_size,
const bool multi_axis,
StorageView& output) const;

const dim_t _axis;
const float _epsilon;
const bool _multi_axis;
};

}
Expand Down
1 change: 1 addition & 0 deletions python/cpp/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,5 @@ PYBIND11_MODULE(_ext, m)
ctranslate2::python::register_wav2vec2(m);
ctranslate2::python::register_wav2vec2bert(m);
ctranslate2::python::register_mpi(m);
ctranslate2::python::register_moonshine(m);
}
1 change: 1 addition & 0 deletions python/cpp/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace ctranslate2 {
void register_wav2vec2(py::module& m);
void register_wav2vec2bert(py::module& m);
void register_mpi(py::module& m);
void register_moonshine(py::module& m);

}
}
Loading
Loading