Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/goose-acp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ uuid = { workspace = true, features = ["v7"] }
schemars = { workspace = true, features = ["derive"] }
goose-acp-macros = { path = "../goose-acp-macros" }
goose-sdk = { path = "../goose-sdk" }
subtle = "2.6"

[dev-dependencies]
async-trait = { workspace = true }
Expand All @@ -58,6 +59,7 @@ test-case = { workspace = true }
axum = { workspace = true }
rmcp = { workspace = true, features = ["transport-streamable-http-server"] }
sqlx = { version = "0.8", default-features = false, features = ["runtime-tokio-rustls", "sqlite"] }
tower = "0.5"

[package.metadata.cargo-machete]
# Used to provide extras imports for sacp
Expand Down
209 changes: 199 additions & 10 deletions crates/goose-acp/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,30 @@ use axum::{
body::Body,
extract::{
ws::{rejection::WebSocketUpgradeRejection, WebSocketUpgrade},
State,
Extension, State,
},
http::{header, Method, Request},
response::Response,
http::{header, HeaderValue, Method, Request, StatusCode},
middleware::{self, Next},
response::{IntoResponse, Response},
routing::{delete, get, post},
Router,
};
use serde_json::Value;
use subtle::ConstantTimeEq;
use tokio::sync::{mpsc, Mutex};
use tower_http::cors::{Any, CorsLayer};
use tower_http::cors::{AllowOrigin, CorsLayer};

use crate::server_factory::AcpServer;

pub(crate) const HEADER_SESSION_ID: &str = "Acp-Session-Id";
pub(crate) const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
pub(crate) const JSON_MIME_TYPE: &str = "application/json";
pub(crate) const WS_AUTH_SUBPROTOCOL_PREFIX: &str = "goose-acp-auth.";

#[derive(Clone, Debug)]
pub(crate) struct AcpAuthContext {
pub websocket_protocol: Option<String>,
}

