Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from flask import Flask, render_template
from flask_sockets import Sockets
from wav2vec2_inference import Wave2Vec2Inference
import threading
from queue import Queue
import numpy as np
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
import base64

app = Flask(__name__)
sockets = Sockets(app)

class LiveWav2Vec2:
exit_event = threading.Event()

def __init__(self, model_name):
self.model_name = model_name

def stop(self):
"""stop the asr process"""
LiveWav2Vec2.exit_event.set()
print("asr stopped")

def start(self):
"""start the asr process"""
self.asr_output_queue = Queue()
self.asr_input_queue = Queue()

self.asr_process = threading.Thread(target=LiveWav2Vec2.asr_process, args=(
self.model_name, self.asr_input_queue, self.asr_output_queue,))
self.asr_process.start()

@staticmethod
def asr_process(model_name, in_queue, output_queue):
wave2vec_asr = Wave2Vec2Inference(model_name, use_lm_if_possible=True)

print("\nlistening to your voice\n")
while True:
audio_frames = in_queue.get()
if audio_frames == "close":
break
float64_buffer = np.frombuffer(audio_frames, dtype=np.int16) / 32767
text, confidence = wave2vec_asr.buffer_to_text(float64_buffer)
text = text.lower()
sample_length = len(audio_frames) / 16000

if text != "":
output_queue.put([text, sample_length, confidence])


def get_last_text(self):
"""returns the text, sample length and inference time in seconds."""
return self.asr_output_queue.get()

@app.route('/')
def index():
return render_template('index.html')

@sockets.route('/audio_stream')
def audio_stream(ws):
live_wav2vec2 = LiveWav2Vec2("jonatasgrosman/wav2vec2-large-xlsr-53-english")
live_wav2vec2.start()

try:
while not ws.closed:
message = ws.receive()
message = base64.b64decode(message)
if message:
live_wav2vec2.asr_input_queue.put(message)
if message:
text, sample_length, confidence = live_wav2vec2.get_last_text()
print(f"{sample_length:.3f}s\t{confidence}\t{text}")
ws.send(text)
except KeyboardInterrupt:
live_wav2vec2.stop()
exit()

if __name__ == "__main__":
server = pywsgi.WSGIServer(('127.0.0.1', 5010), app, handler_class=WebSocketHandler)
print("Server running at http://127.0.0.1:5010/")
server.serve_forever()
120 changes: 120 additions & 0 deletions requirements[Flask].txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
a==1.0
aiohttp==3.9.3
aiosignal==1.3.1
attrs==23.2.0
audioread==3.0.1
bidict==0.22.1
blinker==1.7.0
certifi==2023.11.17
cffi==1.16.0
charset-normalizer==3.3.2
click==7.1.2
colorama==0.4.6
coloredlogs==15.0.1
datasets==2.16.1
decorator==5.1.1
dill==0.3.7
farasapy==0.0.14
filelock==3.13.1
Flask==1.1.4
Flask-SocketIO==5.3.6
Flask-Sockets==0.2.1
flatbuffers==23.5.26
frozenlist==1.4.1
fsspec==2023.10.0
gevent==23.9.1
gevent-websocket==0.10.1
greenlet==3.0.3
h11==0.14.0
halo==0.0.31
huggingface-hub==0.20.3
humanfriendly==10.0
hypothesis==6.97.4
idna==3.6
itsdangerous==1.1.0
Jinja2==2.11.3
joblib==1.3.2
kenlm @ https://github.com/kpu/kenlm/archive/master.zip#sha256=4d002dcde70b52d519cafff4dc0008696c40cff1c9184a531b40c7b45905be6b
lazy_loader==0.3
librosa==0.10.1
llvmlite==0.41.1
log-symbols==0.0.14
MarkupSafe==2.0.1
more-itertools==10.2.0
mpmath==1.3.0
msgpack==1.0.7
multidict==6.0.4
multiprocess==0.70.15
networkx==3.2.1
numba==0.58.1
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
onnx==1.15.0
onnxruntime==1.17.0
openai-whisper==20231117
packaging==23.2
pandas==2.2.0
platformdirs==4.2.0
pooch==1.8.0
protobuf==4.25.2
pyarrow==15.0.0
pyarrow-hotfix==0.6
PyAudio==0.2.14
pycparser==2.21
pyctcdecode==0.5.0
pydub==0.25.1
pygtrie==2.5.0
python-dateutil==2.8.2
python-engineio==4.8.2
python-socketio==5.11.0
pytz==2023.4
PyYAML==6.0.1
regex==2023.12.25
requests==2.31.0
Rx==3.2.0
safetensors==0.4.2
samplerate==0.2.1
scikit-learn==1.4.0
scipy==1.12.0
simple-websocket==1.0.0
six==1.16.0
sortedcontainers==2.4.0
soundfile==0.12.1
soxr==0.3.7
SpeechRecognition==3.10.1
spinners==0.0.24
srt==3.5.3
sympy==1.12
termcolor==2.4.0
threadpoolctl==3.2.0
tiktoken==0.5.2
tnkeeh==0.0.9
tokenizers==0.15.1
torch==2.2.0
torchaudio==2.2.0
tqdm==4.66.1
transformers==4.37.2
triton==2.2.0
typing_extensions==4.9.0
tzdata==2023.4
urllib3==2.1.0
vosk==0.3.45
webrtcvad==2.0.10
websockets==12.0
Werkzeug==1.0.1
wsproto==1.2.0
xxhash==3.4.1
yarl==1.9.4
zope.event==5.0
zope.interface==6.1
84 changes: 84 additions & 0 deletions templates/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Realtime WebSocket Audio Streaming</title>
<style>
body {
background-color: black;
color: green;
}
</style>
</head>
<body>
<h1>Realtime WebSocket Audio Streaming</h1>
<button id="startButton">Start Streaming</button>
<button id="stopButton">Stop Streaming</button>
<div id="responseContainer"></div>
<script src="https://www.WebRTC-Experiment.com/RecordRTC.js"></script>
<script>
let ws = new WebSocket('ws://localhost:5010/audio_stream');
let mediaRecorder;

ws.onmessage = event => {
let responseContainer = document.getElementById('responseContainer');
console.log(event)
responseContainer.innerHTML += `<p>${event.data}</p>`;
};

let handleDataAvailable = (event) => {
if (event.size > 0) {
console.log('blob', event)
blobToBase64(event).then(b64 => {
ws.send(b64)
})
}
};

function blobToBase64(blob) {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.readAsDataURL(blob);
reader.onload = () => {
const base64String = reader.result.split(',')[1];
resolve(base64String);
};
reader.onerror = (error) => reject(error);
});
}

navigator.mediaDevices.getUserMedia({ audio: true })
.then(stream => {
let recorder = RecordRTC(stream, {
type: 'audio',
recorderType: StereoAudioRecorder,
mimeType: 'audio/wav',
timeSlice: 500,
desiredSampRate: 16000,
numberOfAudioChannels: 1,
ondataavailable: handleDataAvailable
});

document.getElementById('startButton').addEventListener('click', () => {
recorder.startRecording();
});

document.getElementById('stopButton').addEventListener('click', () => {
recorder.stopRecording();
recorder.getDataURL(dataURL => {
ws.send(dataURL);
});
});
});

ws.onopen = () => {
console.log('WebSocket connection opened');
};

ws.onclose = () => {
console.log('WebSocket connection closed');
};
</script>
</body>
</html>