diff --git a/.cspell-wordlist.txt b/.cspell-wordlist.txt
index 9a2ffd573..2bacdccea 100644
--- a/.cspell-wordlist.txt
+++ b/.cspell-wordlist.txt
@@ -80,3 +80,6 @@ setpriority
errno
ifdef
elif
+FSMN
+fsmn
+subarray
\ No newline at end of file
diff --git a/docs/docs/02-hooks/01-natural-language-processing/useVAD.md b/docs/docs/02-hooks/01-natural-language-processing/useVAD.md
new file mode 100644
index 000000000..c051d8bce
--- /dev/null
+++ b/docs/docs/02-hooks/01-natural-language-processing/useVAD.md
@@ -0,0 +1,191 @@
+---
+title: useVAD
+---
+
+Voice Activity Detection (VAD) is the task of analyzing an audio signal to identify time segments containing human speech, separating them from non-speech sections like silence and background noise.
+
+:::caution
+It is recommended to use models provided by us, which are available at our [Hugging Face repository](https://huggingface.co/software-mansion/react-native-executorch-fsmn-vad). You can also use [constants](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/constants/modelUrls.ts) shipped with our library.
+:::
+
+## Reference
+
+You can obtain waveform from audio in any way most suitable to you, however in the snippet below we utilize `react-native-audio-api` library to process a `.mp3` file.
+
+```typescript
+import { useVAD, FSMN_VAD } from 'react-native-executorch';
+import { AudioContext } from 'react-native-audio-api';
+import * as FileSystem from 'expo-file-system';
+
+const model = useVAD({
+ model: FSMN_VAD,
+});
+
+const { uri } = await FileSystem.downloadAsync(
+ 'https://some-audio-url.com/file.mp3',
+ FileSystem.cacheDirectory + 'audio_file'
+);
+
+const audioContext = new AudioContext({ sampleRate: 16000 });
+const decodedAudioData = await audioContext.decodeAudioDataSource(uri);
+const audioBuffer = decodedAudioData.getChannelData(0);
+
+try {
+ const speechSegments = await model.forward(audioBuffer);
+ console.log(speechSegments);
+} catch (error) {
+ console.error('Error during running VAD model', error);
+}
+```
+
+### Arguments
+
+**`model`** - Object containing the model source.
+
+- **`modelSource`** - A string that specifies the location of the model binary.
+
+**`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook.
+
+For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page.
+
+### Returns
+
+| Field | Type | Description |
+| ------------------ | -------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------- |
+| `forward` | `(waveform: Float32Array) => Promise<{Segment[]}>` | Executes the model's forward pass, where input array should be a waveform at 16kHz. Returns a promise containing an array of `Segment` objects. |
+| `error` | string | null
| Contains the error message if the model failed to load. |
+| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. |
+| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. |
+| `downloadProgress` | `number` | Represents the download progress as a value between 0 and 1. |
+
+
+Type definitions
+
+```typescript
+interface Segment {
+ start: number;
+ end: number;
+}
+```
+
+
+## Running the model
+
+Before running the model's `forward` method, make sure to extract the audio waveform you want to process. You'll need to handle this step yourself, ensuring the audio is sampled at 16 kHz. Once you have the waveform, pass it as an argument to the forward method. The method returns a promise that resolves to the array of detected speech segments.
+
+:::info
+Timestamps in returned speech segments, correspond to indices of input array (waveform).
+:::
+
+## Example
+
+```tsx
+import React from 'react';
+import { Button, Text, SafeAreaView } from 'react-native';
+import { useVAD, FSMN_VAD } from 'react-native-executorch';
+import { AudioContext } from 'react-native-audio-api';
+import * as FileSystem from 'expo-file-system';
+
+export default function App() {
+ const model = useVAD({
+ model: FSMN_VAD,
+ });
+
+ const audioURL = 'https://some-audio-url.com/file.mp3';
+
+ const handleAudio = async () => {
+ if (!model) {
+ console.error('VAD model is not loaded yet.');
+ return;
+ }
+
+ console.log('Processing URL:', audioURL);
+
+ try {
+ const { uri } = await FileSystem.downloadAsync(
+ audioURL,
+ FileSystem.cacheDirectory + 'vad_example.tmp'
+ );
+
+ const audioContext = new AudioContext({ sampleRate: 16000 });
+ const originalDecodedBuffer =
+ await audioContext.decodeAudioDataSource(uri);
+ const originalChannelData = originalDecodedBuffer.getChannelData(0);
+
+ const segments = await model.forward(originalChannelData);
+ if (segments.length === 0) {
+ console.log('No speech segments were found.');
+ return;
+ }
+ console.log(`Found ${segments.length} speech segments.`);
+
+ const totalLength = segments.reduce(
+ (sum, seg) => sum + (seg.end - seg.start),
+ 0
+ );
+ const newAudioBuffer = audioContext.createBuffer(
+ 1, // Mono
+ totalLength,
+ originalDecodedBuffer.sampleRate
+ );
+ const newChannelData = newAudioBuffer.getChannelData(0);
+
+ let offset = 0;
+ for (const segment of segments) {
+ const slice = originalChannelData.subarray(segment.start, segment.end);
+ newChannelData.set(slice, offset);
+ offset += slice.length;
+ }
+
+ // Play the processed audio
+ const source = audioContext.createBufferSource();
+ source.buffer = newAudioBuffer;
+ source.connect(audioContext.destination);
+ source.start();
+ } catch (error) {
+ console.error('Error processing audio data:', error);
+ }
+ };
+
+ return (
+
+
+ Press the button to process and play speech from a sample file.
+
+
+
+ );
+}
+```
+
+## Supported models
+
+- [fsmn-vad](https://huggingface.co/funasr/fsmn-vad)
+
+## Benchmarks
+
+### Model size
+
+| Model | XNNPACK [MB] |
+| -------- | :----------: |
+| FSMN_VAD | 1.83 |
+
+### Memory usage
+
+| Model | Android (XNNPACK) [MB] | iOS (XNNPACK) [MB] |
+| -------- | :--------------------: | :----------------: |
+| FSMN_VAD | 97 | 45,9 |
+
+### Inference time
+
+
+
+:::warning warning
+Times presented in the tables are measured as consecutive runs of the model. Initial run times may be up to 2x longer due to model loading and initialization.
+:::
+
+Inference time were measured on a 60s audio, that can be found [here](https://models.silero.ai/vad_models/en.wav).
+
+| Model | iPhone 16 Pro (XNNPACK) [ms] | iPhone 14 Pro Max (XNNPACK) [ms] | iPhone SE 3 (XNNPACK) [ms] | OnePlus 12 (XNNPACK) [ms] |
+| -------- | :--------------------------: | :------------------------------: | :------------------------: | :-----------------------: |
+| FSMN_VAD | 151 | 171 | 180 | 109 |
diff --git a/docs/docs/03-typescript-api/01-natural-language-processing/VADModule.md b/docs/docs/03-typescript-api/01-natural-language-processing/VADModule.md
new file mode 100644
index 000000000..7f06ab95f
--- /dev/null
+++ b/docs/docs/03-typescript-api/01-natural-language-processing/VADModule.md
@@ -0,0 +1,64 @@
+---
+title: VADModule
+---
+
+TypeScript API implementation of the [useVAD](../../02-hooks/01-natural-language-processing/useVAD.md) hook.
+
+## Reference
+
+```typescript
+import { VADModule, FSMN_VAD } from 'react-native-executorch';
+
+const model = new VADModule();
+await model.load(FSMN_VAD, (progress) => {
+ console.log(progress);
+});
+
+await model.forward(waveform);
+```
+
+### Methods
+
+| Method | Type | Description |
+| --------- | ------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
+| `load` | `(model: { modelSource: ResourceSource }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary. To track the download progress, supply a callback function `onDownloadProgressCallback`. |
+| `forward` | `(waveform: Float32Array): Promise<{ [category: string]: number }>` | Executes the model's forward pass, where `imageSource` can be a fetchable resource or a Base64-encoded string. |
+| `delete` | `(): void` | Release the memory held by the module. Calling `forward` afterwards is invalid. |
+
+
+Type definitions
+
+```typescript
+type ResourceSource = string | number | object;
+```
+
+```typescript
+interface Segment {
+ start: number;
+ end: number;
+}
+```
+
+
+
+## Loading the model
+
+To load the model, create a new instance of the module and use the `load` method on it. It accepts an object:
+
+**`model`** - Object containing the model source.
+
+- **`modelSource`** - A string that specifies the location of the model binary.
+
+**`onDownloadProgressCallback`** - (Optional) Function called on download progress.
+
+This method returns a promise, which can resolve to an error or void.
+
+For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page.
+
+## Running the model
+
+To run the model, you can use the `forward` method on the module object. Before running the model's `forward` method, make sure to extract the audio waveform you want to process. You'll need to handle this step yourself, ensuring the audio is sampled at 16 kHz. Once you have the waveform, pass it as an argument to the forward method. The method returns a promise that resolves to the array of detected speech segments.
+
+## Managing memory
+
+The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method `delete()` on the module object you will no longer use, and want to remove from the memory. Note that you cannot use `forward` after `delete` unless you load the module again.
diff --git a/docs/docs/04-benchmarks/inference-time.md b/docs/docs/04-benchmarks/inference-time.md
index dd0f1275a..fa12ade94 100644
--- a/docs/docs/04-benchmarks/inference-time.md
+++ b/docs/docs/04-benchmarks/inference-time.md
@@ -62,7 +62,7 @@ Times presented in the tables are measured as consecutive runs of the model. Ini
❌ - Insufficient RAM.
-### Streaming mode
+## Streaming mode
Notice than for `Whisper` model which has to take as an input 30 seconds audio chunks (for shorter audio it is automatically padded with silence to 30 seconds) `fast` mode has the lowest latency (time from starting transcription to first token returned, caused by streaming algorithm), but the slowest speed. If you believe that this might be a problem for you, prefer `balanced` mode instead.
@@ -119,3 +119,13 @@ Average time for generating one image of size 256×256 in 10 inference steps.
| Model | iPhone 16 Pro (XNNPACK) [ms] | iPhone 14 Pro Max (XNNPACK) [ms] | iPhone SE 3 (XNNPACK) | Samsung Galaxy S24 (XNNPACK) [ms] | OnePlus 12 (XNNPACK) [ms] |
| --------------------- | :--------------------------: | :------------------------------: | :-------------------: | :-------------------------------: | :-----------------------: |
| BK_SDM_TINY_VPRED_256 | 19100 | 25000 | ❌ | ❌ | 23100 |
+
+## Voice Activity Detection (VAD)
+
+Average time for processing 60s audio.
+
+
+
+| Model | iPhone 16 Pro (XNNPACK) [ms] | iPhone 14 Pro Max (XNNPACK) [ms] | iPhone SE 3 (XNNPACK) [ms] | OnePlus 12 (XNNPACK) [ms] |
+| -------- | :--------------------------: | :------------------------------: | :------------------------: | :-----------------------: |
+| FSMN_VAD | 151 | 171 | 180 | 109 |
diff --git a/docs/docs/04-benchmarks/memory-usage.md b/docs/docs/04-benchmarks/memory-usage.md
index e34c8a7ca..e250b0c90 100644
--- a/docs/docs/04-benchmarks/memory-usage.md
+++ b/docs/docs/04-benchmarks/memory-usage.md
@@ -75,3 +75,9 @@ title: Memory Usage
| --------------------- | ---------------------- | ------------------ |
| BK_SDM_TINY_VPRED_256 | 2900 | 2800 |
| BK_SDM_TINY_VPRED | 6700 | 6560 |
+
+## Voice Activity Detection (VAD)
+
+| Model | Android (XNNPACK) [MB] | iOS (XNNPACK) [MB] |
+| -------- | :--------------------: | :----------------: |
+| FSMN_VAD | 97 | 45,9 |
diff --git a/docs/docs/04-benchmarks/model-size.md b/docs/docs/04-benchmarks/model-size.md
index 5cf87f6fa..2a648ac53 100644
--- a/docs/docs/04-benchmarks/model-size.md
+++ b/docs/docs/04-benchmarks/model-size.md
@@ -88,3 +88,9 @@ title: Model Size
| Model | Text encoder (XNNPACK) [MB] | UNet (XNNPACK) [MB] | VAE decoder (XNNPACK) [MB] |
| ----------------- | --------------------------- | ------------------- | -------------------------- |
| BK_SDM_TINY_VPRED | 492 | 1290 | 198 |
+
+## Voice Activity Detection (VAD)
+
+| Model | XNNPACK [MB] |
+| -------- | :----------: |
+| FSMN_VAD | 1.83 |
diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp
index 6d5e48902..ad1126a83 100644
--- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp
@@ -6,13 +6,14 @@
#include
#include
#include
-#include
#include
#include
#include
#include
#include
+#include
#include
+#include
#include
#include
@@ -51,8 +52,9 @@ void RnExecutorchInstaller::injectJSIBindings(
jsiRuntime->global().setProperty(
*jsiRuntime, "loadObjectDetection",
- RnExecutorchInstaller::loadModel(
- jsiRuntime, jsCallInvoker, "loadObjectDetection"));
+ RnExecutorchInstaller::loadModel<
+ models::object_detection::ObjectDetection>(jsiRuntime, jsCallInvoker,
+ "loadObjectDetection"));
jsiRuntime->global().setProperty(
*jsiRuntime, "loadExecutorchModule",
@@ -93,6 +95,12 @@ void RnExecutorchInstaller::injectJSIBindings(
RnExecutorchInstaller::loadModel(
jsiRuntime, jsCallInvoker, "loadSpeechToText"));
+ jsiRuntime->global().setProperty(
+ *jsiRuntime, "loadVAD",
+ RnExecutorchInstaller::loadModel<
+ models::voice_activity_detection::VoiceActivityDetection>(
+ jsiRuntime, jsCallInvoker, "loadVAD"));
+
threads::utils::unsafeSetupThreadPool();
threads::GlobalThreadPool::initialize();
}
diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
index 6bb43bd51..04b160b54 100644
--- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
+++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
@@ -17,6 +17,7 @@
#include
#include
#include
+#include
namespace rnexecutorch::jsi_conversion {
@@ -66,7 +67,8 @@ inline JSTensorViewIn getValue(const jsi::Value &val,
tensorView.sizes.reserve(numShapeDims);
for (size_t i = 0; i < numShapeDims; ++i) {
- int32_t dim = getValue(shapeArray.getValueAtIndex(runtime, i), runtime);
+ int32_t dim =
+ getValue(shapeArray.getValueAtIndex(runtime, i), runtime);
tensorView.sizes.push_back(dim);
}
@@ -173,23 +175,24 @@ inline std::vector getArrayAsVector(const jsi::Value &val,
return result;
}
-
// Template specializations for std::vector types
template <>
-inline std::vector getValue>(const jsi::Value &val,
- jsi::Runtime &runtime) {
+inline std::vector
+getValue>(const jsi::Value &val,
+ jsi::Runtime &runtime) {
return getArrayAsVector(val, runtime);
}
template <>
-inline std::vector getValue>(const jsi::Value &val,
- jsi::Runtime &runtime) {
+inline std::vector
+getValue>(const jsi::Value &val,
+ jsi::Runtime &runtime) {
return getArrayAsVector(val, runtime);
}
template <>
-inline std::vector getValue>(const jsi::Value &val,
- jsi::Runtime &runtime) {
+inline std::vector
+getValue>(const jsi::Value &val, jsi::Runtime &runtime) {
return getArrayAsVector(val, runtime);
}
@@ -388,4 +391,19 @@ getJsiValue(const std::vector &detections,
return jsiDetections;
}
+inline jsi::Value
+getJsiValue(const std::vector
+ &speechSegments,
+ jsi::Runtime &runtime) {
+ auto jsiSegments = jsi::Array(runtime, speechSegments.size());
+ for (size_t i = 0; i < speechSegments.size(); i++) {
+ const auto &[start, end] = speechSegments[i];
+ auto jsiSegmentObject = jsi::Object(runtime);
+ jsiSegmentObject.setProperty(runtime, "start", static_cast(start));
+ jsiSegmentObject.setProperty(runtime, "end", static_cast(end));
+ jsiSegments.setValueAtIndex(runtime, i, jsiSegmentObject);
+ }
+ return jsiSegments;
+}
+
} // namespace rnexecutorch::jsi_conversion
diff --git a/packages/react-native-executorch/common/rnexecutorch/jsi/OwningArrayBuffer.h b/packages/react-native-executorch/common/rnexecutorch/jsi/OwningArrayBuffer.h
index a76fb0270..392e0b155 100644
--- a/packages/react-native-executorch/common/rnexecutorch/jsi/OwningArrayBuffer.h
+++ b/packages/react-native-executorch/common/rnexecutorch/jsi/OwningArrayBuffer.h
@@ -23,9 +23,7 @@ class OwningArrayBuffer : public jsi::MutableBuffer {
/**
* @param size Size of the buffer in bytes.
*/
- OwningArrayBuffer(size_t size) : size_(size) {
- data_ = new uint8_t[size_];
- }
+ OwningArrayBuffer(size_t size) : size_(size) { data_ = new uint8_t[size_]; }
/**
* @param data Pointer to the source data.
* @param size Size of the data in bytes.
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Constants.h
new file mode 100644
index 000000000..eaf34796a
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Constants.h
@@ -0,0 +1,27 @@
+#pragma once
+
+#include
+#include
+#include
+namespace rnexecutorch::models::voice_activity_detection::constants {
+
+inline constexpr uint32_t kSampleRate = 16000;
+inline constexpr auto kMstoSecond = 0.001f;
+inline constexpr uint32_t kWindowSizeMs = 25;
+inline constexpr uint32_t kHopLengthMs = 10;
+inline constexpr auto kWindowSize =
+ static_cast(kMstoSecond * kWindowSizeMs * kSampleRate); // 400
+inline constexpr auto kHopLength =
+ static_cast(kMstoSecond * kHopLengthMs * kSampleRate); // 160
+inline constexpr auto kPreemphasisCoeff = 0.97f;
+inline constexpr auto kLeftPadding = (kWindowSize - 1) / 2;
+inline constexpr auto kRightPadding = kWindowSize / 2;
+inline constexpr auto kPaddedWindowSize = std::bit_ceil(kWindowSize); // 512
+inline constexpr size_t kModelInputMin = 100;
+inline constexpr size_t kModelInputMax = 1000;
+inline constexpr auto kSpeechThreshold = 0.6f;
+inline constexpr size_t kMinSpeechDuration = 25; // 250 ms
+inline constexpr size_t kMinSilenceDuration = 10; // 100 ms
+inline constexpr size_t kSpeechPad = 3; // 30 ms
+
+} // namespace rnexecutorch::models::voice_activity_detection::constants
\ No newline at end of file
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Types.h
new file mode 100644
index 000000000..51794d6bf
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Types.h
@@ -0,0 +1,12 @@
+#pragma once
+
+#include
+
+namespace rnexecutorch::models::voice_activity_detection::types {
+
+struct Segment {
+ size_t start;
+ size_t end;
+};
+
+} // namespace rnexecutorch::models::voice_activity_detection::types
\ No newline at end of file
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.cpp b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.cpp
new file mode 100644
index 000000000..881422b4b
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.cpp
@@ -0,0 +1,15 @@
+#include "Utils.h"
+
+namespace rnexecutorch::models::voice_activity_detection::utils {
+size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor,
+ size_t numClass, size_t size,
+ std::vector &resultVector,
+ size_t startIdx) {
+ auto rawData = tensor.const_data_ptr();
+ for (size_t i = 0; i < size; i++) {
+ resultVector[startIdx + i] = rawData[numClass * i];
+ }
+ return startIdx + size;
+}
+
+} // namespace rnexecutorch::models::voice_activity_detection::utils
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.h b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.h
new file mode 100644
index 000000000..b01670088
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include
+#include
+#include
+
+namespace rnexecutorch::models::voice_activity_detection::utils {
+size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor,
+ size_t numClass, size_t size,
+ std::vector &resultVector,
+ size_t startIdx);
+
+} // namespace rnexecutorch::models::voice_activity_detection::utils
\ No newline at end of file
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp
new file mode 100644
index 000000000..d07dbfb3c
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp
@@ -0,0 +1,161 @@
+#include "VoiceActivityDetection.h"
+#include "rnexecutorch/data_processing/dsp.h"
+#include "rnexecutorch/models/voice_activity_detection/Utils.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace rnexecutorch::models::voice_activity_detection {
+using namespace constants;
+namespace ranges = std::ranges;
+using executorch::aten::Tensor;
+using executorch::extension::TensorPtr;
+
+VoiceActivityDetection::VoiceActivityDetection(
+ const std::string &modelSource,
+ std::shared_ptr callInvoker)
+ : BaseModel(modelSource, callInvoker) {}
+
+std::vector>
+VoiceActivityDetection::preprocess(std::span waveform) const {
+ auto kHammingWindowArray = dsp::hannWindow(kWindowSize);
+
+ const size_t numFrames = (waveform.size() - kWindowSize) / kHopLength;
+
+ std::vector> frameBuffer(
+ numFrames, std::array{});
+
+ constexpr size_t totalPadding = kPaddedWindowSize - kWindowSize;
+ constexpr size_t leftPadding = totalPadding / 2;
+ for (size_t i = 0; i < numFrames; i++) {
+
+ auto windowView = waveform.subspan(i * kHopLength, kWindowSize);
+ ranges::copy(windowView, frameBuffer[i].begin() + leftPadding);
+ auto frameView =
+ std::span{frameBuffer[i].data() + leftPadding, kWindowSize};
+ const float sum = std::reduce(frameView.begin(), frameView.end(), 0.0f);
+ const float mean = sum / kWindowSize;
+ ranges::transform(frameView, frameView.begin(),
+ [mean](float value) { return value - mean; });
+
+ // apply pre-emphasis filter
+ for (auto j = frameView.size() - 1; j > 0; --j) {
+ frameView[j] -= kPreemphasisCoeff * frameView[j - 1];
+ }
+ // apply hamming window to reduce spectral leakage
+ ranges::transform(frameView, kHammingWindowArray, frameView.begin(),
+ std::multiplies{});
+ }
+ return frameBuffer;
+}
+
+std::vector
+VoiceActivityDetection::generate(std::span waveform) const {
+
+ auto windowedInput = preprocess(waveform);
+ auto [chunksNumber, remainder] = std::div(
+ static_cast(windowedInput.size()), static_cast(kModelInputMax));
+ std::vector scores(windowedInput.size());
+ auto lastChunkSize = remainder;
+ if (remainder < kModelInputMin) {
+ auto paddingSize = kModelInputMin - remainder;
+ lastChunkSize = kModelInputMin;
+ windowedInput.insert(windowedInput.end(), paddingSize,
+ std::array{});
+ }
+ TensorPtr inputTensor;
+ size_t startIdx = 0;
+
+ for (size_t i = 0; i < chunksNumber; i++) {
+ std::span> chunk(
+ windowedInput.data() + kModelInputMax * i, kModelInputMax);
+ inputTensor = executorch::extension::from_blob(
+ chunk.data(), {kModelInputMax, kPaddedWindowSize},
+ executorch::aten::ScalarType::Float);
+ auto forwardResult = BaseModel::forward(inputTensor);
+ if (!forwardResult.ok()) {
+ throw std::runtime_error(
+ "Failed to forward, error: " +
+ std::to_string(static_cast(forwardResult.error())));
+ }
+ auto tensor = forwardResult->at(0).toTensor();
+ startIdx = utils::getNonSpeechClassProbabilites(
+ tensor, tensor.size(2), tensor.size(1), scores, startIdx);
+ }
+
+ std::span> lastChunk(
+ windowedInput.data() + kModelInputMax * chunksNumber, lastChunkSize);
+ inputTensor = executorch::extension::from_blob(
+ lastChunk.data(), {lastChunkSize, kPaddedWindowSize},
+ executorch::aten::ScalarType::Float);
+ auto forwardResult = BaseModel::forward(inputTensor);
+ if (!forwardResult.ok()) {
+ throw std::runtime_error(
+ "Failed to forward, error: " +
+ std::to_string(static_cast(forwardResult.error())));
+ }
+ auto tensor = forwardResult->at(0).toTensor();
+ startIdx = utils::getNonSpeechClassProbabilites(tensor, tensor.size(2),
+ remainder, scores, startIdx);
+ return postprocess(scores, kSpeechThreshold);
+}
+
+std::vector
+VoiceActivityDetection::postprocess(const std::vector &scores,
+ float threshold) const {
+ bool triggered = false;
+ std::vector speechSegments{};
+ ssize_t startSegment = -1;
+ ssize_t endSegment = -1;
+ ssize_t potentialStart = -1;
+ ssize_t potentialEnd = -1;
+ float score;
+ for (size_t i = 0; i < scores.size(); i++) {
+ score = 1 - scores[i];
+ if (!triggered) {
+ if (score >= threshold) {
+ if (potentialStart == -1) {
+ potentialStart = i;
+ } else if (i - potentialStart >= kMinSpeechDuration) {
+ triggered = true;
+ startSegment = potentialStart;
+ potentialStart = -1;
+ }
+ } else { // score < threshold
+ potentialStart = -1;
+ }
+ } else { // triggered
+ if (score < threshold) {
+ if (potentialEnd == -1) {
+ potentialEnd = i;
+ } else if (i - potentialEnd >= kMinSilenceDuration) {
+ triggered = false;
+ endSegment = potentialEnd;
+ speechSegments.emplace_back(startSegment, endSegment);
+ potentialEnd = -1;
+ }
+ } else {
+ potentialEnd = -1;
+ }
+ }
+ }
+ if (triggered) {
+ endSegment = scores.size();
+ speechSegments.emplace_back(startSegment, endSegment);
+ }
+
+ for (auto &[start, end] : speechSegments) {
+ // std::max(start-kSpeedchPad, 0) might be underflow that is why we use ?
+ // operator.
+ start = (start > kSpeechPad ? start - kSpeechPad : 0) * kHopLength;
+ end = std::min(end + kSpeechPad, scores.size()) * kHopLength;
+ }
+
+ return speechSegments;
+}
+
+} // namespace rnexecutorch::models::voice_activity_detection
\ No newline at end of file
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h
new file mode 100644
index 000000000..e69288930
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
+#include "rnexecutorch/models/BaseModel.h"
+#include "rnexecutorch/models/voice_activity_detection/Constants.h"
+#include "rnexecutorch/models/voice_activity_detection/Types.h"
+
+namespace rnexecutorch {
+namespace models::voice_activity_detection {
+using executorch::extension::TensorPtr;
+using executorch::runtime::EValue;
+class VoiceActivityDetection : public BaseModel {
+public:
+ VoiceActivityDetection(const std::string &modelSource,
+ std::shared_ptr callInvoker);
+ [[nodiscard("Registered non-void function")]] std::vector
+ generate(std::span waveform) const;
+
+private:
+ std::vector>
+ preprocess(std::span waveform) const;
+ std::vector postprocess(const std::vector &scores,
+ float threshold) const;
+};
+} // namespace models::voice_activity_detection
+
+REGISTER_CONSTRUCTOR(models::voice_activity_detection::VoiceActivityDetection,
+ std::string, std::shared_ptr);
+} // namespace rnexecutorch
\ No newline at end of file
diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts
index 34834733e..57381cf15 100644
--- a/packages/react-native-executorch/src/constants/modelUrls.ts
+++ b/packages/react-native-executorch/src/constants/modelUrls.ts
@@ -3,6 +3,7 @@ import { Platform } from 'react-native';
const URL_PREFIX =
'https://huggingface.co/software-mansion/react-native-executorch';
const VERSION_TAG = 'resolve/v0.5.0';
+const NEXT_VERSION_TAG = 'resolve/v0.6.0';
// LLMs
@@ -439,3 +440,10 @@ export const BK_SDM_TINY_VPRED_256 = {
unetSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/unet/model.256.pte`,
decoderSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/vae/model.256.pte`,
};
+
+// Voice Activity Detection
+const FSMN_VAD_MODEL = `${URL_PREFIX}-fsmn-vad/${NEXT_VERSION_TAG}/xnnpack/fsmn-vad_xnnpack.pte`;
+
+export const FSMN_VAD = {
+ modelSource: FSMN_VAD_MODEL,
+};
diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useVAD.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useVAD.ts
new file mode 100644
index 000000000..2e3cb4235
--- /dev/null
+++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useVAD.ts
@@ -0,0 +1,15 @@
+import { ResourceSource } from '../../types/common';
+import { useModule } from '../useModule';
+import { VADModule } from '../../modules/natural_language_processing/VADModule';
+
+interface Props {
+ model: { modelSource: ResourceSource };
+ preventLoad?: boolean;
+}
+
+export const useVAD = ({ model, preventLoad = false }: Props) =>
+ useModule({
+ module: VADModule,
+ model,
+ preventLoad: preventLoad,
+ });
diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts
index 623022113..cddc6f595 100644
--- a/packages/react-native-executorch/src/index.ts
+++ b/packages/react-native-executorch/src/index.ts
@@ -9,6 +9,7 @@ declare global {
var loadExecutorchModule: (source: string) => any;
var loadTokenizerModule: (source: string) => any;
var loadImageEmbeddings: (source: string) => any;
+ var loadVAD: (source: string) => any;
var loadTextEmbeddings: (modelSource: string, tokenizerSource: string) => any;
var loadLLM: (modelSource: string, tokenizerSource: string) => any;
var loadTextToImage: (
@@ -52,6 +53,7 @@ if (
global.loadTokenizerModule == null ||
global.loadTextEmbeddings == null ||
global.loadImageEmbeddings == null ||
+ global.loadVAD == null ||
global.loadLLM == null ||
global.loadSpeechToText == null ||
global.loadOCR == null ||
@@ -79,6 +81,7 @@ export * from './hooks/natural_language_processing/useLLM';
export * from './hooks/natural_language_processing/useSpeechToText';
export * from './hooks/natural_language_processing/useTextEmbeddings';
export * from './hooks/natural_language_processing/useTokenizer';
+export * from './hooks/natural_language_processing/useVAD';
export * from './hooks/general/useExecutorchModule';
@@ -96,6 +99,7 @@ export * from './modules/natural_language_processing/LLMModule';
export * from './modules/natural_language_processing/SpeechToTextModule';
export * from './modules/natural_language_processing/TextEmbeddingsModule';
export * from './modules/natural_language_processing/TokenizerModule';
+export * from './modules/natural_language_processing/VADModule';
export * from './modules/general/ExecutorchModule';
@@ -108,6 +112,7 @@ export * from './types/objectDetection';
export * from './types/ocr';
export * from './types/imageSegmentation';
export * from './types/llm';
+export * from './types/vad';
export * from './types/common';
export {
SpeechToTextLanguage,
diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/VADModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/VADModule.ts
new file mode 100644
index 000000000..9e784a0e4
--- /dev/null
+++ b/packages/react-native-executorch/src/modules/natural_language_processing/VADModule.ts
@@ -0,0 +1,27 @@
+import { ResourceFetcher } from '../../utils/ResourceFetcher';
+import { ResourceSource } from '../../types/common';
+import { Segment } from '../../types/vad';
+import { ETError, getError } from '../../Error';
+import { BaseModule } from '../BaseModule';
+
+export class VADModule extends BaseModule {
+ async load(
+ model: { modelSource: ResourceSource },
+ onDownloadProgressCallback: (progress: number) => void = () => {}
+ ): Promise {
+ const paths = await ResourceFetcher.fetch(
+ onDownloadProgressCallback,
+ model.modelSource
+ );
+ if (paths === null || paths.length < 1) {
+ throw new Error('Download interrupted.');
+ }
+ this.nativeModule = global.loadVAD(paths[0] || '');
+ }
+
+ async forward(waveform: Float32Array): Promise {
+ if (this.nativeModule == null)
+ throw new Error(getError(ETError.ModuleNotLoaded));
+ return await this.nativeModule.generate(waveform);
+ }
+}
diff --git a/packages/react-native-executorch/src/types/vad.ts b/packages/react-native-executorch/src/types/vad.ts
new file mode 100644
index 000000000..ca0dff920
--- /dev/null
+++ b/packages/react-native-executorch/src/types/vad.ts
@@ -0,0 +1,4 @@
+export interface Segment {
+ start: number;
+ end: number;
+}