Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions rust/src/jsonrpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;

use parking_lot::RwLock;
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::task::JoinHandle;
use tracing::{Instrument, debug, error, warn};

use crate::{Error, ProtocolError};
Expand Down Expand Up @@ -184,6 +185,8 @@ pub struct JsonRpcClient {
pending_requests: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
notification_tx: broadcast::Sender<JsonRpcNotification>,
request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
read_task: Mutex<Option<JoinHandle<()>>>,
write_task: Mutex<Option<JoinHandle<()>>>,
}

impl JsonRpcClient {
Expand All @@ -202,22 +205,24 @@ impl JsonRpcClient {
let (write_tx, write_rx) = mpsc::unbounded_channel::<WriteCommand>();

let writer_span = tracing::error_span!("jsonrpc_write_loop");
tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span));
let write_task = tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span));

let client = Self {
request_id: AtomicU64::new(1),
write_tx,
pending_requests: Arc::new(RwLock::new(HashMap::new())),
notification_tx,
request_tx,
read_task: Mutex::new(None),
write_task: Mutex::new(Some(write_task)),
};

let pending_requests = client.pending_requests.clone();
let notification_tx_clone = client.notification_tx.clone();
let request_tx_clone = client.request_tx.clone();
let reader_span = tracing::error_span!("jsonrpc_read_loop");