pub(crate) struct TransportSession {
pub to_agent_tx: mpsc::Sender<String>,
Expand Down Expand Up @@ -82,11 +90,12 @@ pub(crate) fn is_initialize_request(value: &Value) -> bool {

async fn handle_get(
ws_upgrade: Result<WebSocketUpgrade, WebSocketUpgradeRejection>,
auth_context: Option<Extension<AcpAuthContext>>,
State(state): State<(Arc<http::HttpState>, Arc<websocket::WsState>)>,
request: Request<Body>,
) -> Response {
match ws_upgrade {
Ok(ws) => websocket::handle_get(state.1, ws).await,
Ok(ws) => websocket::handle_get(state.1, ws, auth_context.map(|context| context.0)).await,
Err(_) => http::handle_get(state.0, request).await,
}
}
Expand All @@ -95,26 +104,99 @@ async fn health() -> &'static str {
"ok"
}

pub fn create_router(server: Arc<AcpServer>) -> Router {
fn is_websocket_upgrade(request: &Request<Body>) -> bool {
request
.headers()
.get(header::UPGRADE)
.and_then(|value| value.to_str().ok())
.is_some_and(|value| value.eq_ignore_ascii_case("websocket"))
}

fn constant_time_token_matches(expected: &str, actual: &str) -> bool {
expected.as_bytes().ct_eq(actual.as_bytes()).into()
}

fn extract_bearer_token(headers: &axum::http::HeaderMap) -> Option<&str> {
let value = headers.get(header::AUTHORIZATION)?.to_str().ok()?;
let token = value.strip_prefix("Bearer ")?;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Parse Authorization scheme case-insensitively

The bearer parser now requires the exact prefix "Bearer ", but HTTP authentication scheme names are case-insensitive, so valid headers like authorization: bearer <token> will be rejected. Because this middleware now guards all /acp HTTP requests, this can break compliant clients or proxies that normalize scheme casing; parse the scheme case-insensitively before extracting the token.

Useful? React with 👍 / 👎.

Some(token.trim())
}

fn extract_websocket_auth_protocol(
headers: &axum::http::HeaderMap,
expected_token: &str,
) -> Option<String> {
let protocols = headers.get(header::SEC_WEBSOCKET_PROTOCOL)?.to_str().ok()?;

protocols
.split(',')
.map(str::trim)
.find(|protocol| {
protocol
.strip_prefix(WS_AUTH_SUBPROTOCOL_PREFIX)
.is_some_and(|token| constant_time_token_matches(expected_token, token))
})
.map(ToOwned::to_owned)
}

async fn require_acp_auth(
State(auth_token): State<Arc<str>>,
mut request: Request<Body>,
next: Next,
) -> Response {
let authorized = if is_websocket_upgrade(&request) {
if let Some(protocol) =
extract_websocket_auth_protocol(request.headers(), auth_token.as_ref())
{
request.extensions_mut().insert(AcpAuthContext {
websocket_protocol: Some(protocol),
});
}

request
.extensions()
.get::<AcpAuthContext>()
.is_some_and(|context| context.websocket_protocol.is_some())
|| extract_bearer_token(request.headers())
.is_some_and(|token| constant_time_token_matches(auth_token.as_ref(), token))
} else {
extract_bearer_token(request.headers())
.is_some_and(|token| constant_time_token_matches(auth_token.as_ref(), token))
};

if !authorized {
return StatusCode::UNAUTHORIZED.into_response();
}

next.run(request).await
}

pub fn create_router(server: Arc<AcpServer>, auth_token: Arc<str>) -> Router {
let http_state = Arc::new(http::HttpState::new(server.clone()));
let ws_state = Arc::new(websocket::WsState::new(server));
let allowed_origins = [
"http://127.0.0.1".parse::<HeaderValue>().unwrap(),
"http://localhost".parse::<HeaderValue>().unwrap(),
"tauri://localhost".parse::<HeaderValue>().unwrap(),
Comment thread
mvanhorn marked this conversation as resolved.
Outdated
"https://tauri.localhost".parse::<HeaderValue>().unwrap(),
];

let cors = CorsLayer::new()
.allow_origin(Any)
.allow_origin(AllowOrigin::list(allowed_origins))
.allow_methods([Method::GET, Method::POST, Method::DELETE, Method::OPTIONS])
.allow_headers([
header::AUTHORIZATION,
header::CONTENT_TYPE,
header::ACCEPT,
HEADER_SESSION_ID.parse().unwrap(),
header::SEC_WEBSOCKET_PROTOCOL,
header::SEC_WEBSOCKET_VERSION,
header::SEC_WEBSOCKET_KEY,
header::CONNECTION,
header::UPGRADE,
]);

Router::new()
.route("/health", get(health))
.route("/status", get(health))
let acp_routes = Router::new()
.route(
"/acp",
post(http::handle_post).with_state(http_state.clone()),
Expand All @@ -124,5 +206,112 @@ pub fn create_router(server: Arc<AcpServer>) -> Router {
get(handle_get).with_state((http_state.clone(), ws_state)),
)
.route("/acp", delete(http::handle_delete).with_state(http_state))
.route_layer(middleware::from_fn_with_state(auth_token, require_acp_auth));
Comment thread
mvanhorn marked this conversation as resolved.

Router::new()
.route("/health", get(health))
.route("/status", get(health))
.merge(acp_routes)
.layer(cors)
}

#[cfg(test)]
mod tests {
use axum::{
extract::Extension,
http::{header::SEC_WEBSOCKET_PROTOCOL, Request, StatusCode},
routing::get,
Router,
};
use tower::ServiceExt;

use super::*;

fn auth_test_router(token: Arc<str>) -> Router {
Router::new()
.route("/acp", get(|| async { StatusCode::OK }))
.route(
"/ws",
get(|Extension(context): Extension<AcpAuthContext>| async move {
let protocol = context
.websocket_protocol
.expect("missing websocket protocol");
(
StatusCode::SWITCHING_PROTOCOLS,
[(SEC_WEBSOCKET_PROTOCOL, protocol)],
)
}),
)
.route_layer(middleware::from_fn_with_state(token, require_acp_auth))
}

#[tokio::test]
async fn acp_auth_rejects_missing_auth_header() {
let response = auth_test_router(Arc::<str>::from("secret-token"))
.oneshot(Request::builder().uri("/acp").body(Body::empty()).unwrap())
.await
.unwrap();

assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn acp_auth_rejects_wrong_auth_header() {
let response = auth_test_router(Arc::<str>::from("secret-token"))
.oneshot(
Request::builder()
.uri("/acp")
.header(header::AUTHORIZATION, "Bearer wrong-token")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

#[tokio::test]
async fn acp_auth_accepts_matching_auth_header() {
let response = auth_test_router(Arc::<str>::from("secret-token"))
.oneshot(
Request::builder()
.uri("/acp")
.header(header::AUTHORIZATION, "Bearer secret-token")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::OK);
}

#[tokio::test]
async fn acp_auth_accepts_matching_websocket_subprotocol() {
let protocol = format!("{WS_AUTH_SUBPROTOCOL_PREFIX}secret-token");
let response = auth_test_router(Arc::<str>::from("secret-token"))
.oneshot(
Request::builder()
.uri("/ws")
.header(header::CONNECTION, "Upgrade")
.header(header::UPGRADE, "websocket")
.header(header::SEC_WEBSOCKET_VERSION, "13")
.header(header::SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==")
.header(header::SEC_WEBSOCKET_PROTOCOL, protocol.clone())
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
assert_eq!(
response
.headers()
.get(header::SEC_WEBSOCKET_PROTOCOL)
.unwrap(),
&protocol
);
}
}
15 changes: 13 additions & 2 deletions crates/goose-acp/src/transport/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use tokio::sync::{mpsc, Mutex, RwLock};
use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
use tracing::{debug, error, info, warn};

use super::{TransportSession, HEADER_SESSION_ID};
use super::{AcpAuthContext, TransportSession, HEADER_SESSION_ID};
use crate::adapters::{ReceiverToAsyncRead, SenderToAsyncWrite};
use crate::server_factory::AcpServer;

Expand Down Expand Up @@ -66,7 +66,11 @@ impl WsState {
}
}

pub(crate) async fn handle_get(state: Arc<WsState>, ws: WebSocketUpgrade) -> Response {
pub(crate) async fn handle_get(
state: Arc<WsState>,
ws: WebSocketUpgrade,
auth_context: Option<AcpAuthContext>,
) -> Response {
let acp_session_id = match state.create_connection().await {
Ok(id) => id,
Err(e) => {
Expand All @@ -79,6 +83,13 @@ pub(crate) async fn handle_get(state: Arc<WsState>, ws: WebSocketUpgrade) -> Res
}
};

let protocol = auth_context.and_then(|context| context.websocket_protocol);
let ws = if let Some(protocol) = protocol {
ws.protocols([protocol])
} else {
ws
};

let mut response = ws.on_upgrade({
let acp_session_id = acp_session_id.clone();
move |socket| handle_ws(socket, state, acp_session_id)
Expand Down
12 changes: 11 additions & 1 deletion crates/goose-cli/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,7 @@ async fn handle_mcp_command(server: McpCommand) -> Result<()> {
async fn handle_serve_command(host: String, port: u16, builtins: Vec<String>) -> Result<()> {
use goose::config::paths::Paths;
use goose_acp::server_factory::{AcpServer, AcpServerFactoryConfig};
use rand::{rngs::OsRng, RngCore};
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::info;
Expand All @@ -1081,10 +1082,19 @@ async fn handle_serve_command(host: String, port: u16, builtins: Vec<String>) ->
data_dir: Paths::data_dir(),
config_dir: Paths::config_dir(),
}));
let router = goose_acp::transport::create_router(server);
let mut auth_bytes = [0_u8; 32];
OsRng.fill_bytes(&mut auth_bytes);
let mut auth_token = String::with_capacity(auth_bytes.len() * 2);
for byte in auth_bytes {
use std::fmt::Write as _;

write!(&mut auth_token, "{byte:02x}")?;
}
let router = goose_acp::transport::create_router(server, Arc::<str>::from(auth_token.clone()));

let addr: SocketAddr = format!("{}:{}", host, port).parse()?;
info!("Starting ACP server on {}", addr);
println!("ACP_TOKEN={auth_token}");

let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, router).await?;
Expand Down
15 changes: 13 additions & 2 deletions ui/goose2/src-tauri/src/commands/acp.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
use crate::services::acp::GooseServeProcess;

#[derive(serde::Serialize)]
pub struct GooseServeConnection {
url: String,
token: String,
}

#[tauri::command]
pub async fn get_goose_serve_url(app_handle: tauri::AppHandle) -> Result<String, String> {
pub async fn get_goose_serve_connection(
app_handle: tauri::AppHandle,
) -> Result<GooseServeConnection, String> {
let process = GooseServeProcess::get(app_handle).await?;
Ok(process.ws_url())
Ok(GooseServeConnection {
url: process.ws_url(),
token: process.token().to_string(),
})
}
2 changes: 1 addition & 1 deletion ui/goose2/src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub fn run() {
commands::agents::save_persona_avatar,
commands::agents::save_persona_avatar_bytes,
commands::agents::get_avatars_dir,
commands::acp::get_goose_serve_url,
commands::acp::get_goose_serve_connection,
commands::skills::create_skill,
commands::skills::list_skills,
commands::skills::delete_skill,
Expand Down
Loading
Loading