Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::types::{
ClientOptions, ConnectionState, GetAuthStatusResponse, GetForegroundSessionResponse,
GetStatusResponse, LogLevel, ModelInfo, PingResponse, ProviderConfig, ResumeSessionConfig,
SessionConfig, SessionLifecycleEvent, SessionMetadata, SetForegroundSessionResponse, StopError,
SDK_PROTOCOL_VERSION,
MIN_PROTOCOL_VERSION, SDK_PROTOCOL_VERSION,
};
use serde_json::{json, Value};
use std::collections::HashMap;
Expand Down Expand Up @@ -520,6 +520,7 @@ pub struct Client {
lifecycle_handlers: Arc<RwLock<HashMap<u64, LifecycleHandler>>>,
next_lifecycle_handler_id: AtomicU64,
models_cache: Arc<Mutex<Option<Vec<ModelInfo>>>>,
negotiated_protocol_version: Arc<Mutex<Option<u32>>>,
}

impl Client {
Expand Down Expand Up @@ -571,6 +572,7 @@ impl Client {
lifecycle_handlers: Arc::new(RwLock::new(HashMap::new())),
next_lifecycle_handler_id: AtomicU64::new(1),
models_cache: Arc::new(Mutex::new(None)),
negotiated_protocol_version: Arc::new(Mutex::new(None)),
})
}

Expand Down Expand Up @@ -1196,7 +1198,8 @@ impl Client {
Ok(())
}