tokio::spawn(
let read_task = tokio::spawn(
async move {
Self::read_loop(
reader,
Expand All @@ -229,10 +234,21 @@ impl JsonRpcClient {
}
.instrument(reader_span),
);
*client.read_task.lock() = Some(read_task);

client
}

pub(crate) fn force_close(&self) {
if let Some(task) = self.read_task.lock().take() {
task.abort();
}
if let Some(task) = self.write_task.lock().take() {
task.abort();
}
self.pending_requests.write().clear();
}

/// Writer-actor task. Owns the `AsyncWrite`, drains the command queue,
/// and writes each frame atomically (header + body + flush) before
/// signaling the ack.
Expand Down
198 changes: 165 additions & 33 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,15 @@ pub enum SessionError {
/// non-empty.
#[error("invalid SessionFsConfig: {0}")]
InvalidSessionFsConfig(String),

/// The CLI returned a different session ID than the one the SDK registered.
#[error("CLI returned session ID {returned} after SDK registered {requested}")]
SessionIdMismatch {
/// Session ID registered by the SDK before the RPC was sent.
requested: SessionId,
/// Session ID returned by the CLI.
returned: SessionId,
},
}

/// How the SDK communicates with the CLI server.
Expand Down Expand Up @@ -844,6 +853,23 @@ fn generate_connection_token() -> String {
hex
}

fn generate_session_id() -> SessionId {
let mut bytes = [0u8; 16];
getrandom::getrandom(&mut bytes)
.expect("OS CSPRNG (getrandom) is unavailable; cannot generate session ID");
bytes[6] = (bytes[6] & 0x0f) | 0x40; // UUID version 4.
bytes[8] = (bytes[8] & 0x3f) | 0x80; // UUID variant 1.
let mut id = String::with_capacity(36);
for (index, byte) in bytes.into_iter().enumerate() {
if matches!(index, 4 | 6 | 8 | 10) {
id.push('-');
}
use std::fmt::Write;
let _ = write!(id, "{byte:02x}");
}
SessionId::from(id)
}

/// Connection to a GitHub Copilot CLI server (stdio, TCP, or external).
///
/// Cheaply cloneable — cloning shares the underlying connection.
Expand Down Expand Up @@ -873,6 +899,7 @@ struct ClientInner {
state: parking_lot::Mutex<ConnectionState>,
lifecycle_tx: broadcast::Sender<SessionLifecycleEvent>,
on_list_models: Option<Arc<dyn ListModelsHandler>>,
models_cache: parking_lot::Mutex<Arc<tokio::sync::OnceCell<Vec<Model>>>>,
session_fs_configured: bool,
on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
/// Token sent in the `connect` handshake. Auto-generated when the
Expand Down Expand Up @@ -1138,6 +1165,7 @@ impl Client {
state: parking_lot::Mutex::new(ConnectionState::Connected),
lifecycle_tx: broadcast::channel(256).0,
on_list_models,
models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())),
session_fs_configured,
on_get_trace_context,
effective_connection_token,
Expand Down Expand Up @@ -1752,10 +1780,17 @@ impl Client {
/// When [`ClientOptions::on_list_models`] is set, returns the handler's
/// result without making a `models.list` RPC. Otherwise queries the CLI.
pub async fn list_models(&self) -> Result<Vec<Model>, Error> {
if let Some(handler) = &self.inner.on_list_models {
return handler.list_models().await;
}
Ok(self.rpc().models().list().await?.models)
let cache = self.inner.models_cache.lock().clone();
let models = cache
.get_or_try_init(|| async {
if let Some(handler) = &self.inner.on_list_models {
handler.list_models().await
} else {
Ok(self.rpc().models().list().await?.models)
}
})
.await?;
Ok(models.clone())
}

/// Invoke [`ClientOptions::on_get_trace_context`] when configured,
Expand Down Expand Up @@ -1828,6 +1863,7 @@ impl Client {

let child = self.inner.child.lock().take();
*self.inner.state.lock() = ConnectionState::Disconnected;
*self.inner.models_cache.lock() = Arc::new(tokio::sync::OnceCell::new());
if let Some(mut child) = child
&& let Err(e) = child.kill().await
{
Expand Down Expand Up @@ -1879,10 +1915,12 @@ impl Client {
{
error!(pid = ?pid, error = %e, "failed to send kill signal");
}
self.inner.rpc.force_close();
// Drop all session channels so any awaiters see a closed channel
// instead of waiting for responses that will never arrive.
self.inner.router.clear();
*self.inner.state.lock() = ConnectionState::Disconnected;
*self.inner.models_cache.lock() = Arc::new(tokio::sync::OnceCell::new());
}

/// Subscribe to lifecycle events.
Expand Down Expand Up @@ -2405,43 +2443,137 @@ mod tests {
policy: None,
supported_reasoning_efforts: Vec::new(),
};
let handler = Arc::new(CountingHandler {
let handler: Arc<dyn ListModelsHandler> = Arc::new(CountingHandler {
calls: Arc::clone(&calls),
models: vec![model.clone()],
});

// We can't call list_models() through Client::start without a CLI, but we
// can exercise the override path by directly constructing a Client whose
// inner has the handler set. This is the same dispatch path as the real
// call; from_streams's None default is replaced via inner construction.
let inner = ClientInner {
child: parking_lot::Mutex::new(None),
rpc: {
let (req_tx, _req_rx) = mpsc::unbounded_channel();
let (notif_tx, _notif_rx) = broadcast::channel(16);
let (read_pipe, _write_pipe) = tokio::io::duplex(64);
let (_unused_read, write_pipe) = tokio::io::duplex(64);
JsonRpcClient::new(write_pipe, read_pipe, notif_tx, req_tx)
},
cwd: PathBuf::from("."),
request_rx: parking_lot::Mutex::new(None),
notification_tx: broadcast::channel(16).0,
router: router::SessionRouter::new(),
negotiated_protocol_version: OnceLock::new(),
state: parking_lot::Mutex::new(ConnectionState::Connected),
lifecycle_tx: broadcast::channel(16).0,
on_list_models: Some(handler),
session_fs_configured: false,
on_get_trace_context: None,
effective_connection_token: None,
};
let client = Client {
inner: Arc::new(inner),
};
let client = client_with_list_models_handler(handler);

let result = client.list_models().await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].id, "byok-gpt-4");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}

#[tokio::test]
async fn list_models_serializes_concurrent_cache_misses() {
use std::sync::atomic::{AtomicUsize, Ordering};

struct SlowCountingHandler {
calls: Arc<AtomicUsize>,
models: Vec<Model>,
}
#[async_trait]
impl ListModelsHandler for SlowCountingHandler {
async fn list_models(&self) -> Result<Vec<Model>, Error> {
self.calls.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(25)).await;
Ok(self.models.clone())
}
}

let calls = Arc::new(AtomicUsize::new(0));
let model = Model {
billing: None,
capabilities: ModelCapabilities {
limits: None,
supports: None,
},
default_reasoning_effort: None,
id: "single-flight-model".into(),
name: "Single Flight Model".into(),
policy: None,
supported_reasoning_efforts: Vec::new(),
};
let handler: Arc<dyn ListModelsHandler> = Arc::new(SlowCountingHandler {
calls: Arc::clone(&calls),
models: vec![model],
});
let client = client_with_list_models_handler(handler);

let (first, second) = tokio::join!(client.list_models(), client.list_models());
assert_eq!(first.unwrap()[0].id, "single-flight-model");
assert_eq!(second.unwrap()[0].id, "single-flight-model");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}

#[tokio::test]
async fn cancelled_create_session_unregisters_pending_session() {
let (client_write, _server_read) = tokio::io::duplex(8192);
let (_server_write, client_read) = tokio::io::duplex(8192);
let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap();
let handle = tokio::spawn({
let client = client.clone();
async move { client.create_session(SessionConfig::default()).await }
});

wait_for_pending_session_registration(&client).await;
handle.abort();
let _ = handle.await;

assert!(client.inner.router.session_ids().is_empty());
client.force_stop();
}

#[tokio::test]
async fn cancelled_resume_session_unregisters_pending_session() {
let (client_write, _server_read) = tokio::io::duplex(8192);
let (_server_write, client_read) = tokio::io::duplex(8192);
let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap();
let session_id = SessionId::new("resume-cancel-test");
let handle = tokio::spawn({
let client = client.clone();
async move {
client
.resume_session(ResumeSessionConfig::new(session_id))
.await
}
});

wait_for_pending_session_registration(&client).await;
handle.abort();
let _ = handle.await;

assert!(client.inner.router.session_ids().is_empty());
client.force_stop();
}

fn client_with_list_models_handler(handler: Arc<dyn ListModelsHandler>) -> Client {
Client {
inner: Arc::new(ClientInner {
child: parking_lot::Mutex::new(None),
rpc: {
let (req_tx, _req_rx) = mpsc::unbounded_channel();
let (notif_tx, _notif_rx) = broadcast::channel(16);
let (read_pipe, _write_pipe) = tokio::io::duplex(64);
let (_unused_read, write_pipe) = tokio::io::duplex(64);
JsonRpcClient::new(write_pipe, read_pipe, notif_tx, req_tx)
},
cwd: PathBuf::from("."),
request_rx: parking_lot::Mutex::new(None),
notification_tx: broadcast::channel(16).0,
router: router::SessionRouter::new(),
negotiated_protocol_version: OnceLock::new(),
state: parking_lot::Mutex::new(ConnectionState::Connected),
lifecycle_tx: broadcast::channel(16).0,
on_list_models: Some(handler),
models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())),
session_fs_configured: false,
on_get_trace_context: None,
effective_connection_token: None,
}),
}
}

async fn wait_for_pending_session_registration(client: &Client) {
let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1);
while client.inner.router.session_ids().is_empty() {
assert!(
tokio::time::Instant::now() < deadline,
"session was not registered"
);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}
}
Loading
Loading