diff --git a/.gitmodules b/.gitmodules index 73f832b1d6ce9a..02a8988a0e3523 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,7 +6,7 @@ url = ../../commaai/opendbc.git [submodule "cereal"] path = cereal - url = ../../commaai/cereal.git + url = ../../jakethesnake420/cereal.git [submodule "rednose_repo"] path = rednose_repo url = ../../commaai/rednose.git diff --git a/cereal b/cereal index 20b65eeb1f6c58..02e154efe6a8f4 160000 --- a/cereal +++ b/cereal @@ -1 +1 @@ -Subproject commit 20b65eeb1f6c580cdd7d63e53639f4fc48bc2f56 +Subproject commit 02e154efe6a8f4fdc329884145f2e09bf39f8909 diff --git a/common/params.cc b/common/params.cc index 386813efdd3514..dc93fc448e0cae 100644 --- a/common/params.cc +++ b/common/params.cc @@ -187,6 +187,7 @@ std::unordered_map keys = { {"RecordFrontLock", PERSISTENT}, // for the internal fleet {"ReplayControlsState", CLEAR_ON_MANAGER_START | CLEAR_ON_ONROAD_TRANSITION}, {"SnoozeUpdate", CLEAR_ON_MANAGER_START | CLEAR_ON_OFFROAD_TRANSITION}, + {"SpeechToTextAllowed", PERSISTENT}, {"SshEnabled", PERSISTENT}, {"TermsVersion", PERSISTENT}, {"Timezone", PERSISTENT}, @@ -205,6 +206,7 @@ std::unordered_map keys = { {"UpdaterLastFetchTime", PERSISTENT}, {"Version", PERSISTENT}, {"VisionRadarToggle", PERSISTENT}, + {"WakeWordDetected", CLEAR_ON_MANAGER_START}, {"WheeledBody", PERSISTENT}, }; diff --git a/poetry.lock b/poetry.lock index 41b677e5e0c7c7..2772957aaf3530 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3957,6 +3957,26 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rev_ai" +version = "2.19.4" +description = "Rev AI makes speech applications easy to build!" +optional = false +python-versions = "*" +files = [] +develop = false + +[package.dependencies] +requests = ">=2.21.0,<3.0.0" +six = ">=1.12.0,<2.0.0" +websocket-client = ">=0.56.0,<1.0.0" + +[package.source] +type = "git" +url = "https://github.com/jakethesnake420/revai-python-sdk.git" +reference = "patch-1" +resolved_reference = "386dd99c189000a5eb9c5056f028ce86488c352e" + [[package]] name = "ruff" version = "0.1.13" @@ -4720,19 +4740,17 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [[package]] name = "websocket-client" -version = "1.7.0" +version = "0.59.0" description = "WebSocket client for Python with low level API options" optional = false -python-versions = ">=3.8" +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" files = [ - {file = "websocket-client-1.7.0.tar.gz", hash = "sha256:10e511ea3a8c744631d3bd77e61eb17ed09304c413ad42cf6ddfa4c7787e8fe6"}, - {file = "websocket_client-1.7.0-py3-none-any.whl", hash = "sha256:f4c3d22fec12a2461427a29957ff07d35098ee2d976d3ba244e688b8b4057588"}, + {file = "websocket-client-0.59.0.tar.gz", hash = "sha256:d376bd60eace9d437ab6d7ee16f4ab4e821c9dae591e1b783c58ebd8aaf80c5c"}, + {file = "websocket_client-0.59.0-py2.py3-none-any.whl", hash = "sha256:2e50d26ca593f70aba7b13a489435ef88b8fc3b5c5643c1ce8808ff9b40f0b32"}, ] -[package.extras] -docs = ["Sphinx (>=6.0)", "sphinx-rtd-theme (>=1.1.0)"] -optional = ["python-socks", "wsaccel"] -test = ["websockets"] +[package.dependencies] +six = "*" [[package]] name = "yapf" @@ -4871,4 +4889,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "~3.11" -content-hash = "1976ee7795d5ac5b257cac8dd83e408ae6e75c48d8c349e3e8e56519727cf52f" +content-hash = "ce851124c34915a6ba98c54fce1a0837807e2fec7d8fdbf218b5bcd03ef62ce0" diff --git a/pyproject.toml b/pyproject.toml index 5641a5510175b5..d1b1ca00c6a4ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,7 @@ sounddevice = "*" spidev = { version = "*", platform = "linux" } sympy = "*" websocket_client = "*" +rev_ai = { git = "https://github.com/jakethesnake420/revai-python-sdk.git", branch = "patch-1" } # these should be removed markdown-it-py = "*" diff --git a/selfdrive/manager/process_config.py b/selfdrive/manager/process_config.py index 36d69b03bb447e..e7c49febf03eb9 100644 --- a/selfdrive/manager/process_config.py +++ b/selfdrive/manager/process_config.py @@ -48,7 +48,7 @@ def only_offroad(started, params, CP: car.CarParams) -> bool: NativeProcess("logcatd", "system/logcatd", ["./logcatd"], only_onroad), NativeProcess("proclogd", "system/proclogd", ["./proclogd"], only_onroad), PythonProcess("logmessaged", "system.logmessaged", always_run), - PythonProcess("micd", "system.micd", iscar), + PythonProcess("micd", "system.micd", only_onroad if TICI else always_run), PythonProcess("timezoned", "system.timezoned", always_run, enabled=not PC), PythonProcess("dmonitoringmodeld", "selfdrive.modeld.dmonitoringmodeld", driverview, enabled=(not PC or WEBCAM)), @@ -81,6 +81,8 @@ def only_offroad(started, params, CP: car.CarParams) -> bool: PythonProcess("updated", "selfdrive.updated", only_offroad, enabled=not PC), PythonProcess("uploader", "system.loggerd.uploader", always_run), PythonProcess("statsd", "selfdrive.statsd", always_run), + PythonProcess("speechd", "system.assistant.rev_speechd", only_onroad if TICI else always_run), + PythonProcess("wakewordd", "system.assistant.wakewordd", only_onroad if TICI else always_run), # debug procs NativeProcess("bridge", "cereal/messaging", ["./bridge"], notcar), diff --git a/selfdrive/ui/SConscript b/selfdrive/ui/SConscript index a3cba124fe79f0..0aeb66a7d37810 100644 --- a/selfdrive/ui/SConscript +++ b/selfdrive/ui/SConscript @@ -27,7 +27,7 @@ widgets_src = ["ui.cc", "qt/widgets/input.cc", "qt/widgets/wifi.cc", "qt/widgets/ssh_keys.cc", "qt/widgets/toggle.cc", "qt/widgets/controls.cc", "qt/widgets/offroad_alerts.cc", "qt/widgets/prime.cc", "qt/widgets/keyboard.cc", "qt/widgets/scrollview.cc", "qt/widgets/cameraview.cc", "#third_party/qrcode/QrCode.cc", - "qt/request_repeater.cc", "qt/qt_window.cc", "qt/network/networking.cc", "qt/network/wifi_manager.cc"] + "qt/request_repeater.cc", "qt/qt_window.cc", "qt/network/networking.cc", "qt/network/wifi_manager.cc", "qt/widgets/assistant.cc"] qt_env['CPPDEFINES'] = [] if maps: diff --git a/selfdrive/ui/qt/home.cc b/selfdrive/ui/qt/home.cc index 9dbe7cbae39b86..9e1df93ea1f7d7 100644 --- a/selfdrive/ui/qt/home.cc +++ b/selfdrive/ui/qt/home.cc @@ -65,6 +65,7 @@ void HomeWindow::updateState(const UIState &s) { body->setEnabled(true); slayout->setCurrentWidget(body); } + emit requestRaiseAssistantOverlay(); } void HomeWindow::offroadTransition(bool offroad) { diff --git a/selfdrive/ui/qt/home.h b/selfdrive/ui/qt/home.h index c6032852a13170..8fa2b369dd83ae 100644 --- a/selfdrive/ui/qt/home.h +++ b/selfdrive/ui/qt/home.h @@ -50,6 +50,7 @@ class HomeWindow : public QWidget { signals: void openSettings(int index = 0, const QString ¶m = ""); void closeSettings(); + void requestRaiseAssistantOverlay(); public slots: void offroadTransition(bool offroad); diff --git a/selfdrive/ui/qt/widgets/assistant.cc b/selfdrive/ui/qt/widgets/assistant.cc new file mode 100644 index 00000000000000..b497ace88a225a --- /dev/null +++ b/selfdrive/ui/qt/widgets/assistant.cc @@ -0,0 +1,88 @@ +#include "selfdrive/ui/qt/widgets/assistant.h" + +AssistantOverlay::AssistantOverlay(QWidget *parent) : QLabel(parent) { + + setStyleSheet("QLabel {" + " background-color: #373737;" + " border-radius: 20px;" + " font-family: 'Inter';" + " font-size: 60px;" + " color: white;" // Text color + "}"); + + // Set up the animations + showAnimation = new QPropertyAnimation(this, "geometry"); + showAnimation->setDuration(250); // Duration in milliseconds + hideAnimation = new QPropertyAnimation(this, "geometry"); + hideAnimation->setDuration(250); + int height = 100; // Fixed height + setGeometry(0, 0, 0, height); + hide(); + + hideTimer = new QTimer(this); + connect(hideTimer, &QTimer::timeout, this, [this]() { animateOverlay(false); }); + QObject::connect(uiState(), &UIState::uiUpdate, this, &AssistantOverlay::updateState); +} + +void AssistantOverlay::animateOverlay(bool show) { + int parentCenterX = parentWidget()->width() / 2; + finalWidth = parentWidget()->width() * 0.5; + int startX = parentCenterX - finalWidth / 2; + QRect centerRect(parentCenterX, 0, 0, height()); // Centered, zero width + QRect fullRect(startX, 0, finalWidth, height()); // Adjusted x, final width + + if (show) { + showAnimation->setStartValue(centerRect); + showAnimation->setEndValue(fullRect); + this->show(); + showAnimation->start(); + } else { + hideAnimation->setStartValue(fullRect); + hideAnimation->setEndValue(centerRect); + hideAnimation->start(); + hideTimer->stop(); + } +} + +void AssistantOverlay::updateText(QString text) { + this->setText(text); + this->setAlignment(QFontMetrics(this->font()).horizontalAdvance(text) > this->finalWidth ? Qt::AlignRight : Qt::AlignCenter); +} + +void AssistantOverlay::updateState(const UIState &s) { + const SubMaster &sm = *(s.sm); + if (!sm.updated("speechToText")) return; + + static cereal::SpeechToText::State current_state = cereal::SpeechToText::State::NONE; + cereal::SpeechToText::State request_state = sm["speechToText"].getSpeechToText().getState(); + // Check for valid state transition + if (current_state == cereal::SpeechToText::State::BEGIN || + (current_state == cereal::SpeechToText::State::NONE && + (request_state == cereal::SpeechToText::State::EMPTY || + request_state == cereal::SpeechToText::State::FINAL || + request_state == cereal::SpeechToText::State::NONE)) || + request_state == cereal::SpeechToText::State::BEGIN) { + + current_state = request_state; // Update state + switch (current_state) { // Handle UI updates + case cereal::SpeechToText::State::BEGIN: + if (!hideTimer->isActive()) animateOverlay(true); + updateText("Hello, I'm listening"); + hideTimer->start(30000); + break; + case cereal::SpeechToText::State::EMPTY: + updateText("Sorry, I didn't catch that"); + hideTimer->start(8000); + break; + case cereal::SpeechToText::State::ERROR: + updateText("Sorry, an error occorred"); + hideTimer->start(8000); + break; + case cereal::SpeechToText::State::FINAL: + case cereal::SpeechToText::State::NONE: + updateText(QString::fromStdString(sm["speechToText"].getSpeechToText().getTranscript())); + hideTimer->start(request_state == cereal::SpeechToText::State::FINAL ? 8000 : 30000); + break; + } + } +} diff --git a/selfdrive/ui/qt/widgets/assistant.h b/selfdrive/ui/qt/widgets/assistant.h new file mode 100644 index 00000000000000..f10a84398a10de --- /dev/null +++ b/selfdrive/ui/qt/widgets/assistant.h @@ -0,0 +1,25 @@ +#pragma once + +#include "selfdrive/ui/ui.h" +#include +#include + +class AssistantOverlay : public QLabel { + Q_OBJECT + +public: + explicit AssistantOverlay(QWidget *parent = nullptr); + void animateOverlay(bool show); + +private: + void updateText(const QString text); + void startHideTimer(); + QTimer *hideTimer; + QPropertyAnimation *showAnimation; + QPropertyAnimation *hideAnimation; + int finalWidth; + +private slots: + void updateState(const UIState &s); + +}; diff --git a/selfdrive/ui/qt/window.cc b/selfdrive/ui/qt/window.cc index 74fd05ed7bee77..86848337f51d11 100644 --- a/selfdrive/ui/qt/window.cc +++ b/selfdrive/ui/qt/window.cc @@ -8,10 +8,12 @@ MainWindow::MainWindow(QWidget *parent) : QWidget(parent) { main_layout = new QStackedLayout(this); main_layout->setMargin(0); + assistantOverlay = new AssistantOverlay(this); homeWindow = new HomeWindow(this); main_layout->addWidget(homeWindow); QObject::connect(homeWindow, &HomeWindow::openSettings, this, &MainWindow::openSettings); QObject::connect(homeWindow, &HomeWindow::closeSettings, this, &MainWindow::closeSettings); + QObject::connect(homeWindow, &HomeWindow::requestRaiseAssistantOverlay, this, &MainWindow::raiseAssistantOverlay); settingsWindow = new SettingsWindow(this); main_layout->addWidget(settingsWindow); @@ -37,11 +39,13 @@ MainWindow::MainWindow(QWidget *parent) : QWidget(parent) { if (!offroad) { closeSettings(); } + assistantOverlay->raise(); }); QObject::connect(device(), &Device::interactiveTimeout, [=]() { if (main_layout->currentWidget() == settingsWindow) { closeSettings(); } + assistantOverlay->raise(); }); // load fonts @@ -68,6 +72,7 @@ MainWindow::MainWindow(QWidget *parent) : QWidget(parent) { void MainWindow::openSettings(int index, const QString ¶m) { main_layout->setCurrentWidget(settingsWindow); settingsWindow->setCurrentPanel(index, param); + assistantOverlay->raise(); } void MainWindow::closeSettings() { @@ -80,6 +85,7 @@ void MainWindow::closeSettings() { homeWindow->showMapPanel(true); } } + assistantOverlay->raise(); } bool MainWindow::eventFilter(QObject *obj, QEvent *event) { diff --git a/selfdrive/ui/qt/window.h b/selfdrive/ui/qt/window.h index 05b61e1f762bff..2fcd2e19758e47 100644 --- a/selfdrive/ui/qt/window.h +++ b/selfdrive/ui/qt/window.h @@ -6,6 +6,7 @@ #include "selfdrive/ui/qt/home.h" #include "selfdrive/ui/qt/offroad/onboarding.h" #include "selfdrive/ui/qt/offroad/settings.h" +#include "selfdrive/ui/qt/widgets/assistant.h" class MainWindow : public QWidget { Q_OBJECT @@ -17,9 +18,11 @@ class MainWindow : public QWidget { bool eventFilter(QObject *obj, QEvent *event) override; void openSettings(int index = 0, const QString ¶m = ""); void closeSettings(); + void raiseAssistantOverlay() {assistantOverlay->raise();} QStackedLayout *main_layout; HomeWindow *homeWindow; SettingsWindow *settingsWindow; OnboardingWindow *onboardingWindow; + AssistantOverlay *assistantOverlay; }; diff --git a/selfdrive/ui/ui.cc b/selfdrive/ui/ui.cc index 9afd22f13a12da..fb9b05e893f791 100644 --- a/selfdrive/ui/ui.cc +++ b/selfdrive/ui/ui.cc @@ -248,7 +248,7 @@ UIState::UIState(QObject *parent) : QObject(parent) { sm = std::make_unique>({ "modelV2", "controlsState", "liveCalibration", "radarState", "deviceState", "pandaStates", "carParams", "driverMonitoringState", "carState", "liveLocationKalman", "driverStateV2", - "wideRoadCameraState", "managerState", "navInstruction", "navRoute", "uiPlan", + "wideRoadCameraState", "managerState", "navInstruction", "navRoute", "uiPlan", "speechToText", }); Params params; diff --git a/system/assistant/nav_setter.py b/system/assistant/nav_setter.py new file mode 100644 index 00000000000000..5972f81e371cb1 --- /dev/null +++ b/system/assistant/nav_setter.py @@ -0,0 +1,59 @@ +from cereal import messaging, log +from openpilot.common.params import Params +import json + +STTState = log.SpeechToText.State + +sm = messaging.SubMaster(["speechToText"]) +import os +import re +import requests +import urllib.parse + +def get_coordinates_from_transcript(transcript, proximity, mapbox_access_token): + # Regular expression to find 'navigate to' or 'directions to' followed by an address + pattern = r'\b(navigate to|directions to)\b\s+(.*?)(\.|$)' + # Search for the pattern in the transcript + match = re.search(pattern, transcript, re.IGNORECASE) + if match: + address = match.group(2).strip() + encoded_address = urllib.parse.quote(address) + mapbox_url = f"https://api.mapbox.com/geocoding/v5/mapbox.places/{encoded_address}.json?access_token=pk.eyJ1IjoicnlsZXltY2MiLCJhIjoiY2xjeDl5aGp4MTBmeDNzb2Vua2QyNWN1bSJ9.CrbD-j1LQkBdOqyWcZneyQ" + response = requests.get(mapbox_url) + if response.status_code == 200: + data = response.json() + # Assuming the first result is the most relevant + if data['features']: + coordinates = { + "latitude": data['features'][0]['geometry']['coordinates'][0], + "longitude": data['features'][0]['geometry']['coordinates'][1], + } + return coordinates + print("No coordinates") + print(f"Mapbox API error: {response.status_code}") + return False + +def main(): + params = Params() + mapbox_access_token = os.environ["MAPBOX_TOKEN"] + while True: + dest = False + transcript: str = "" + sm.update(0) + if sm.updated["speechToText"]: + transcript = sm["speechToText"].transcript + if not sm["speechToText"].state == log.SpeechToText.State.final: + print(f'Interim result: {transcript}') + else: + print(f'Final result: {transcript}') + proximity = params.get("LastGPSPosition") + print(proximity) + dest = get_coordinates_from_transcript(transcript,proximity, mapbox_access_token) + if dest: + params.put("NavDestination", json.dumps(dest)) + print(dest) + dest = False + +if __name__ == "__main__": + main() + diff --git a/system/assistant/openwakeword/README.md b/system/assistant/openwakeword/README.md new file mode 100644 index 00000000000000..85ce95c9a64094 --- /dev/null +++ b/system/assistant/openwakeword/README.md @@ -0,0 +1,19 @@ +The openwakeword driectory code was copied from https://github.com/dscripka/openWakeWord +and then stripped down to the essentials for Openpilot's purposes. + +To test wake word detection on the comma device or PC, run wakeword.py and say "alexa". +Make sure you have onnxruntime==1.16.3 when running on the comma device. + +pip install onnxruntime==1.16.3 + +You can also run rev_speechd.py which will wait for the "WakeWordDetected" param to be set. +To setup the Rev.Ai api you need to install rev_ai: + +pip install rev_ai + +You also need to set your rev ai acccess token which can be obtained with a free account. https://www.rev.ai/access-token +Once you have your token you can paste it in launch_openpilot.sh. + +export REVAI_ACCESS_TOKEN="" + +Once you have everything set up you can run ./launch_openpilot and see the assistant overlay on the UI. You can also run ./ui, rev_speechd.py, wakeword.py, micd.py in their own terminals for testing. diff --git a/system/assistant/openwakeword/__init__.py b/system/assistant/openwakeword/__init__.py new file mode 100644 index 00000000000000..34016d2ecf710b --- /dev/null +++ b/system/assistant/openwakeword/__init__.py @@ -0,0 +1,50 @@ +import os +from openpilot.system.assistant.openwakeword.model import Model +assert Model + +FEATURE_MODELS = { + "embedding": { + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/embedding_model.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/embedding_model.tflite" + }, + "melspectrogram": { + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/melspectrogram.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/melspectrogram.tflite" + } +} + +VAD_MODELS = { + "silero_vad": { + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/silero_vad.onnx"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/silero_vad.onnx" + } +} + +MODELS = { + "alexa": { + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/alexa_v0.1.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/alexa_v0.1.tflite" + }, + "hey_mycroft": { + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_mycroft_v0.1.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/hey_mycroft_v0.1.tflite" + }, + "hey_jarvis": { + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_jarvis_v0.1.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/hey_jarvis_v0.1.tflite" + }, + "hey_rhasspy": { + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/hey_rhasspy_v0.1.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/hey_rhasspy_v0.1.tflite" + }, + "timer": { + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/timer_v0.1.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/timer_v0.1.tflite" + }, + "weather": { + "model_path": os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources/models/weather_v0.1.tflite"), + "download_url": "https://github.com/dscripka/openWakeWord/releases/download/v0.5.1/weather_v0.1.tflite" + } +} + + diff --git a/system/assistant/openwakeword/model.py b/system/assistant/openwakeword/model.py new file mode 100644 index 00000000000000..344550618fc314 --- /dev/null +++ b/system/assistant/openwakeword/model.py @@ -0,0 +1,151 @@ +# Copyright 2022 David Scripka. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import onnxruntime as ort +import os +import functools +from collections import deque, defaultdict +from functools import partial +import time +from typing import List, DefaultDict, Dict +import numpy as np +from openpilot.system.assistant.openwakeword.utils import AudioFeatures + + +class Model(): + def __init__(self, wakeword_models: List[str], enable_speex_noise_suppression: bool = False, **kwargs ): + wakeword_model_names = [] + if len(wakeword_models) >= 1: + for _, i in enumerate(wakeword_models): + if os.path.exists(i): + wakeword_model_names.append(os.path.splitext(os.path.basename(i))[0]) + # Create attributes to store models and metadata + self.models = {} + self.model_inputs = {} + self.model_outputs = {} + self.model_prediction_function = {} + self.class_mapping = {} + + def onnx_predict(onnx_model, x): + return onnx_model.run(None, {onnx_model.get_inputs()[0].name: x}) + + for mdl_path, mdl_name in zip(wakeword_models, wakeword_model_names, strict=False): + # Load openwakeword models + sessionOptions = ort.SessionOptions() + sessionOptions.inter_op_num_threads = 1 + sessionOptions.intra_op_num_threads = 1 + self.models[mdl_name] = ort.InferenceSession(mdl_path, sess_options=sessionOptions, + providers=["CPUExecutionProvider"]) + self.model_inputs[mdl_name] = self.models[mdl_name].get_inputs()[0].shape[1] + self.model_outputs[mdl_name] = self.models[mdl_name].get_outputs()[0].shape[1] + pred_function = functools.partial(onnx_predict, self.models[mdl_name]) + self.model_prediction_function[mdl_name] = pred_function + + self.class_mapping[mdl_name] = {str(i): str(i) for i in range(self.model_outputs[mdl_name])} + # Create buffer to store frame predictions + self.prediction_buffer: DefaultDict[str, deque] = defaultdict(partial(deque, maxlen=30)) + # Initialize SpeexDSP noise canceller + if enable_speex_noise_suppression: + from speexdsp_ns import NoiseSuppression + self.speex_ns = NoiseSuppression.create(160, 16000) + else: + self.speex_ns = None + # Create AudioFeatures object + self.preprocessor = AudioFeatures(**kwargs) + + def get_parent_model_from_label(self, label): + parent_model = "" + for mdl in self.class_mapping.keys(): + if label in self.class_mapping[mdl].values(): + parent_model = mdl + elif label in self.class_mapping.keys() and label == mdl: + parent_model = mdl + return parent_model + + def reset(self): + self.prediction_buffer = defaultdict(partial(deque, maxlen=30)) + + def predict(self, x: np.ndarray, timing: bool = False): + # Check input data type + if not isinstance(x, np.ndarray): + raise ValueError(f"The input audio data (x) must by a Numpy array, instead received an object of type {type(x)}.") + # Setup timing dict + if timing: + timing_dict: Dict[str, Dict] = {} + timing_dict["models"] = {} + feature_start = time.time() + # Get audio features (optionally with Speex noise suppression) + if self.speex_ns: + n_prepared_samples = self.preprocessor(self._suppress_noise_with_speex(x)) + else: + n_prepared_samples = self.preprocessor(x) + if timing: + timing_dict["models"]["preprocessor"] = time.time() - feature_start + # Get predictions from model(s) + predictions = {} + for mdl in self.models.keys(): + if timing: + model_start = time.time() + # Run model to get predictions + if n_prepared_samples > 1280: + group_predictions = [] + for i in np.arange(n_prepared_samples//1280-1, -1, -1): + group_predictions.extend( + self.model_prediction_function[mdl]( + self.preprocessor.get_features( + self.model_inputs[mdl], + start_ndx=-self.model_inputs[mdl] - i + ) + ) + ) + prediction = np.array(group_predictions).max(axis=0)[None, ] + elif n_prepared_samples == 1280: + prediction = self.model_prediction_function[mdl]( + self.preprocessor.get_features(self.model_inputs[mdl]) + ) + elif n_prepared_samples < 1280: # get previous prediction if there aren't enough samples + if self.model_outputs[mdl] == 1: + if len(self.prediction_buffer[mdl]) > 0: + prediction = [[[self.prediction_buffer[mdl][-1]]]] + else: + prediction = [[[0]]] + elif self.model_outputs[mdl] != 1: + n_classes = max([int(i) for i in self.class_mapping[mdl].keys()]) + prediction = [[[0]*(n_classes+1)]] + + if self.model_outputs[mdl] == 1: + predictions[mdl] = prediction[0][0][0] + else: + for int_label, cls in self.class_mapping[mdl].items(): + predictions[cls] = prediction[0][0][int(int_label)] + # Update prediction buffer, and zero predictions for first 5 frames during model initialization + for cls in predictions.keys(): + if len(self.prediction_buffer[cls]) < 5: + predictions[cls] = 0.0 + self.prediction_buffer[cls].append(predictions[cls]) + # Get timing information + if timing: + timing_dict["models"][mdl] = time.time() - model_start + if timing: + return predictions, timing_dict + else: + return predictions + + def _suppress_noise_with_speex(self, x: np.ndarray, frame_size: int = 160): + cleaned = [] + for i in range(0, x.shape[0], frame_size): + chunk = x[i:i+frame_size] + cleaned.append(self.speex_ns.process(chunk.tobytes())) + cleaned_bytestring = b''.join(cleaned) + cleaned_array = np.frombuffer(cleaned_bytestring, np.int16) + return cleaned_array diff --git a/system/assistant/openwakeword/utils.py b/system/assistant/openwakeword/utils.py new file mode 100644 index 00000000000000..a7555cb3c9e4e0 --- /dev/null +++ b/system/assistant/openwakeword/utils.py @@ -0,0 +1,207 @@ +# Copyright 2022 David Scripka. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union, List, Callable, Deque +from collections import deque +import pathlib +import os +import numpy as np +import requests +from tqdm import tqdm +import openpilot.system.assistant.openwakeword as openwakeword +import onnxruntime as ort + + +# Base class for computing audio features using Google's speech_embedding +# model (https://tfhub.dev/google/speech_embedding/1) +class AudioFeatures(): + def __init__(self, + melspec_model_path: str = pathlib.Path(__file__).parent.parent / "models/melspectrogram.onnx", + embedding_model_path: str = pathlib.Path(__file__).parent.parent / "models/embedding_model.onnx", + sr: int = 16000, + ncpu: int = 1, + device: str = 'cpu' + ): + # Initialize ONNX options + sessionOptions = ort.SessionOptions() + sessionOptions.inter_op_num_threads = ncpu + sessionOptions.intra_op_num_threads = ncpu + # Melspectrogram model + self.melspec_model = ort.InferenceSession(melspec_model_path, sess_options=sessionOptions, + providers=["CUDAExecutionProvider"] if device == "gpu" else ["CPUExecutionProvider"]) + self.onnx_execution_provider = self.melspec_model.get_providers()[0] + self.melspec_model_predict = lambda x: self.melspec_model.run(None, {'input': x}) + # Audio embedding model + self.embedding_model = ort.InferenceSession(embedding_model_path, sess_options=sessionOptions, + providers=["CUDAExecutionProvider"] if device == "gpu" + else ["CPUExecutionProvider"]) + self.embedding_model_predict = lambda x: self.embedding_model.run(None, {'input_1': x})[0].squeeze() + # Create databuffers + self.raw_data_buffer: Deque = deque(maxlen=sr*10) + self.melspectrogram_buffer = np.ones((76, 32)) # n_frames x num_features + self.melspectrogram_max_len = 10*97 # 97 is the number of frames in 1 second of 16hz audio + self.accumulated_samples = 0 # the samples added to the buffer since the audio preprocessor was last called + self.raw_data_remainder = np.empty(0) + self.feature_buffer = self._get_embeddings(np.random.randint(-1000, 1000, 16000*4).astype(np.int16)) + self.feature_buffer_max_len = 120 # ~10 seconds of feature buffer history + + def _get_melspectrogram(self, x: Union[np.ndarray, List], melspec_transform: Callable = lambda x: x/10 + 2): + x = np.asarray(x, dtype=np.float32) # Convert to numpy array of type float32 directly + if x.ndim == 1: + x = x[np.newaxis, :] # Add new axis for single sample + outputs = self.melspec_model_predict(x) + spec = np.squeeze(outputs[0]) + + return melspec_transform(spec) + + def _get_embeddings(self, x: np.ndarray, window_size: int = 76, step_size: int = 8, **kwargs): + spec = self._get_melspectrogram(x, **kwargs) + + # Check if input is too short + if spec.shape[0] < window_size: + raise ValueError("Input is too short for the specified window size.") + + # Collect windows + windows = [spec[i:i + window_size] for i in range(0, spec.shape[0] - window_size + 1, 8) + if i + window_size <= spec.shape[0]] + + # Convert to batch format + batch = np.array(windows)[..., np.newaxis].astype(np.float32) + embedding = self.embedding_model_predict(batch) + return embedding + + def _streaming_melspectrogram(self, n_samples): + if len(self.raw_data_buffer) < 400: + raise ValueError("The number of input frames must be at least 400 samples @ 16khz (25 ms)!") + + self.melspectrogram_buffer = np.vstack( + (self.melspectrogram_buffer, self._get_melspectrogram(list(self.raw_data_buffer)[-n_samples-160*3:])) + ) + + if self.melspectrogram_buffer.shape[0] > self.melspectrogram_max_len: + self.melspectrogram_buffer = self.melspectrogram_buffer[-self.melspectrogram_max_len:, :] + + def _buffer_raw_data(self, x): + + self.raw_data_buffer.extend(x.tolist() if isinstance(x, np.ndarray) else x) + + def _streaming_features(self, x): + # Add raw audio data to buffer, temporarily storing extra frames if not an even number of 80 ms chunks + processed_samples = 0 + + if self.raw_data_remainder.shape[0] != 0: + x = np.concatenate((self.raw_data_remainder, x)) + self.raw_data_remainder = np.empty(0) + + if self.accumulated_samples + x.shape[0] >= 1280: + remainder = (self.accumulated_samples + x.shape[0]) % 1280 + if remainder != 0: + x_even_chunks = x[0:-remainder] + self._buffer_raw_data(x_even_chunks) + self.accumulated_samples += len(x_even_chunks) + self.raw_data_remainder = x[-remainder:] + elif remainder == 0: + self._buffer_raw_data(x) + self.accumulated_samples += x.shape[0] + self.raw_data_remainder = np.empty(0) + else: + self.accumulated_samples += x.shape[0] + self._buffer_raw_data(x) + + # Only calculate melspectrogram once minimum samples are accumulated + if self.accumulated_samples >= 1280 and self.accumulated_samples % 1280 == 0: + self._streaming_melspectrogram(self.accumulated_samples) + # Calculate new audio embeddings/features based on update melspectrograms + for i in np.arange(self.accumulated_samples//1280-1, -1, -1): + ndx = -8*i + ndx = ndx if ndx != 0 else len(self.melspectrogram_buffer) + x = self.melspectrogram_buffer[-76 + ndx:ndx].astype(np.float32)[None, :, :, None] + if x.shape[1] == 76: + self.feature_buffer = np.vstack((self.feature_buffer, + self.embedding_model_predict(x))) + # Reset raw data buffer counter + processed_samples = self.accumulated_samples + self.accumulated_samples = 0 + + if self.feature_buffer.shape[0] > self.feature_buffer_max_len: + self.feature_buffer = self.feature_buffer[-self.feature_buffer_max_len:, :] + + return processed_samples if processed_samples != 0 else self.accumulated_samples + + def get_features(self, n_feature_frames: int = 16, start_ndx: int = -1): + if start_ndx != -1: + end_ndx = start_ndx + int(n_feature_frames) \ + if start_ndx + n_feature_frames != 0 else len(self.feature_buffer) + return self.feature_buffer[start_ndx:end_ndx, :][None, ].astype(np.float32) + else: + return self.feature_buffer[int(-1*n_feature_frames):, :][None, ].astype(np.float32) + + def __call__(self, x): + return self._streaming_features(x) + +def download_file(url, target_directory, file_size=None): + local_filename = url.split('/')[-1] + with requests.get(url, stream=True) as r: + if file_size is not None: + progress_bar = tqdm(total=file_size, unit='iB', unit_scale=True, desc=f"{local_filename}") + else: + total_size = int(r.headers.get('content-length', 0)) + progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True, desc=f"{local_filename}") + with open(os.path.join(target_directory, local_filename), 'wb') as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + progress_bar.update(len(chunk)) + progress_bar.close() + + +# Function to download models from GitHub release assets +def download_models(model_names: List[str] = ["",], + target_directory: str = os.path.join(pathlib.Path(__file__).parent.resolve(), "resources", "models")): + if not isinstance(model_names, list): + raise ValueError("The model_names argument must be a list of strings") + # Always download melspectrogram and embedding models, if they don't already exist + if not os.path.exists(target_directory): + os.makedirs(target_directory) + for feature_model in openwakeword.FEATURE_MODELS.values(): + if not os.path.exists(os.path.join(target_directory, feature_model["download_url"].split("/")[-1])): + download_file(feature_model["download_url"], target_directory) + download_file(feature_model["download_url"].replace(".tflite", ".onnx"), target_directory) + # Always download VAD models, if they don't already exist + for vad_model in openwakeword.VAD_MODELS.values(): + if not os.path.exists(os.path.join(target_directory, vad_model["download_url"].split("/")[-1])): + download_file(vad_model["download_url"], target_directory) + # Get all model urls + official_model_urls = [i["download_url"] for i in openwakeword.MODELS.values()] + official_model_names = [i["download_url"].split("/")[-1] for i in openwakeword.MODELS.values()] + if model_names != []: + for model_name in model_names: + url = [i for i, j in zip(official_model_urls, official_model_names, strict=False) if model_name in j] + if url != []: + if not os.path.exists(os.path.join(target_directory, url[0].split("/")[-1])): + download_file(url[0], target_directory) + download_file(url[0].replace(".tflite", ".onnx"), target_directory) + else: + for official_model_url in official_model_urls: + if not os.path.exists(os.path.join(target_directory, official_model_url.split("/")[-1])): + download_file(official_model_url, target_directory) + download_file(official_model_url.replace(".tflite", ".onnx"), target_directory) + +def re_arg(kwarg_map): + def decorator(func): + def wrapped(*args, **kwargs): + new_kwargs = {} + for k, v in kwargs.items(): + new_kwargs[kwarg_map.get(k, k)] = v + return func(*args, **new_kwargs) + return wrapped + return decorator diff --git a/system/assistant/rev_speechd.py b/system/assistant/rev_speechd.py new file mode 100644 index 00000000000000..67ec6b836ab41c --- /dev/null +++ b/system/assistant/rev_speechd.py @@ -0,0 +1,149 @@ +import os +import re +import json +import time +from rev_ai.models import MediaConfig +from rev_ai.streamingclient import RevAiStreamingClient +from websocket import _exceptions +from threading import Thread, Event +from queue import Queue +from cereal import messaging, log +from openpilot.common.params import Params +from openpilot.system.micd import SAMPLE_BUFFER, SAMPLE_RATE + +class AssistantWidgetControl: + def __init__(self): + self.pm = messaging.PubMaster(['speechToText']) + self.pm.wait_for_readers_to_update('speechToText', timeout=1) + def make_msg(self): + self.pm.wait_for_readers_to_update('speechToText', timeout=1) + return messaging.new_message('speechToText', valid=True) + def begin(self): + msg = self.make_msg() + msg.speechToText.state = log.SpeechToText.State.begin # Show + self.pm.send('speechToText', msg) + def error(self): + msg = self.make_msg() + msg.speechToText.state = log.SpeechToText.State.error + self.pm.send('speechToText', msg) + def empty(self): + msg = self.make_msg() + msg.speechToText.state = log.SpeechToText.State.empty + self.pm.send('speechToText', msg) + def set_text(self, text, final=True): + msg = self.make_msg() + msg.speechToText.transcript = text + msg.speechToText.state = log.SpeechToText.State.none if not final else log.SpeechToText.State.final + self.pm.send('speechToText', msg) + +class SpeechToTextProcessor: + TIMEOUT_DURATION = 10 + RATE = SAMPLE_RATE + CHUNK = SAMPLE_BUFFER + BUFFERS_PER_SECOND = SAMPLE_RATE/SAMPLE_BUFFER + QUEUE_TIME = 10 # Save the first 10 seconds to the queue + CONNECTION_TIMEOUT = 30 + + def __init__(self, access_token, queue_size=BUFFERS_PER_SECOND*QUEUE_TIME): + self.reva_access_token = access_token + self.audio_queue = Queue(maxsize=int(queue_size)) + self.stop_thread = Event() + self.awc = AssistantWidgetControl() + self.sm = messaging.SubMaster(['microphoneRaw']) + media_config = MediaConfig('audio/x-raw', 'interleaved', 16000, 'S16LE', 1) + self.streamclient = RevAiStreamingClient(self.reva_access_token, media_config) + self.p = Params() + self.error = False + + def microphone_data_collector(self): + """Thread function for collecting microphone data.""" + while not self.stop_thread.is_set(): + self.sm.update(0) + if self.sm.updated['microphoneRaw']: + data = self.sm['microphoneRaw'].rawSample + if not self.audio_queue.full(): + self.audio_queue.put(data) + else: + print("Queue is full, stopping") + self.stop_thread.set() + self.awc.error() + + def microphone_stream(self): + """Generator that yields audio chunks from the queue.""" + loop_count = 0 + start_time = time.time() + while True: + if loop_count >= self.audio_queue.maxsize or time.time() - start_time > self.CONNECTION_TIMEOUT: + print(f'Timeout reached. {loop_count=}, {time.time()-start_time=}') + break + elif self.stop_thread.is_set(): + break + elif not self.audio_queue.empty(): + data = self.audio_queue.get(block=True) + loop_count += 1 + yield data + else: + time.sleep(.1) + + def listen_print_loop(self, response_gen, final_transcript): + try: + for response in response_gen: + data = json.loads(response) + if data['type'] == 'final': + # Extract and concatenate the final transcript then send it + final_transcript = ' '.join([element['value'] for element in data['elements'] if element['type'] == 'text']) + else: + # Handle partial transcripts (optional) + partial_transcript = ' '.join([element['value'] for element in data['elements'] if element['type'] == 'text']) + self.awc.set_text(re.sub(r'<[^>]*>', '', partial_transcript), final=False) + except Exception as e: + print(f"An error occurred: {e}") + self.error=True + return re.sub(r'<[^>]*>', '', final_transcript) # remove atmospherics. ex: + + def run(self): + self.audio_queue.queue.clear() + collector_thread = Thread(target=self.microphone_data_collector) + final_transcript = "" + self.error = False + while not self.p.get_bool("WakeWordDetected"): + time.sleep(.4) + collector_thread.start() + self.awc.begin() + try: + response_gen = self.streamclient.start(self.microphone_stream(), + remove_disfluencies=True, # remove umms + filter_profanity=True, # brand integridity or something + detailed_partials=False, # don't need time stamps + ) + final_transcript = self.listen_print_loop(response_gen, final_transcript) + except _exceptions.WebSocketAddressException as e: + print(f"WebSocketAddressException: Address unreachable. {e}") + self.error = True + except Exception as e: + print(f"An error occurred: {e}") + self.error = True + finally: + print("Waiting for collector_thread to join...") + self.stop_thread.set() # end the stream + collector_thread.join() + self.stop_thread.clear() + print("collector_thread joined") + self.awc.set_text(final_transcript, final=True) + +def main(): + try: + reva_access_token = os.environ["REVAI_ACCESS_TOKEN"] + except KeyError: + print("your rev ai acccess token which can be obtained with a free account. https://www.rev.ai/access-token") + print("Set your REVAI_ACCESS_TOKEN with the command:") + print('export REVAI_ACCESS_TOKEN="your token string"') + processor = SpeechToTextProcessor(access_token=reva_access_token) + while True: + processor.p.put_bool("WakeWordDetected", False) + processor.run() + +if __name__ == "__main__": + main() + + diff --git a/system/assistant/tests/record_unlimied.py b/system/assistant/tests/record_unlimied.py new file mode 100644 index 00000000000000..b69a3e208616e7 --- /dev/null +++ b/system/assistant/tests/record_unlimied.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +"""Create a recording with arbitrary duration. + +PySoundFile (https://github.com/bastibe/PySoundFile/) has to be installed! + +""" +import argparse +import tempfile +import queue +import sys + + +def int_or_str(text): + """Helper function for argument parsing.""" + try: + return int(text) + except ValueError: + return text + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + '-l', '--list-devices', action='store_true', + help='show list of audio devices and exit') +parser.add_argument( + '-d', '--device', type=int_or_str, + help='input device (numeric ID or substring)') +parser.add_argument( + '-r', '--samplerate', type=int, help='sampling rate') +parser.add_argument( + '-c', '--channels', type=int, default=1, help='number of input channels') +parser.add_argument( + 'filename', nargs='?', metavar='FILENAME', + help='audio file to store recording to') +parser.add_argument( + '-t', '--subtype', type=str, help='sound file subtype (e.g. "PCM_24")') +args = parser.parse_args() + +try: + import sounddevice as sd + import soundfile as sf + import numpy # Make sure NumPy is loaded before it is used in the callback + assert numpy # avoid "imported but unused" message (W0611) + + if args.list_devices: + print(sd.query_devices()) + parser.exit(0) + if args.samplerate is None: + device_info = sd.query_devices(args.device, 'input') + # soundfile expects an int, sounddevice provides a float: + args.samplerate = int(device_info['default_samplerate']) + if args.filename is None: + args.filename = tempfile.mktemp(prefix='delme_rec_unlimited_', + suffix='.wav', dir='') + q = queue.Queue() + + def callback(indata, frames, time, status): + """This is called (from a separate thread) for each audio block.""" + if status: + print(status, file=sys.stderr) + q.put(indata.copy()) + + # Make sure the file is opened before recording anything: + with sf.SoundFile(args.filename, mode='x', samplerate=args.samplerate, + channels=args.channels, subtype=args.subtype) as file: + with sd.InputStream(samplerate=args.samplerate, device=args.device, + channels=args.channels, callback=callback): + print('#' * 80) + print('press Ctrl+C to stop the recording') + print('#' * 80) + while True: + file.write(q.get()) + +except KeyboardInterrupt: + print('\nRecording finished: ' + repr(args.filename)) + parser.exit(0) +except Exception as e: + parser.exit(type(e).__name__ + ': ' + str(e)) diff --git a/system/assistant/tests/test_assistant_widget_controls.py b/system/assistant/tests/test_assistant_widget_controls.py new file mode 100644 index 00000000000000..b4b66a25b5fdcf --- /dev/null +++ b/system/assistant/tests/test_assistant_widget_controls.py @@ -0,0 +1,10 @@ +from openpilot.system.assistant.rev_speechd import AssistantWidgetControl + +if __name__ == "__main__": + awc = AssistantWidgetControl() + awc.begin() + awc.set_text("TEST", final=False) + + awc.set_text("ENDING TEST", final=True) + awc.empty() + awc.error() diff --git a/system/assistant/tests/test_wakeword.py b/system/assistant/tests/test_wakeword.py new file mode 100644 index 00000000000000..e47650cfdcdbe7 --- /dev/null +++ b/system/assistant/tests/test_wakeword.py @@ -0,0 +1,53 @@ +import unittest +import numpy as np +import wave +import openpilot.system.micd as micd +from openpilot.system.assistant.wakewordd import WakeWordListener as WWL, download_models +from openpilot.common.params import Params +from pathlib import Path + +SOUND_FILE_PATH = Path(__file__).parent / 'sounds' +EASY = f'{SOUND_FILE_PATH}/alexa_easy.wav' +MEDIUM = f'{SOUND_FILE_PATH}/alexa_medium.wav' +HARD = f'{SOUND_FILE_PATH}/alexa_hard.wav' +CONVERSATION = f'{SOUND_FILE_PATH}/random_conversation.wav' +sounds_and_detects = {EASY: True, MEDIUM: True, HARD: True, CONVERSATION: False} + +class WakeWordListener(unittest.TestCase): + + def setUp(self): + # Download models if necessary + download_models([WWL.PHRASE_MODEL_NAME], "./models") + self.wwl = WWL(model_path=WWL.PHRASE_MODEL_PATH,threshhold=0.5) + self.params = Params() + + def test_wake_word(self): + # Create a Mic instance + mic_instance = micd.Mic() + for file,should_detect in sounds_and_detects.items(): + self.params.put_bool("WakeWordDetected", False) # Reset + print(f'testing {file}, {should_detect=}') + with wave.open(file, 'rb') as wf: + # Ensure the file is mono and has the correct sample rate + assert wf.getnchannels() == 1 + assert wf.getframerate() == micd.SAMPLE_RATE + detected = False + while True: + # Read a chunk of data + frames = wf.readframes(micd.SAMPLE_BUFFER) + if len(frames) == 0: + break + # Convert frames to numpy array + indata = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768 + indata = indata.reshape(-1, 1) + + # Pass the chunk to the callback + mic_instance.callback(indata, len(indata), None, None) + mic_instance.update() + self.wwl.wake_word_runner() + if self.params.get_bool("WakeWordDetected"): + detected = True + self.assertEqual(detected,should_detect, f'{detected=} {should_detect=} for sound {file}') + +if __name__ == '__main__': + unittest.main() diff --git a/system/assistant/wakewordd.py b/system/assistant/wakewordd.py new file mode 100644 index 00000000000000..81665f68f3209c --- /dev/null +++ b/system/assistant/wakewordd.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +import numpy as np +from pathlib import Path +from openpilot.system.assistant.openwakeword import Model +from openpilot.system.assistant.openwakeword.utils import download_models +from openpilot.common.params import Params +from cereal import messaging +from openpilot.system.micd import SAMPLE_BUFFER, SAMPLE_RATE + + +class WakeWordListener: + RATE = 12.5 + PHRASE_MODEL_NAME = "alexa_v0.1" + MODEL_DIR = Path(__file__).parent / 'models' + PHRASE_MODEL_PATH = f'{MODEL_DIR}/{PHRASE_MODEL_NAME}.onnx' + MEL_MODEL_PATH = f'{MODEL_DIR}/melspectrogram.onnx' + EMB_MODEL_PATH = f'{MODEL_DIR}/embedding_model.onnx' + THRESHOLD = .5 + def __init__(self, model_path=PHRASE_MODEL_PATH, threshhold=THRESHOLD): + self.owwModel = Model(wakeword_models=[model_path], melspec_model_path=self.MEL_MODEL_PATH, embedding_model_path=self.EMB_MODEL_PATH, sr=SAMPLE_RATE) + self.sm = messaging.SubMaster(['microphoneRaw']) + self.params = Params() + + self.model_name = model_path.split("/")[-1].split(".onnx")[0] + self.frame_index = 0 + self.frame_index_last = 0 + self.detected_last = False + self.threshhold = threshhold + + def update(self): + self.frame_index = self.sm['microphoneRaw'].frameIndex + if not (self.frame_index_last == self.frame_index or + self.frame_index - self.frame_index_last == SAMPLE_BUFFER): + print(f'skipped {(self.frame_index - self.frame_index_last)//SAMPLE_BUFFER-1} sample(s)') # TODO: Stop it from skipping + if self.frame_index_last == self.frame_index: + print("got the same frame") + return + self.frame_index_last = self.frame_index + sample = np.frombuffer(self.sm['microphoneRaw'].rawSample, dtype=np.int16) + prediction_score = self.owwModel.predict(sample) + detected = prediction_score[self.model_name] >= self.threshhold + if detected: + print("wake word detected") + self.params.put_bool("WakeWordDetected", True) + self.detected_last = detected + + def wake_word_runner(self): + self.sm.update(0) + if self.sm.updated['microphoneRaw']: + self.update() + +def main(): + download_models([WakeWordListener.PHRASE_MODEL_NAME], WakeWordListener.MODEL_DIR) + wwl = WakeWordListener() + while True: + wwl.wake_word_runner() + +if __name__ == "__main__": + main() diff --git a/system/micd.py b/system/micd.py index 8b738ebe939bcf..01ad7f2f8480ef 100755 --- a/system/micd.py +++ b/system/micd.py @@ -2,15 +2,14 @@ import numpy as np from cereal import messaging -from openpilot.common.realtime import Ratekeeper from openpilot.common.retry import retry from openpilot.common.swaglog import cloudlog +import threading -RATE = 10 -FFT_SAMPLES = 4096 +FFT_SAMPLES = 1280 REFERENCE_SPL = 2e-5 # newtons/m^2 -SAMPLE_RATE = 44100 -SAMPLE_BUFFER = 4096 # (approx 100ms) +SAMPLE_RATE = 16000 +SAMPLE_BUFFER = 1280 # (80ms) def calculate_spl(measurements): @@ -41,14 +40,18 @@ def apply_a_weighting(measurements: np.ndarray) -> np.ndarray: class Mic: def __init__(self): - self.rk = Ratekeeper(RATE) - self.pm = messaging.PubMaster(['microphone']) + + self.pm = messaging.PubMaster(['microphone', 'microphoneRaw']) + self.indata_ready_event = threading.Event() self.measurements = np.empty(0) self.sound_pressure = 0 self.sound_pressure_weighted = 0 self.sound_pressure_level_weighted = 0 + self.frame_index = 0 + self.frame_index_last = 0 + self.raw_sample = np.empty(SAMPLE_BUFFER, dtype=np.float32) def update(self): msg = messaging.new_message('microphone', valid=True) @@ -56,9 +59,19 @@ def update(self): msg.microphone.soundPressureWeighted = float(self.sound_pressure_weighted) msg.microphone.soundPressureWeightedDb = float(self.sound_pressure_level_weighted) - self.pm.send('microphone', msg) - self.rk.keep_time() + msg = messaging.new_message('microphoneRaw', valid=True) + + self.indata_ready_event.wait(.9) + msg.microphoneRaw.rawSample = np.int16(self.raw_sample * 32767).tobytes() + msg.microphoneRaw.frameIndex = self.frame_index + if not (self.frame_index_last == self.frame_index or + self.frame_index - self.frame_index_last == SAMPLE_BUFFER): + cloudlog.info(f'skipped {(self.frame_index - self.frame_index_last)//SAMPLE_BUFFER-1} samples') + + self.frame_index_last = self.frame_index + self.pm.send('microphoneRaw', msg) + self.indata_ready_event.clear() def callback(self, indata, frames, time, status): """ @@ -79,6 +92,10 @@ def callback(self, indata, frames, time, status): self.measurements = self.measurements[FFT_SAMPLES:] + self.frame_index += frames + self.raw_sample = indata[:, 0].copy() + self.indata_ready_event.set() + @retry(attempts=7, delay=3) def get_stream(self, sd): # reload sounddevice to reinitialize portaudio