/// Verify protocol version matches.
/// Verify that the server's protocol version is within the supported range
/// and store the negotiated version.
async fn verify_protocol_version(&self) -> Result<()> {
// NOTE: We call the underlying RPC directly instead of ping() because ping() calls
// ensure_connected(), but we haven't set state to Connected yet.
Expand All @@ -1206,23 +1209,39 @@ impl Client {
.invoke("ping", Some(serde_json::json!({ "message": null })))
.await?;

let protocol_version = result
let server_version = result
.get("protocolVersion")
.and_then(|v| v.as_u64())
.map(|v| v as u32);

if let Some(version) = protocol_version {
if version != SDK_PROTOCOL_VERSION {
match server_version {
None => {
return Err(CopilotError::ProtocolMismatch {
expected: SDK_PROTOCOL_VERSION,
min: MIN_PROTOCOL_VERSION,
max: SDK_PROTOCOL_VERSION,
actual: 0,
});
}
Some(version) if version < MIN_PROTOCOL_VERSION || version > SDK_PROTOCOL_VERSION => {
return Err(CopilotError::ProtocolMismatch {
min: MIN_PROTOCOL_VERSION,
max: SDK_PROTOCOL_VERSION,
actual: version,
});
}
Some(version) => {
*self.negotiated_protocol_version.lock().await = Some(version);
}
}

Ok(())
}

/// Get the negotiated protocol version (set after successful start).
pub async fn negotiated_protocol_version(&self) -> Option<u32> {
*self.negotiated_protocol_version.lock().await
}

/// Set up notification and request handlers.
async fn setup_handlers(&self) -> Result<()> {
let rpc = self.rpc.lock().await;
Expand Down Expand Up @@ -1269,7 +1288,10 @@ impl Client {
// Clone Arc references for request handler
let sessions_for_requests = Arc::clone(&self.sessions);

// Set up request handler for tool.call and permission.request
// Protocol v2 backward-compatibility adapters.
// v2 servers send tool.call / permission.request as RPC requests.
// v3 servers send them as broadcast session events (handled in Session::handle_broadcast_event).
// We always register v2 handlers; a v3 server will simply never send these requests.
rpc.set_request_handler(move |method, params| {
use crate::jsonrpc::JsonRpcError;

Expand Down Expand Up @@ -1677,3 +1699,4 @@ mod tests {
assert_eq!(normalize_tool_arguments(&params), json!({}));
}
}

4 changes: 2 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ pub enum CopilotError {
},

/// Protocol version mismatch
#[error("Protocol version mismatch: expected {expected}, got {actual}")]
ProtocolMismatch { expected: u32, actual: u32 },
#[error("Protocol version mismatch: SDK supports versions {min}-{max}, but server reports version {actual}. Please update your SDK or server to ensure compatibility.")]
ProtocolMismatch { min: u32, max: u32, actual: u32 },

/// Protocol error (framing, invalid messages, etc.)
#[error("Protocol error: {0}")]
Expand Down
45 changes: 43 additions & 2 deletions src/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,9 +556,40 @@ pub struct SkillInvokedData {
// Session Event (Discriminated Union)
// =============================================================================

/// Event data variants - the payload of each event type.
/// Data for `external_tool.requested` event (protocol v3 broadcast model).
///
/// In protocol v3, tool calls are broadcast as session events instead of
/// RPC requests. The SDK handles these internally and responds via
/// `session.tools.handlePendingToolCall` RPC.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ExternalToolRequestedData {
/// Unique request ID for correlating the response.
pub request_id: Option<String>,
/// Name of the tool being requested.
pub tool_name: Option<String>,
/// Tool call ID for tracking.
pub tool_call_id: Option<String>,
/// Arguments to pass to the tool handler.
pub arguments: Option<serde_json::Value>,
}

/// Data for `permission.requested` event (protocol v3 broadcast model).
///
/// In protocol v3, permission requests are broadcast as session events.
/// The SDK handles these internally and responds via
/// `session.permissions.handlePendingPermissionRequest` RPC.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
#[serde(rename_all = "camelCase")]
pub struct PermissionRequestedData {
/// Unique request ID for correlating the response.
pub request_id: Option<String>,
/// The permission request details.
pub permission_request: Option<serde_json::Value>,
}

/// Event data variants - the payload of each event type.
#[derive(Debug, Clone, Serialize)]
pub enum SessionEventData {
SessionStart(SessionStartData),
SessionResume(SessionResumeData),
Expand Down Expand Up @@ -597,6 +628,10 @@ pub enum SessionEventData {
SessionSnapshotRewind(SessionSnapshotRewindData),
SessionUsageInfo(SessionUsageInfoData),
SkillInvoked(SkillInvokedData),
/// External tool requested (protocol v3 broadcast).
ExternalToolRequested(ExternalToolRequestedData),
/// Permission requested (protocol v3 broadcast).
PermissionRequested(PermissionRequestedData),
/// Unknown event - preserves raw JSON for forward compatibility.
Unknown(serde_json::Value),
}
Expand Down Expand Up @@ -847,6 +882,12 @@ fn parse_event_data(event_type: &str, data: serde_json::Value) -> SessionEventDa
"skill.invoked" => serde_json::from_value(data)
.map(SessionEventData::SkillInvoked)
.unwrap_or_else(|_| SessionEventData::Unknown(serde_json::Value::Null)),
"external_tool.requested" => serde_json::from_value(data)
.map(SessionEventData::ExternalToolRequested)
.unwrap_or_else(|_| SessionEventData::Unknown(serde_json::Value::Null)),
"permission.requested" => serde_json::from_value(data)
.map(SessionEventData::PermissionRequested)
.unwrap_or_else(|_| SessionEventData::Unknown(serde_json::Value::Null)),
// Unknown event type - preserve raw data
_ => SessionEventData::Unknown(data),
}
Expand Down
134 changes: 133 additions & 1 deletion src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! A session represents a conversation with the Copilot CLI.

use crate::error::{CopilotError, Result};
use crate::events::{SessionEvent, SessionEventData};
use crate::events::{SessionEvent, SessionEventData, ExternalToolRequestedData, PermissionRequestedData};
use crate::types::{
ErrorOccurredHookInput, MessageOptions, PermissionRequest, PermissionRequestResult,
PostToolUseHookInput, PreToolUseHookInput, SessionEndHookInput, SessionHooks,
Expand Down Expand Up @@ -224,8 +224,15 @@ impl Session {

/// Dispatch an event to all subscribers.
///
/// Broadcast request events (external_tool.requested, permission.requested) are handled
/// internally before being forwarded to user handlers (protocol v3 model).
///
/// This is called by the Client when events are received.
pub async fn dispatch_event(&self, event: SessionEvent) {
// Handle broadcast request events (protocol v3) before dispatching to user handlers.
// Fire-and-forget: the response is sent asynchronously via RPC.
self.handle_broadcast_event(&event).await;

// Send to broadcast channel
let _ = self.event_tx.send(event.clone());

Expand All @@ -236,6 +243,131 @@ impl Session {
}
}

/// Handle broadcast request events by executing local handlers and responding via RPC.
///
/// Implements the protocol v3 broadcast model where tool calls and permission requests
/// are broadcast as session events to all clients.
async fn handle_broadcast_event(&self, event: &SessionEvent) {
match &event.data {
SessionEventData::ExternalToolRequested(data) => {
let request_id = match &data.request_id {
Some(id) => id.clone(),
None => return,
};
let tool_name = match &data.tool_name {
Some(name) => name.clone(),
None => return,
};

// Check if this session handles this tool
if self.get_tool(&tool_name).await.is_none() {
return; // This client doesn't handle this tool; another client will.
}

let tool_call_id = data.tool_call_id.clone().unwrap_or_default();
let arguments = data.arguments.clone().unwrap_or(serde_json::json!({}));
let session_id = self.session_id.clone();

// Execute tool and respond via handlePendingToolCall RPC
match self.invoke_tool(&tool_name, &arguments).await {
Ok(result) => {
// If the tool reported a failure with an error, send via top-level error
let params = if result.result_type == "failure" || result.result_type == "error" {
serde_json::json!({
"sessionId": session_id,
"requestId": request_id,
"error": result.error.unwrap_or_else(|| result.text_result_for_llm.clone()),
})
} else {
serde_json::json!({
"sessionId": session_id,
"requestId": request_id,
"result": {
"textResultForLlm": result.text_result_for_llm,
"resultType": result.result_type,
"toolTelemetry": result.tool_telemetry.unwrap_or_default(),
}
})
};
let _ = (self.invoke_fn)(
"session.tools.handlePendingToolCall",
Some(params),
).await;
}
Err(e) => {
let params = serde_json::json!({
"sessionId": session_id,
"requestId": request_id,
"error": e.to_string(),
});
let _ = (self.invoke_fn)(
"session.tools.handlePendingToolCall",
Some(params),
).await;
}
}
}
SessionEventData::PermissionRequested(data) => {
let request_id = match &data.request_id {
Some(id) => id.clone(),
None => return,
};
let perm_data = match &data.permission_request {
Some(d) => d.clone(),
None => return,
};

let session_id = self.session_id.clone();

// Build PermissionRequest from JSON
use crate::types::PermissionRequest;
let kind = perm_data
.get("kind")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let tool_call_id = perm_data
.get("toolCallId")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let mut extension_data = std::collections::HashMap::new();
if let Some(obj) = perm_data.as_object() {
for (key, value) in obj {
if key != "kind" && key != "toolCallId" {
extension_data.insert(key.clone(), value.clone());
}
}
}

let request = PermissionRequest {
kind,
tool_call_id,
extension_data,
};

let result = self.handle_permission_request(&request).await;

let mut perm_result_inner = serde_json::json!({
"kind": result.kind,
});
if let Some(rules) = &result.rules {
perm_result_inner["rules"] = serde_json::json!(rules);
}
let perm_result = serde_json::json!({
"sessionId": session_id,
"requestId": request_id,
"result": perm_result_inner,
});

let _ = (self.invoke_fn)(
"session.permissions.handlePendingPermissionRequest",
Some(perm_result),
).await;
}
_ => {} // Not a broadcast request event
}
}

// =========================================================================
// Messaging
// =========================================================================
Expand Down
9 changes: 7 additions & 2 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@ fn is_false(value: &bool) -> bool {
// Protocol Version
// =============================================================================

/// SDK protocol version - must match copilot-agent-runtime server.
pub const SDK_PROTOCOL_VERSION: u32 = 2;
/// Maximum protocol version this SDK supports.
/// This must match the version expected by the copilot-agent-runtime server.
pub const SDK_PROTOCOL_VERSION: u32 = 3;

/// Minimum protocol version this SDK can communicate with.
/// Servers reporting a version below this are rejected.
pub const MIN_PROTOCOL_VERSION: u32 = 2;

// =============================================================================
// Enums
Expand Down