McpServerTool output schema (#35069)

Agus Zubiaga created

Add an `Output` associated type to `McpServerTool`, so that we can
include its schema in `tools/list`.

Release Notes:

- N/A

Change summary

crates/agent_servers/src/claude/mcp_server.rs | 30 ++++++++++++---
crates/context_server/src/listener.rs         | 40 +++++++++++++++++---
crates/context_server/src/types.rs            |  2 +
3 files changed, 59 insertions(+), 13 deletions(-)

Detailed changes

crates/agent_servers/src/claude/mcp_server.rs 🔗

@@ -124,13 +124,19 @@ enum PermissionToolBehavior {
 
 impl McpServerTool for PermissionTool {
     type Input = PermissionToolParams;
+    type Output = ();
+
     const NAME: &'static str = "Confirmation";
 
     fn description(&self) -> &'static str {
         "Request permission for tool calls"
     }
 
-    async fn run(&self, input: Self::Input, cx: &mut AsyncApp) -> Result<ToolResponse> {
+    async fn run(
+        &self,
+        input: Self::Input,
+        cx: &mut AsyncApp,
+    ) -> Result<ToolResponse<Self::Output>> {
         let mut thread_rx = self.thread_rx.clone();
         let Some(thread) = thread_rx.recv().await?.upgrade() else {
             anyhow::bail!("Thread closed");
@@ -178,7 +184,7 @@ impl McpServerTool for PermissionTool {
             content: vec![ToolResponseContent::Text {
                 text: serde_json::to_string(&response)?,
             }],
-            structured_content: None,
+            structured_content: (),
         })
     }
 }
@@ -190,6 +196,8 @@ pub struct ReadTool {
 
 impl McpServerTool for ReadTool {
     type Input = ReadToolParams;
+    type Output = ();
+
     const NAME: &'static str = "Read";
 
     fn description(&self) -> &'static str {
@@ -206,7 +214,11 @@ impl McpServerTool for ReadTool {
         }
     }
 
-    async fn run(&self, input: Self::Input, cx: &mut AsyncApp) -> Result<ToolResponse> {
+    async fn run(
+        &self,
+        input: Self::Input,
+        cx: &mut AsyncApp,
+    ) -> Result<ToolResponse<Self::Output>> {
         let mut thread_rx = self.thread_rx.clone();
         let Some(thread) = thread_rx.recv().await?.upgrade() else {
             anyhow::bail!("Thread closed");
@@ -220,7 +232,7 @@ impl McpServerTool for ReadTool {
 
         Ok(ToolResponse {
             content: vec![ToolResponseContent::Text { text: content }],
-            structured_content: None,
+            structured_content: (),
         })
     }
 }
@@ -232,6 +244,8 @@ pub struct EditTool {
 
 impl McpServerTool for EditTool {
     type Input = EditToolParams;
+    type Output = ();
+
     const NAME: &'static str = "Edit";
 
     fn description(&self) -> &'static str {
@@ -248,7 +262,11 @@ impl McpServerTool for EditTool {
         }
     }
 
-    async fn run(&self, input: Self::Input, cx: &mut AsyncApp) -> Result<ToolResponse> {
+    async fn run(
+        &self,
+        input: Self::Input,
+        cx: &mut AsyncApp,
+    ) -> Result<ToolResponse<Self::Output>> {
         let mut thread_rx = self.thread_rx.clone();
         let Some(thread) = thread_rx.recv().await?.upgrade() else {
             anyhow::bail!("Thread closed");
@@ -273,7 +291,7 @@ impl McpServerTool for EditTool {
 
         Ok(ToolResponse {
             content: vec![],
-            structured_content: None,
+            structured_content: (),
         })
     }
 }

crates/context_server/src/listener.rs 🔗

@@ -41,8 +41,12 @@ struct RegisteredTool {
     handler: ToolHandler,
 }
 
-type ToolHandler =
-    Box<dyn Fn(Option<serde_json::Value>, &mut AsyncApp) -> Task<Result<ToolResponse>>>;
+type ToolHandler = Box<
+    dyn Fn(
+        Option<serde_json::Value>,
+        &mut AsyncApp,
+    ) -> Task<Result<ToolResponse<serde_json::Value>>>,
+>;
 type RequestHandler = Box<dyn Fn(RequestId, Option<Box<RawValue>>, &App) -> Task<String>>;
 
 impl McpServer {
@@ -79,11 +83,19 @@ impl McpServer {
     }
 
     pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
+        let output_schema = schemars::schema_for!(T::Output);
+        let unit_schema = schemars::schema_for!(());
+
         let registered_tool = RegisteredTool {
             tool: Tool {
                 name: T::NAME.into(),
                 description: Some(tool.description().into()),
                 input_schema: schemars::schema_for!(T::Input).into(),
+                output_schema: if output_schema == unit_schema {
+                    None
+                } else {
+                    Some(output_schema.into())
+                },
                 annotations: Some(tool.annotations()),
             },
             handler: Box::new({
@@ -96,7 +108,15 @@ impl McpServer {
 
                     let tool = tool.clone();
                     match input {
-                        Ok(input) => cx.spawn(async move |cx| tool.run(input, cx).await),
+                        Ok(input) => cx.spawn(async move |cx| {
+                            let output = tool.run(input, cx).await?;
+
+                            Ok(ToolResponse {
+                                content: output.content,
+                                structured_content: serde_json::to_value(output.structured_content)
+                                    .unwrap_or_default(),
+                            })
+                        }),
                         Err(err) => Task::ready(Err(err.into())),
                     }
                 }
@@ -259,7 +279,11 @@ impl McpServer {
                                 content: result.content,
                                 is_error: Some(false),
                                 meta: None,
-                                structured_content: result.structured_content,
+                                structured_content: if result.structured_content.is_null() {
+                                    None
+                                } else {
+                                    Some(result.structured_content)
+                                },
                             },
                             Err(err) => CallToolResponse {
                                 content: vec![ToolResponseContent::Text {
@@ -367,6 +391,8 @@ impl McpServer {
 
 pub trait McpServerTool {
     type Input: DeserializeOwned + JsonSchema;
+    type Output: Serialize + JsonSchema;
+
     const NAME: &'static str;
 
     fn description(&self) -> &'static str;
@@ -385,12 +411,12 @@ pub trait McpServerTool {
         &self,
         input: Self::Input,
         cx: &mut AsyncApp,
-    ) -> impl Future<Output = Result<ToolResponse>>;
+    ) -> impl Future<Output = Result<ToolResponse<Self::Output>>>;
 }
 
-pub struct ToolResponse {
+pub struct ToolResponse<T> {
     pub content: Vec<ToolResponseContent>,
-    pub structured_content: Option<serde_json::Value>,
+    pub structured_content: T,
 }
 
 #[derive(Serialize, Deserialize)]

crates/context_server/src/types.rs 🔗

@@ -502,6 +502,8 @@ pub struct Tool {
     #[serde(skip_serializing_if = "Option::is_none")]
     pub description: Option<String>,
     pub input_schema: serde_json::Value,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub output_schema: Option<serde_json::Value>,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub annotations: Option<ToolAnnotations>,
 }