Skip to content

Commit b5919e0

Browse files
committed
test(openapi): add OpenAPI v3 compatibility and test for nullable field schema workaround (#135)
1 parent 030b6f0 commit b5919e0

File tree

2 files changed

+239
-3
lines changed

2 files changed

+239
-3
lines changed

crates/rmcp/src/handler/server/tool.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ use crate::{
1414
};
1515
/// A shortcut for generating a JSON schema for a type.
1616
pub fn schema_for_type<T: JsonSchema>() -> JsonObject {
17-
let schema = schemars::r#gen::SchemaGenerator::default().into_root_schema_for::<T>();
17+
let settings = schemars::r#gen::SchemaSettings::openapi3();
18+
let generator = settings.into_generator();
19+
let schema = generator.into_root_schema_for::<T>();
1820
let object = serde_json::to_value(schema).expect("failed to serialize schema");
1921
match object {
2022
serde_json::Value::Object(object) => object,

crates/rmcp/tests/test_tool_macros.rs

Lines changed: 236 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1-
use std::sync::Arc;
1+
//cargo test --test test_tool_macros --features "client server"
22

3-
use rmcp::{ServerHandler, handler::server::tool::ToolCallContext, tool};
3+
use rmcp::{
4+
ClientHandler, Peer, RoleClient, ServerHandler, ServiceExt,
5+
model::{CallToolRequestParam, ClientInfo},
6+
};
7+
use rmcp::{handler::server::tool::ToolCallContext, tool};
48
use schemars::JsonSchema;
59
use serde::{Deserialize, Serialize};
10+
use serde_json;
11+
use std::sync::Arc;
612

713
#[derive(Serialize, Deserialize, JsonSchema)]
814
pub struct GetWeatherRequest {
@@ -36,6 +42,11 @@ impl Server {
3642
}
3743
#[tool(description = "Empty Parameter")]
3844
async fn empty_param(&self) {}
45+
46+
#[tool(description = "Optional Parameter")]
47+
async fn optional_param(&self, #[tool(param)] city: Option<String>) -> String {
48+
city.unwrap_or_default()
49+
}
3950
}
4051

4152
// define generic service trait
@@ -99,4 +110,227 @@ async fn test_tool_macros_with_generics() {
99110
assert_eq!(server.get_data().await, "mock data");
100111
}
101112

113+
#[tokio::test]
114+
async fn test_tool_macros_with_optional_param() {
115+
let _attr = Server::optional_param_tool_attr();
116+
// println!("{_attr:?}");
117+
let attr_type = _attr
118+
.input_schema
119+
.get("properties")
120+
.unwrap()
121+
.get("city")
122+
.unwrap()
123+
.get("type")
124+
.unwrap();
125+
println!("_attr.input_schema: {:?}", attr_type);
126+
assert_eq!(attr_type.as_str().unwrap(), "string");
127+
}
128+
102129
impl GetWeatherRequest {}
130+
131+
// Struct defined for testing optional field schema generation
132+
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
133+
pub struct OptionalFieldTestSchema {
134+
#[schemars(description = "An optional description field")]
135+
pub description: Option<String>,
136+
}
137+
138+
// Struct defined for testing optional i64 field schema generation and null handling
139+
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
140+
pub struct OptionalI64TestSchema {
141+
#[schemars(description = "An optional i64 field")]
142+
pub count: Option<i64>,
143+
pub mandatory_field: String, // Added to ensure non-empty object schema
144+
}
145+
146+
// Dummy struct to host the test tool method
147+
#[derive(Debug, Clone, Default)]
148+
pub struct OptionalSchemaTester {}
149+
150+
impl OptionalSchemaTester {
151+
// Dummy tool function using the test schema as an aggregated parameter
152+
#[tool(description = "A tool to test optional schema generation")]
153+
async fn test_optional_aggr(&self, #[tool(aggr)] _req: OptionalFieldTestSchema) {
154+
// Implementation doesn't matter for schema testing
155+
// Return type changed to () to satisfy IntoCallToolResult
156+
}
157+
158+
// Tool function to test optional i64 handling
159+
#[tool(description = "A tool to test optional i64 schema generation")]
160+
async fn test_optional_i64_aggr(&self, #[tool(aggr)] req: OptionalI64TestSchema) -> String {
161+
match req.count {
162+
Some(c) => format!("Received count: {}", c),
163+
None => "Received null count".to_string(),
164+
}
165+
}
166+
}
167+
168+
// Implement ServerHandler to route tool calls for OptionalSchemaTester
169+
impl ServerHandler for OptionalSchemaTester {
170+
async fn call_tool(
171+
&self,
172+
request: rmcp::model::CallToolRequestParam,
173+
context: rmcp::service::RequestContext<rmcp::RoleServer>,
174+
) -> Result<rmcp::model::CallToolResult, rmcp::Error> {
175+
let tcc = ToolCallContext::new(self, request, context);
176+
match tcc.name() {
177+
"test_optional_aggr" => Self::test_optional_aggr_tool_call(tcc).await,
178+
"test_optional_i64_aggr" => Self::test_optional_i64_aggr_tool_call(tcc).await,
179+
_ => Err(rmcp::Error::invalid_params("method not found", None)),
180+
}
181+
}
182+
}
183+
184+
#[test]
185+
fn test_optional_field_schema_generation_via_macro() {
186+
// tests https://github.com/modelcontextprotocol/rust-sdk/issues/135
187+
188+
// Get the attributes generated by the #[tool] macro helper
189+
let tool_attr = OptionalSchemaTester::test_optional_aggr_tool_attr();
190+
191+
// Print the actual generated schema for debugging
192+
println!(
193+
"Actual input schema generated by macro: {:#?}",
194+
tool_attr.input_schema
195+
);
196+
197+
// Verify the schema generated for the aggregated OptionalFieldTestSchema
198+
// by the macro infrastructure (which should now use OpenAPI 3 settings)
199+
let input_schema_map = &*tool_attr.input_schema; // Dereference Arc<JsonObject>
200+
201+
// Check the schema for the 'description' property within the input schema
202+
let properties = input_schema_map
203+
.get("properties")
204+
.expect("Schema should have properties")
205+
.as_object()
206+
.unwrap();
207+
let description_schema = properties
208+
.get("description")
209+
.expect("Properties should include description")
210+
.as_object()
211+
.unwrap();
212+
213+
// Assert that the format is now `type: "string", nullable: true`
214+
assert_eq!(
215+
description_schema.get("type").map(|v| v.as_str().unwrap()),
216+
Some("string"),
217+
"Schema for Option<String> generated by macro should be type: \"string\""
218+
);
219+
assert_eq!(
220+
description_schema
221+
.get("nullable")
222+
.map(|v| v.as_bool().unwrap()),
223+
Some(true),
224+
"Schema for Option<String> generated by macro should have nullable: true"
225+
);
226+
// We still check the description is correct
227+
assert_eq!(
228+
description_schema
229+
.get("description")
230+
.map(|v| v.as_str().unwrap()),
231+
Some("An optional description field")
232+
);
233+
234+
// Ensure the old 'type: [T, null]' format is NOT used
235+
let type_value = description_schema.get("type").unwrap();
236+
assert!(
237+
!type_value.is_array(),
238+
"Schema type should not be an array [T, null]"
239+
);
240+
}
241+
242+
// Define a dummy client handler
243+
#[derive(Debug, Clone, Default)]
244+
struct DummyClientHandler {
245+
peer: Option<Peer<RoleClient>>,
246+
}
247+
248+
impl ClientHandler for DummyClientHandler {
249+
fn get_info(&self) -> ClientInfo {
250+
ClientInfo::default()
251+
}
252+
253+
fn set_peer(&mut self, peer: Peer<RoleClient>) {
254+
self.peer = Some(peer);
255+
}
256+
257+
fn get_peer(&self) -> Option<Peer<RoleClient>> {
258+
self.peer.clone()
259+
}
260+
}
261+
262+
#[tokio::test]
263+
async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> {
264+
let (server_transport, client_transport) = tokio::io::duplex(4096);
265+
266+
// Server setup
267+
let server = OptionalSchemaTester::default();
268+
let server_handle = tokio::spawn(async move {
269+
server.serve(server_transport).await?.waiting().await?;
270+
anyhow::Ok(())
271+
});
272+
273+
// Create a simple client handler that just forwards tool calls
274+
let client_handler = DummyClientHandler::default();
275+
let client = client_handler.serve(client_transport).await?;
276+
277+
// Test null case
278+
let result = client
279+
.call_tool(CallToolRequestParam {
280+
name: "test_optional_i64_aggr".into(),
281+
arguments: Some(
282+
serde_json::json!({
283+
"count": null,
284+
"mandatory_field": "test_null"
285+
})
286+
.as_object()
287+
.unwrap()
288+
.clone(),
289+
),
290+
})
291+
.await?;
292+
293+
let result_text = result
294+
.content
295+
.first()
296+
.and_then(|content| content.raw.as_text())
297+
.map(|text| text.text.as_str())
298+
.expect("Expected text content");
299+
300+
assert_eq!(
301+
result_text, "Received null count",
302+
"Null case should return expected message"
303+
);
304+
305+
// Test Some case
306+
let some_result = client
307+
.call_tool(CallToolRequestParam {
308+
name: "test_optional_i64_aggr".into(),
309+
arguments: Some(
310+
serde_json::json!({
311+
"count": 42,
312+
"mandatory_field": "test_some"
313+
})
314+
.as_object()
315+
.unwrap()
316+
.clone(),
317+
),
318+
})
319+
.await?;
320+
321+
let some_result_text = some_result
322+
.content
323+
.first()
324+
.and_then(|content| content.raw.as_text())
325+
.map(|text| text.text.as_str())
326+
.expect("Expected text content");
327+
328+
assert_eq!(
329+
some_result_text, "Received count: 42",
330+
"Some case should return expected message"
331+
);
332+
333+
client.cancel().await?;
334+
server_handle.await??;
335+
Ok(())
336+
}

0 commit comments

Comments
 (0)