Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 116 additions & 4 deletions crates/dictyped/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ where
let _ = tx
.send(Err(Status::internal(format!("failed to record: {e:?}"))))
.await;
let _ = state.lock().expect("state poisoned").reset();
let _ = state.lock().expect("state poisoned").clear();
return;
}
};
Expand All @@ -96,7 +96,7 @@ where
Ok(client) => client,
Err(e) => {
let _ = tx.send(Err(e)).await;
let _ = state.lock().expect("state poisoned").reset();
let _ = state.lock().expect("state poisoned").clear();
return;
}
};
Expand All @@ -121,7 +121,7 @@ where
}
}

let _ = state.lock().expect("state poisoned").reset();
let _ = state.lock().expect("state poisoned").clear();
});

let response_stream = ReceiverStream::new(rx);
Expand All @@ -130,7 +130,7 @@ where
}

async fn stop(&self, _request: Request<StopRequest>) -> Result<Response<StopResponse>, Status> {
let stopped = self.state.lock().expect("state poisoned").reset();
let stopped = self.state.lock().expect("state poisoned").stop();
let response = StopResponse { stopped };
Ok(Response::new(response))
}
Expand Down Expand Up @@ -159,8 +159,10 @@ mod tests {

mod mock_recorders {
use std::io;
use std::time::Duration;

use async_stream::stream;
use tokio::time::sleep;
use tokio_util::bytes::Bytes;
use tokio_util::sync::CancellationToken;

Expand Down Expand Up @@ -196,6 +198,42 @@ mod tests {
}
}

pub(super) struct PacedNoiseRecorder {
remaining: usize,
delay: Duration,
}

impl AudioCapture for PacedNoiseRecorder {
type CaptureOption = usize;

fn new(emit_count: Self::CaptureOption) -> io::Result<Self> {
Ok(Self {
remaining: emit_count,
delay: Duration::from_millis(10),
})
}

fn create(&self, _cancellation_token: CancellationToken) -> io::Result<AudioStream> {
let mut remaining = self.remaining;
let delay = self.delay;
Ok(AudioStream(Box::pin(stream! {
let mut value = 0x1234_5678_u32;

loop {
if remaining == 0 {
return;
}
value = value.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
yield Ok(Bytes::from(value.to_le_bytes().to_vec()));
remaining -= 1;
if remaining > 0 {
sleep(delay).await;
}
}
})))
}
}

pub(super) struct ImmediateBadCaptureRecorder;

impl AudioCapture for ImmediateBadCaptureRecorder {
Expand Down Expand Up @@ -386,6 +424,21 @@ mod tests {
NoiseRecorder::new(capture_count).expect("NoiseRecorder must initialize"),
)
}

pub(super) fn paced_asr_service(
capture_count: usize,
) -> DictypeService<PacedNoiseRecorder> {
let mut clients = BTreeMap::new();
clients.insert(
"yes-asr".to_string(),
Arc::new(YesAsrClient {}) as Arc<dyn BackendClient + Send + Sync>,
);

DictypeService::new(
ClientStore::from_clients(clients),
PacedNoiseRecorder::new(capture_count).expect("PacedNoiseRecorder must initialize"),
)
}
}

#[tokio::test]
Expand Down Expand Up @@ -658,4 +711,63 @@ mod tests {
.await
.expect("transcribe should succeed again after stop");
}

#[tokio::test]
async fn stop_keeps_service_busy_until_stream_drains() {
let capture_count = 16;
let service = paced_asr_service(capture_count);

let mut first_stream = service
.transcribe(Request::new(TranscribeRequest {
profile_name: "yes-asr".to_string(),
}))
.await
.expect("transcribe should succeed")
.into_inner();

let first = first_stream
.next()
.await
.expect("stream should not be empty")
.expect("stream should not fail");
assert_eq!(first.text, "yes");

let stop_response = service
.stop(Request::new(StopRequest {}))
.await
.expect("stop should succeed")
.into_inner();
assert!(stop_response.stopped);

let second_err = service
.transcribe(Request::new(TranscribeRequest {
profile_name: "yes-asr".to_string(),
}))
.await
.expect_err("service must stay busy until the current stream drains");
assert_eq!(second_err.code(), Code::AlreadyExists);
assert_eq!(second_err.message(), "request exists");

let mut success_count = 1;
while let Some(response) = first_stream.next().await {
let response = response.expect("stream should not fail while draining");
assert_eq!(response.text, "yes");
success_count += 1;
}
assert_eq!(success_count, capture_count);
assert!(first_stream.next().await.is_none());

let restarted = service
.transcribe(Request::new(TranscribeRequest {
profile_name: "yes-asr".to_string(),
}))
.await
.expect("transcribe should succeed after drain completes")
.into_inner()
.next()
.await
.expect("restarted stream should yield a response")
.expect("restarted stream should not fail");
assert_eq!(restarted.text, "yes");
}
}
8 changes: 6 additions & 2 deletions crates/dictyped/src/service_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ impl ServiceState {
self.cancellation_token.is_some()
}

pub(crate) fn reset(&mut self) -> bool {
if let Some(cancellation_token) = self.cancellation_token.take() {
pub(crate) fn stop(&self) -> bool {
if let Some(cancellation_token) = &self.cancellation_token {
cancellation_token.cancel();
info!("stop: stopped session");
true
Expand All @@ -28,6 +28,10 @@ impl ServiceState {
}
}

pub(crate) fn clear(&mut self) -> bool {
self.cancellation_token.take().is_some()
}

pub(crate) fn replace(&mut self, cancellation_token: CancellationToken) -> Result<(), Status> {
let existing_cancellation = self.cancellation_token.replace(cancellation_token);

Expand Down