diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp index 098bd487f..79b109387 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp @@ -29,7 +29,7 @@ BaseModel::BaseModel(const std::string &modelSource, } std::vector BaseModel::getInputShape(std::string method_name, - int32_t index) { + int32_t index) const { if (!module_) { throw std::runtime_error("Model not loaded: Cannot get input shape"); } @@ -55,7 +55,7 @@ std::vector BaseModel::getInputShape(std::string method_name, } std::vector> -BaseModel::getAllInputShapes(std::string methodName) { +BaseModel::getAllInputShapes(std::string methodName) const { if (!module_) { throw std::runtime_error("Model not loaded: Cannot get all input shapes"); } @@ -87,7 +87,7 @@ BaseModel::getAllInputShapes(std::string methodName) { /// to JS. It is not meant to be used within C++. If you want to call forward /// from C++ on a BaseModel, please use BaseModel::forward. std::vector -BaseModel::forwardJS(std::vector tensorViewVec) { +BaseModel::forwardJS(std::vector tensorViewVec) const { if (!module_) { throw std::runtime_error("Model not loaded: Cannot perform forward pass"); } @@ -135,7 +135,7 @@ BaseModel::forwardJS(std::vector tensorViewVec) { } Result -BaseModel::getMethodMeta(const std::string &methodName) { +BaseModel::getMethodMeta(const std::string &methodName) const { if (!module_) { throw std::runtime_error("Model not loaded: Cannot get method meta!"); } @@ -160,7 +160,7 @@ BaseModel::forward(const std::vector &input_evalues) const { Result> BaseModel::execute(const std::string &methodName, - const std::vector &input_value) { + const std::vector &input_value) const { if (!module_) { throw std::runtime_error("Model not loaded, cannot run execute."); } @@ -174,7 +174,7 @@ std::size_t BaseModel::getMemoryLowerBound() const noexcept { void BaseModel::unload() noexcept { module_.reset(nullptr); } std::vector -BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) { +BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) const { auto sizes = tensor.sizes(); return std::vector(sizes.begin(), sizes.end()); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h index 983dc9b74..b944c590a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h @@ -21,18 +21,20 @@ class BaseModel { std::shared_ptr callInvoker); std::size_t getMemoryLowerBound() const noexcept; void unload() noexcept; - std::vector getInputShape(std::string method_name, int32_t index); + std::vector getInputShape(std::string method_name, + int32_t index) const; std::vector> - getAllInputShapes(std::string methodName = "forward"); + getAllInputShapes(std::string methodName = "forward") const; std::vector - forwardJS(std::vector tensorViewVec); + forwardJS(std::vector tensorViewVec) const; Result> forward(const EValue &input_value) const; Result> forward(const std::vector &input_value) const; - Result> execute(const std::string &methodName, - const std::vector &input_value); + Result> + execute(const std::string &methodName, + const std::vector &input_value) const; Result - getMethodMeta(const std::string &methodName); + getMethodMeta(const std::string &methodName) const; protected: // If possible, models should not use the JS runtime to keep JSI internals @@ -44,7 +46,8 @@ class BaseModel { private: std::size_t memorySizeLowerBound{0}; - std::vector getTensorShape(const executorch::aten::Tensor &tensor); + std::vector + getTensorShape(const executorch::aten::Tensor &tensor) const; }; } // namespace models diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp index d0f965cb3..bf8f9fb86 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.cpp @@ -4,7 +4,6 @@ #include "ASR.h" #include "executorch/extension/tensor/tensor_ptr.h" #include "rnexecutorch/data_processing/Numerical.h" -#include "rnexecutorch/data_processing/dsp.h" #include "rnexecutorch/data_processing/gzip.h" namespace rnexecutorch::models::speech_to_text::asr { @@ -37,8 +36,7 @@ ASR::getInitialSequence(const DecodingOptions &options) const { return seq; } -GenerationResult ASR::generate(std::span waveform, - float temperature, +GenerationResult ASR::generate(std::span waveform, float temperature, const DecodingOptions &options) const { std::vector encoderOutput = this->encode(waveform); @@ -94,7 +92,7 @@ float ASR::getCompressionRatio(const std::string &text) const { } std::vector -ASR::generateWithFallback(std::span waveform, +ASR::generateWithFallback(std::span waveform, const DecodingOptions &options) const { std::vector temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f}; std::vector bestTokens; @@ -209,7 +207,7 @@ ASR::estimateWordLevelTimestampsLinear(std::span tokens, return wordObjs; } -std::vector ASR::transcribe(std::span waveform, +std::vector ASR::transcribe(std::span waveform, const DecodingOptions &options) const { int32_t seek = 0; std::vector results; @@ -218,7 +216,7 @@ std::vector ASR::transcribe(std::span waveform, int32_t start = seek * ASR::kSamplingRate; const auto end = std::min( (seek + ASR::kChunkSize) * ASR::kSamplingRate, waveform.size()); - std::span chunk = waveform.subspan(start, end - start); + auto chunk = waveform.subspan(start, end - start); if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) { break; @@ -246,19 +244,12 @@ std::vector ASR::transcribe(std::span waveform, return results; } -std::vector ASR::encode(std::span waveform) const { - constexpr int32_t fftWindowSize = 512; - constexpr int32_t stftHopLength = 160; - constexpr int32_t innerDim = 256; - - std::vector preprocessedData = - dsp::stftFromWaveform(waveform, fftWindowSize, stftHopLength); - const auto numFrames = - static_cast(preprocessedData.size()) / innerDim; - std::vector inputShape = {numFrames, innerDim}; +std::vector ASR::encode(std::span waveform) const { + auto inputShape = {static_cast(waveform.size())}; const auto modelInputTensor = executorch::extension::make_tensor_ptr( - std::move(inputShape), std::move(preprocessedData)); + std::move(inputShape), waveform.data(), + executorch::runtime::etensor::ScalarType::Float); const auto encoderResult = this->encoder->forward(modelInputTensor); if (!encoderResult.ok()) { @@ -268,7 +259,7 @@ std::vector ASR::encode(std::span waveform) const { } const auto decoderOutputTensor = encoderResult.get().at(0).toTensor(); - const int32_t outputNumel = decoderOutputTensor.numel(); + const auto outputNumel = decoderOutputTensor.numel(); const float *const dataPtr = decoderOutputTensor.const_data_ptr(); return {dataPtr, dataPtr + outputNumel}; @@ -277,8 +268,10 @@ std::vector ASR::encode(std::span waveform) const { std::vector ASR::decode(std::span tokens, std::span encoderOutput) const { std::vector tokenShape = {1, static_cast(tokens.size())}; + auto tokensLong = std::vector(tokens.begin(), tokens.end()); + auto tokenTensor = executorch::extension::make_tensor_ptr( - std::move(tokenShape), tokens.data(), ScalarType::Int); + tokenShape, tokensLong.data(), ScalarType::Long); const auto encoderOutputSize = static_cast(encoderOutput.size()); std::vector encShape = {1, ASR::kNumFrames, diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h index 20180ebe4..a0ea7e181 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/asr/ASR.h @@ -14,9 +14,9 @@ class ASR { const models::BaseModel *decoder, const TokenizerModule *tokenizer); std::vector - transcribe(std::span waveform, + transcribe(std::span waveform, const types::DecodingOptions &options) const; - std::vector encode(std::span waveform) const; + std::vector encode(std::span waveform) const; std::vector decode(std::span tokens, std::span encoderOutput) const; @@ -44,11 +44,10 @@ class ASR { std::vector getInitialSequence(const types::DecodingOptions &options) const; - types::GenerationResult generate(std::span waveform, - float temperature, + types::GenerationResult generate(std::span waveform, float temperature, const types::DecodingOptions &options) const; std::vector - generateWithFallback(std::span waveform, + generateWithFallback(std::span waveform, const types::DecodingOptions &options) const; std::vector calculateWordLevelTimestamps(std::span tokens,