Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ BaseModel::BaseModel(const std::string &modelSource,
}

std::vector<int32_t> BaseModel::getInputShape(std::string method_name,
int32_t index) {
int32_t index) const {
if (!module_) {
throw std::runtime_error("Model not loaded: Cannot get input shape");
}
Expand All @@ -55,7 +55,7 @@ std::vector<int32_t> BaseModel::getInputShape(std::string method_name,
}

std::vector<std::vector<int32_t>>
BaseModel::getAllInputShapes(std::string methodName) {
BaseModel::getAllInputShapes(std::string methodName) const {
if (!module_) {
throw std::runtime_error("Model not loaded: Cannot get all input shapes");
}
Expand Down Expand Up @@ -87,7 +87,7 @@ BaseModel::getAllInputShapes(std::string methodName) {
/// to JS. It is not meant to be used within C++. If you want to call forward
/// from C++ on a BaseModel, please use BaseModel::forward.
std::vector<JSTensorViewOut>
BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) const {
if (!module_) {
throw std::runtime_error("Model not loaded: Cannot perform forward pass");
}
Expand Down Expand Up @@ -135,7 +135,7 @@ BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
}

Result<executorch::runtime::MethodMeta>
BaseModel::getMethodMeta(const std::string &methodName) {
BaseModel::getMethodMeta(const std::string &methodName) const {
if (!module_) {
throw std::runtime_error("Model not loaded: Cannot get method meta!");
}
Expand All @@ -160,7 +160,7 @@ BaseModel::forward(const std::vector<EValue> &input_evalues) const {

Result<std::vector<EValue>>
BaseModel::execute(const std::string &methodName,
const std::vector<EValue> &input_value) {
const std::vector<EValue> &input_value) const {
if (!module_) {
throw std::runtime_error("Model not loaded, cannot run execute.");
}
Expand All @@ -174,7 +174,7 @@ std::size_t BaseModel::getMemoryLowerBound() const noexcept {
void BaseModel::unload() noexcept { module_.reset(nullptr); }

std::vector<int32_t>
BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) {
BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) const {
auto sizes = tensor.sizes();
return std::vector<int32_t>(sizes.begin(), sizes.end());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@ class BaseModel {
std::shared_ptr<react::CallInvoker> callInvoker);
std::size_t getMemoryLowerBound() const noexcept;
void unload() noexcept;
std::vector<int32_t> getInputShape(std::string method_name, int32_t index);
std::vector<int32_t> getInputShape(std::string method_name,
int32_t index) const;
std::vector<std::vector<int32_t>>
getAllInputShapes(std::string methodName = "forward");
getAllInputShapes(std::string methodName = "forward") const;
std::vector<JSTensorViewOut>
forwardJS(std::vector<JSTensorViewIn> tensorViewVec);
forwardJS(std::vector<JSTensorViewIn> tensorViewVec) const;
Result<std::vector<EValue>> forward(const EValue &input_value) const;
Result<std::vector<EValue>>
forward(const std::vector<EValue> &input_value) const;
Result<std::vector<EValue>> execute(const std::string &methodName,
const std::vector<EValue> &input_value);
Result<std::vector<EValue>>
execute(const std::string &methodName,
const std::vector<EValue> &input_value) const;
Result<executorch::runtime::MethodMeta>
getMethodMeta(const std::string &methodName);
getMethodMeta(const std::string &methodName) const;

protected:
// If possible, models should not use the JS runtime to keep JSI internals
Expand All @@ -44,7 +46,8 @@ class BaseModel {

private:
std::size_t memorySizeLowerBound{0};
std::vector<int32_t> getTensorShape(const executorch::aten::Tensor &tensor);
std::vector<int32_t>
getTensorShape(const executorch::aten::Tensor &tensor) const;
};
} // namespace models

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "ASR.h"
#include "executorch/extension/tensor/tensor_ptr.h"
#include "rnexecutorch/data_processing/Numerical.h"
#include "rnexecutorch/data_processing/dsp.h"
#include "rnexecutorch/data_processing/gzip.h"

namespace rnexecutorch::models::speech_to_text::asr {
Expand Down Expand Up @@ -37,8 +36,7 @@ ASR::getInitialSequence(const DecodingOptions &options) const {
return seq;
}

GenerationResult ASR::generate(std::span<const float> waveform,
float temperature,
GenerationResult ASR::generate(std::span<float> waveform, float temperature,
const DecodingOptions &options) const {
std::vector<float> encoderOutput = this->encode(waveform);

Expand Down Expand Up @@ -94,7 +92,7 @@ float ASR::getCompressionRatio(const std::string &text) const {
}

std::vector<Segment>
ASR::generateWithFallback(std::span<const float> waveform,
ASR::generateWithFallback(std::span<float> waveform,
const DecodingOptions &options) const {
std::vector<float> temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f};
std::vector<int32_t> bestTokens;
Expand Down Expand Up @@ -209,7 +207,7 @@ ASR::estimateWordLevelTimestampsLinear(std::span<const int32_t> tokens,
return wordObjs;
}

std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
std::vector<Segment> ASR::transcribe(std::span<float> waveform,
const DecodingOptions &options) const {
int32_t seek = 0;
std::vector<Segment> results;
Expand All @@ -218,7 +216,7 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
int32_t start = seek * ASR::kSamplingRate;
const auto end = std::min<int32_t>(
(seek + ASR::kChunkSize) * ASR::kSamplingRate, waveform.size());
std::span<const float> chunk = waveform.subspan(start, end - start);
auto chunk = waveform.subspan(start, end - start);

if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) {
break;
Expand Down Expand Up @@ -246,19 +244,12 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
return results;
}

std::vector<float> ASR::encode(std::span<const float> waveform) const {
constexpr int32_t fftWindowSize = 512;
constexpr int32_t stftHopLength = 160;
constexpr int32_t innerDim = 256;

std::vector<float> preprocessedData =
dsp::stftFromWaveform(waveform, fftWindowSize, stftHopLength);
const auto numFrames =
static_cast<int32_t>(preprocessedData.size()) / innerDim;
std::vector<int32_t> inputShape = {numFrames, innerDim};
std::vector<float> ASR::encode(std::span<float> waveform) const {
auto inputShape = {static_cast<int32_t>(waveform.size())};

const auto modelInputTensor = executorch::extension::make_tensor_ptr(
std::move(inputShape), std::move(preprocessedData));
std::move(inputShape), waveform.data(),
executorch::runtime::etensor::ScalarType::Float);
const auto encoderResult = this->encoder->forward(modelInputTensor);

if (!encoderResult.ok()) {
Expand All @@ -268,7 +259,7 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
}

const auto decoderOutputTensor = encoderResult.get().at(0).toTensor();
const int32_t outputNumel = decoderOutputTensor.numel();
const auto outputNumel = decoderOutputTensor.numel();

const float *const dataPtr = decoderOutputTensor.const_data_ptr<float>();
return {dataPtr, dataPtr + outputNumel};
Expand All @@ -277,8 +268,10 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
std::vector<float> ASR::decode(std::span<int32_t> tokens,
std::span<float> encoderOutput) const {
std::vector<int32_t> tokenShape = {1, static_cast<int32_t>(tokens.size())};
auto tokensLong = std::vector<int64_t>(tokens.begin(), tokens.end());

auto tokenTensor = executorch::extension::make_tensor_ptr(
std::move(tokenShape), tokens.data(), ScalarType::Int);
tokenShape, tokensLong.data(), ScalarType::Long);

const auto encoderOutputSize = static_cast<int32_t>(encoderOutput.size());
std::vector<int32_t> encShape = {1, ASR::kNumFrames,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class ASR {
const models::BaseModel *decoder,
const TokenizerModule *tokenizer);
std::vector<types::Segment>
transcribe(std::span<const float> waveform,
transcribe(std::span<float> waveform,
const types::DecodingOptions &options) const;
std::vector<float> encode(std::span<const float> waveform) const;
std::vector<float> encode(std::span<float> waveform) const;
std::vector<float> decode(std::span<int32_t> tokens,
std::span<float> encoderOutput) const;

Expand Down Expand Up @@ -44,11 +44,10 @@ class ASR {

std::vector<int32_t>
getInitialSequence(const types::DecodingOptions &options) const;
types::GenerationResult generate(std::span<const float> waveform,
float temperature,
types::GenerationResult generate(std::span<float> waveform, float temperature,
const types::DecodingOptions &options) const;
std::vector<types::Segment>
generateWithFallback(std::span<const float> waveform,
generateWithFallback(std::span<float> waveform,
const types::DecodingOptions &options) const;
std::vector<types::Segment>
calculateWordLevelTimestamps(std::span<const int32_t> tokens,
Expand Down
Loading