From 95d78ff8d573c88d51a8f77e039165c1de0ff16a Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Fri, 6 Jun 2025 17:47:21 +0200 Subject: [PATCH] context server: Make requests type safe (#32254) This changes the context server crate so that the input/output for a request are encoded at the type level, similar to how it is done for LSP requests. This also makes it easier to write tests that mock context servers, e.g. you can write something like this now when using the `test-support` feature of the `context-server` crate: ```rust create_fake_transport("mcp-1", cx.background_executor()) .on_request::(|_params| { PromptsListResponse { prompts: vec![/* some prompts */], .. } }) ``` Release Notes: - N/A --- crates/agent/src/context_server_tool.rs | 10 +- crates/agent/src/thread_store.rs | 8 +- .../src/context_store.rs | 9 +- .../src/context_server_command.rs | 42 ++-- crates/context_server/Cargo.toml | 3 + crates/context_server/src/context_server.rs | 2 + crates/context_server/src/protocol.rs | 139 +---------- crates/context_server/src/test.rs | 118 +++++++++ crates/context_server/src/types.rs | 195 ++++++++------- crates/project/Cargo.toml | 1 + crates/project/src/context_server_store.rs | 226 +++--------------- 11 files changed, 320 insertions(+), 433 deletions(-) create mode 100644 crates/context_server/src/test.rs diff --git a/crates/agent/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs index e4461f94de3ced9c13431de6e0eb02b7ffe646e4..2de43d157f8ed9303a1dd9c7f5b0b34543d4f44c 100644 --- a/crates/agent/src/context_server_tool.rs +++ b/crates/agent/src/context_server_tool.rs @@ -104,7 +104,15 @@ impl Tool for ContextServerTool { tool_name, arguments ); - let response = protocol.run_tool(tool_name, arguments).await?; + let response = protocol + .request::( + context_server::types::CallToolParams { + name: tool_name, + arguments, + meta: None, + }, + ) + .await?; let mut result = String::new(); for content in response.content { diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index cb3a0d3c63f88c294ad29f8b62c4970726268bb7..5d5cf21d93e24785abb5023f354668711ffa0387 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -566,10 +566,14 @@ impl ThreadStore { }; if protocol.capable(context_server::protocol::ServerCapability::Tools) { - if let Some(tools) = protocol.list_tools().await.log_err() { + if let Some(response) = protocol + .request::(()) + .await + .log_err() + { let tool_ids = tool_working_set .update(cx, |tool_working_set, _| { - tools + response .tools .into_iter() .map(|tool| { diff --git a/crates/assistant_context_editor/src/context_store.rs b/crates/assistant_context_editor/src/context_store.rs index 7af97b62a934f3e09c9713c17f31677feacfa609..7965ee592be8d386ec24839a25a44e2e6f47e3df 100644 --- a/crates/assistant_context_editor/src/context_store.rs +++ b/crates/assistant_context_editor/src/context_store.rs @@ -864,8 +864,13 @@ impl ContextStore { }; if protocol.capable(context_server::protocol::ServerCapability::Prompts) { - if let Some(prompts) = protocol.list_prompts().await.log_err() { - let slash_command_ids = prompts + if let Some(response) = protocol + .request::(()) + .await + .log_err() + { + let slash_command_ids = response + .prompts .into_iter() .filter(assistant_slash_commands::acceptable_prompt) .map(|prompt| { diff --git a/crates/assistant_slash_commands/src/context_server_command.rs b/crates/assistant_slash_commands/src/context_server_command.rs index 9b0ac1842687a765c4fc06f2e4d53836d2fb96c3..509076c1677919635c46e704d71f663c661693da 100644 --- a/crates/assistant_slash_commands/src/context_server_command.rs +++ b/crates/assistant_slash_commands/src/context_server_command.rs @@ -86,20 +86,26 @@ impl SlashCommand for ContextServerSlashCommand { cx.foreground_executor().spawn(async move { let protocol = server.client().context("Context server not initialized")?; - let completion_result = protocol - .completion( - context_server::types::CompletionReference::Prompt( - context_server::types::PromptReference { - r#type: context_server::types::PromptReferenceType::Prompt, - name: prompt_name, + let response = protocol + .request::( + context_server::types::CompletionCompleteParams { + reference: context_server::types::CompletionReference::Prompt( + context_server::types::PromptReference { + ty: context_server::types::PromptReferenceType::Prompt, + name: prompt_name, + }, + ), + argument: context_server::types::CompletionArgument { + name: arg_name, + value: arg_value, }, - ), - arg_name, - arg_value, + meta: None, + }, ) .await?; - let completions = completion_result + let completions = response + .completion .values .into_iter() .map(|value| ArgumentCompletion { @@ -138,10 +144,18 @@ impl SlashCommand for ContextServerSlashCommand { if let Some(server) = store.get_running_server(&server_id) { cx.foreground_executor().spawn(async move { let protocol = server.client().context("Context server not initialized")?; - let result = protocol.run_prompt(&prompt_name, prompt_args).await?; + let response = protocol + .request::( + context_server::types::PromptsGetParams { + name: prompt_name.clone(), + arguments: Some(prompt_args), + meta: None, + }, + ) + .await?; anyhow::ensure!( - result + response .messages .iter() .all(|msg| matches!(msg.role, context_server::types::Role::User)), @@ -149,7 +163,7 @@ impl SlashCommand for ContextServerSlashCommand { ); // Extract text from user messages into a single prompt string - let mut prompt = result + let mut prompt = response .messages .into_iter() .filter_map(|msg| match msg.content { @@ -167,7 +181,7 @@ impl SlashCommand for ContextServerSlashCommand { range: 0..(prompt.len()), icon: IconName::ZedAssistant, label: SharedString::from( - result + response .description .unwrap_or(format!("Result from {}", prompt_name)), ), diff --git a/crates/context_server/Cargo.toml b/crates/context_server/Cargo.toml index 62a5354b39079e41214d75a6be41f261da3fae5f..96bb9e071f42dd1f6f7fa0782ed8ca425e1cd379 100644 --- a/crates/context_server/Cargo.toml +++ b/crates/context_server/Cargo.toml @@ -11,6 +11,9 @@ workspace = true [lib] path = "src/context_server.rs" +[features] +test-support = [] + [dependencies] anyhow.workspace = true async-trait.workspace = true diff --git a/crates/context_server/src/context_server.rs b/crates/context_server/src/context_server.rs index 19f2f75541b7e7fee62e87578acf2254f7a22a85..387235307a18839b26a6e76734c1b81b846bcca3 100644 --- a/crates/context_server/src/context_server.rs +++ b/crates/context_server/src/context_server.rs @@ -1,5 +1,7 @@ pub mod client; pub mod protocol; +#[cfg(any(test, feature = "test-support"))] +pub mod test; pub mod transport; pub mod types; diff --git a/crates/context_server/src/protocol.rs b/crates/context_server/src/protocol.rs index 782a1a4a6754a6363db9a233053d687983608af8..233df048d620f48a7488f1b008f25aa9059e88c0 100644 --- a/crates/context_server/src/protocol.rs +++ b/crates/context_server/src/protocol.rs @@ -6,10 +6,9 @@ //! of messages. use anyhow::Result; -use collections::HashMap; use crate::client::Client; -use crate::types; +use crate::types::{self, Request}; pub struct ModelContextProtocol { inner: Client, @@ -43,7 +42,7 @@ impl ModelContextProtocol { let response: types::InitializeResponse = self .inner - .request(types::RequestType::Initialize.as_str(), params) + .request(types::request::Initialize::METHOD, params) .await?; anyhow::ensure!( @@ -94,137 +93,7 @@ impl InitializedContextServerProtocol { } } - fn check_capability(&self, capability: ServerCapability) -> Result<()> { - anyhow::ensure!( - self.capable(capability), - "Server does not support {capability:?} capability" - ); - Ok(()) - } - - /// List the MCP prompts. - pub async fn list_prompts(&self) -> Result> { - self.check_capability(ServerCapability::Prompts)?; - - let response: types::PromptsListResponse = self - .inner - .request( - types::RequestType::PromptsList.as_str(), - serde_json::json!({}), - ) - .await?; - - Ok(response.prompts) - } - - /// List the MCP resources. - pub async fn list_resources(&self) -> Result { - self.check_capability(ServerCapability::Resources)?; - - let response: types::ResourcesListResponse = self - .inner - .request( - types::RequestType::ResourcesList.as_str(), - serde_json::json!({}), - ) - .await?; - - Ok(response) - } - - /// Executes a prompt with the given arguments and returns the result. - pub async fn run_prompt>( - &self, - prompt: P, - arguments: HashMap, - ) -> Result { - self.check_capability(ServerCapability::Prompts)?; - - let params = types::PromptsGetParams { - name: prompt.as_ref().to_string(), - arguments: Some(arguments), - meta: None, - }; - - let response: types::PromptsGetResponse = self - .inner - .request(types::RequestType::PromptsGet.as_str(), params) - .await?; - - Ok(response) - } - - pub async fn completion>( - &self, - reference: types::CompletionReference, - argument: P, - value: P, - ) -> Result { - let params = types::CompletionCompleteParams { - r#ref: reference, - argument: types::CompletionArgument { - name: argument.into(), - value: value.into(), - }, - meta: None, - }; - let result: types::CompletionCompleteResponse = self - .inner - .request(types::RequestType::CompletionComplete.as_str(), params) - .await?; - - let completion = types::Completion { - values: result.completion.values, - total: types::CompletionTotal::from_options( - result.completion.has_more, - result.completion.total, - ), - }; - - Ok(completion) - } - - /// List MCP tools. - pub async fn list_tools(&self) -> Result { - self.check_capability(ServerCapability::Tools)?; - - let response = self - .inner - .request::(types::RequestType::ListTools.as_str(), ()) - .await?; - - Ok(response) - } - - /// Executes a tool with the given arguments - pub async fn run_tool>( - &self, - tool: P, - arguments: Option>, - ) -> Result { - self.check_capability(ServerCapability::Tools)?; - - let params = types::CallToolParams { - name: tool.as_ref().to_string(), - arguments, - meta: None, - }; - - let response: types::CallToolResponse = self - .inner - .request(types::RequestType::CallTool.as_str(), params) - .await?; - - Ok(response) - } -} - -impl InitializedContextServerProtocol { - pub async fn request( - &self, - method: &str, - params: impl serde::Serialize, - ) -> Result { - self.inner.request(method, params).await + pub async fn request(&self, params: T::Params) -> Result { + self.inner.request(T::METHOD, params).await } } diff --git a/crates/context_server/src/test.rs b/crates/context_server/src/test.rs new file mode 100644 index 0000000000000000000000000000000000000000..d882a569841c231f387d36853d50b5404e7d0dd4 --- /dev/null +++ b/crates/context_server/src/test.rs @@ -0,0 +1,118 @@ +use anyhow::Context as _; +use collections::HashMap; +use futures::{Stream, StreamExt as _, lock::Mutex}; +use gpui::BackgroundExecutor; +use std::{pin::Pin, sync::Arc}; + +use crate::{ + transport::Transport, + types::{Implementation, InitializeResponse, ProtocolVersion, ServerCapabilities}, +}; + +pub fn create_fake_transport( + name: impl Into, + executor: BackgroundExecutor, +) -> FakeTransport { + let name = name.into(); + FakeTransport::new(executor).on_request::(move |_params| { + create_initialize_response(name.clone()) + }) +} + +fn create_initialize_response(server_name: String) -> InitializeResponse { + InitializeResponse { + protocol_version: ProtocolVersion(crate::types::LATEST_PROTOCOL_VERSION.to_string()), + server_info: Implementation { + name: server_name, + version: "1.0.0".to_string(), + }, + capabilities: ServerCapabilities::default(), + meta: None, + } +} + +pub struct FakeTransport { + request_handlers: + HashMap<&'static str, Arc serde_json::Value + Send + Sync>>, + tx: futures::channel::mpsc::UnboundedSender, + rx: Arc>>, + executor: BackgroundExecutor, +} + +impl FakeTransport { + pub fn new(executor: BackgroundExecutor) -> Self { + let (tx, rx) = futures::channel::mpsc::unbounded(); + Self { + request_handlers: Default::default(), + tx, + rx: Arc::new(Mutex::new(rx)), + executor, + } + } + + pub fn on_request( + mut self, + handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static, + ) -> Self { + self.request_handlers.insert( + T::METHOD, + Arc::new(move |value| { + let params = value.get("params").expect("Missing parameters").clone(); + let params: T::Params = + serde_json::from_value(params).expect("Invalid parameters received"); + let response = handler(params); + serde_json::to_value(response).unwrap() + }), + ); + self + } +} + +#[async_trait::async_trait] +impl Transport for FakeTransport { + async fn send(&self, message: String) -> anyhow::Result<()> { + if let Ok(msg) = serde_json::from_str::(&message) { + let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0); + + if let Some(method) = msg.get("method") { + let method = method.as_str().expect("Invalid method received"); + if let Some(handler) = self.request_handlers.get(method) { + let payload = handler(msg); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": payload + }); + self.tx + .unbounded_send(response.to_string()) + .context("sending a message")?; + } else { + log::debug!("No handler registered for MCP request '{method}'"); + } + } + } + Ok(()) + } + + fn receive(&self) -> Pin + Send>> { + let rx = self.rx.clone(); + let executor = self.executor.clone(); + Box::pin(futures::stream::unfold(rx, move |rx| { + let executor = executor.clone(); + async move { + let mut rx_guard = rx.lock().await; + executor.simulate_random_delay().await; + if let Some(message) = rx_guard.next().await { + drop(rx_guard); + Some((message, rx)) + } else { + None + } + } + })) + } + + fn receive_err(&self) -> Pin + Send>> { + Box::pin(futures::stream::empty()) + } +} diff --git a/crates/context_server/src/types.rs b/crates/context_server/src/types.rs index 83f08218f3b6fd8750366732c87b6a64200e6826..9c36c40228641e2740eda0a85c12e5b2dc5776eb 100644 --- a/crates/context_server/src/types.rs +++ b/crates/context_server/src/types.rs @@ -1,76 +1,92 @@ use collections::HashMap; +use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use url::Url; pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05"; -pub enum RequestType { - Initialize, - CallTool, - ResourcesUnsubscribe, - ResourcesSubscribe, - ResourcesRead, - ResourcesList, - LoggingSetLevel, - PromptsGet, - PromptsList, - CompletionComplete, - Ping, - ListTools, - ListResourceTemplates, - ListRoots, -} - -impl RequestType { - pub fn as_str(&self) -> &'static str { - match self { - RequestType::Initialize => "initialize", - RequestType::CallTool => "tools/call", - RequestType::ResourcesUnsubscribe => "resources/unsubscribe", - RequestType::ResourcesSubscribe => "resources/subscribe", - RequestType::ResourcesRead => "resources/read", - RequestType::ResourcesList => "resources/list", - RequestType::LoggingSetLevel => "logging/setLevel", - RequestType::PromptsGet => "prompts/get", - RequestType::PromptsList => "prompts/list", - RequestType::CompletionComplete => "completion/complete", - RequestType::Ping => "ping", - RequestType::ListTools => "tools/list", - RequestType::ListResourceTemplates => "resources/templates/list", - RequestType::ListRoots => "roots/list", - } - } -} +pub mod request { + use super::*; -impl TryFrom<&str> for RequestType { - type Error = (); - - fn try_from(s: &str) -> Result { - match s { - "initialize" => Ok(RequestType::Initialize), - "tools/call" => Ok(RequestType::CallTool), - "resources/unsubscribe" => Ok(RequestType::ResourcesUnsubscribe), - "resources/subscribe" => Ok(RequestType::ResourcesSubscribe), - "resources/read" => Ok(RequestType::ResourcesRead), - "resources/list" => Ok(RequestType::ResourcesList), - "logging/setLevel" => Ok(RequestType::LoggingSetLevel), - "prompts/get" => Ok(RequestType::PromptsGet), - "prompts/list" => Ok(RequestType::PromptsList), - "completion/complete" => Ok(RequestType::CompletionComplete), - "ping" => Ok(RequestType::Ping), - "tools/list" => Ok(RequestType::ListTools), - "resources/templates/list" => Ok(RequestType::ListResourceTemplates), - "roots/list" => Ok(RequestType::ListRoots), - _ => Err(()), - } + macro_rules! request { + ($method:expr, $name:ident, $params:ty, $response:ty) => { + pub struct $name; + + impl Request for $name { + type Params = $params; + type Response = $response; + const METHOD: &'static str = $method; + } + }; } + + request!( + "initialize", + Initialize, + InitializeParams, + InitializeResponse + ); + request!("tools/call", CallTool, CallToolParams, CallToolResponse); + request!( + "resources/unsubscribe", + ResourcesUnsubscribe, + ResourcesUnsubscribeParams, + () + ); + request!( + "resources/subscribe", + ResourcesSubscribe, + ResourcesSubscribeParams, + () + ); + request!( + "resources/read", + ResourcesRead, + ResourcesReadParams, + ResourcesReadResponse + ); + request!("resources/list", ResourcesList, (), ResourcesListResponse); + request!( + "logging/setLevel", + LoggingSetLevel, + LoggingSetLevelParams, + () + ); + request!( + "prompts/get", + PromptsGet, + PromptsGetParams, + PromptsGetResponse + ); + request!("prompts/list", PromptsList, (), PromptsListResponse); + request!( + "completion/complete", + CompletionComplete, + CompletionCompleteParams, + CompletionCompleteResponse + ); + request!("ping", Ping, (), ()); + request!("tools/list", ListTools, (), ListToolsResponse); + request!( + "resources/templates/list", + ListResourceTemplates, + (), + ListResourceTemplatesResponse + ); + request!("roots/list", ListRoots, (), ListRootsResponse); +} + +pub trait Request { + type Params: DeserializeOwned + Serialize + Send + Sync + 'static; + type Response: DeserializeOwned + Serialize + Send + Sync + 'static; + const METHOD: &'static str; } #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(transparent)] pub struct ProtocolVersion(pub String); -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct InitializeParams { pub protocol_version: ProtocolVersion, @@ -80,7 +96,7 @@ pub struct InitializeParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CallToolParams { pub name: String, @@ -90,7 +106,7 @@ pub struct CallToolParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesUnsubscribeParams { pub uri: Url, @@ -98,7 +114,7 @@ pub struct ResourcesUnsubscribeParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesSubscribeParams { pub uri: Url, @@ -106,7 +122,7 @@ pub struct ResourcesSubscribeParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesReadParams { pub uri: Url, @@ -114,7 +130,7 @@ pub struct ResourcesReadParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct LoggingSetLevelParams { pub level: LoggingLevel, @@ -122,7 +138,7 @@ pub struct LoggingSetLevelParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptsGetParams { pub name: String, @@ -132,37 +148,40 @@ pub struct PromptsGetParams { pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CompletionCompleteParams { - pub r#ref: CompletionReference, + #[serde(rename = "ref")] + pub reference: CompletionReference, pub argument: CompletionArgument, #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] pub meta: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum CompletionReference { Prompt(PromptReference), Resource(ResourceReference), } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptReference { - pub r#type: PromptReferenceType, + #[serde(rename = "type")] + pub ty: PromptReferenceType, pub name: String, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourceReference { - pub r#type: PromptReferenceType, + #[serde(rename = "type")] + pub ty: PromptReferenceType, pub uri: Url, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum PromptReferenceType { #[serde(rename = "ref/prompt")] @@ -171,7 +190,7 @@ pub enum PromptReferenceType { Resource, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CompletionArgument { pub name: String, @@ -188,7 +207,7 @@ pub struct InitializeResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesReadResponse { pub contents: Vec, @@ -196,14 +215,14 @@ pub struct ResourcesReadResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum ResourceContentsType { Text(TextResourceContents), Blob(BlobResourceContents), } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ResourcesListResponse { pub resources: Vec, @@ -220,7 +239,7 @@ pub struct SamplingMessage { pub content: MessageContent, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CreateMessageRequest { pub messages: Vec, @@ -296,7 +315,7 @@ pub struct MessageAnnotations { pub priority: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptsGetResponse { #[serde(skip_serializing_if = "Option::is_none")] @@ -306,7 +325,7 @@ pub struct PromptsGetResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptsListResponse { pub prompts: Vec, @@ -316,7 +335,7 @@ pub struct PromptsListResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CompletionCompleteResponse { pub completion: CompletionResult, @@ -324,7 +343,7 @@ pub struct CompletionCompleteResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CompletionResult { pub values: Vec, @@ -336,7 +355,7 @@ pub struct CompletionResult { pub meta: Option>, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Prompt { pub name: String, @@ -346,7 +365,7 @@ pub struct Prompt { pub arguments: Option>, } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptArgument { pub name: String, @@ -509,7 +528,7 @@ pub struct ModelHint { pub name: Option, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub enum NotificationType { Initialized, @@ -589,7 +608,7 @@ pub struct Completion { pub total: CompletionTotal, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CallToolResponse { pub content: Vec, @@ -620,7 +639,7 @@ pub struct ListToolsResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ListResourceTemplatesResponse { pub resource_templates: Vec, @@ -630,7 +649,7 @@ pub struct ListResourceTemplatesResponse { pub meta: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ListRootsResponse { pub roots: Vec, diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index 7e506d218444781a2247003cbfe8b0efb7d44ddc..f208af54d77cd5bfc86afbccda8e96f113ee3778 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -91,6 +91,7 @@ workspace-hack.workspace = true [dev-dependencies] client = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] } +context_server = { workspace = true, features = ["test-support"] } buffer_diff = { workspace = true, features = ["test-support"] } dap = { workspace = true, features = ["test-support"] } dap_adapters = { workspace = true, features = ["test-support"] } diff --git a/crates/project/src/context_server_store.rs b/crates/project/src/context_server_store.rs index aac9d5d4604e655303e7277f2d5612093e7416de..34d6abb96c9cb6567d5fa2bbcb9a769c5be5198b 100644 --- a/crates/project/src/context_server_store.rs +++ b/crates/project/src/context_server_store.rs @@ -499,17 +499,10 @@ impl ContextServerStore { mod tests { use super::*; use crate::{FakeFs, Project, project_settings::ProjectSettings}; - use context_server::{ - transport::Transport, - types::{ - self, Implementation, InitializeResponse, ProtocolVersion, RequestType, - ServerCapabilities, - }, - }; - use futures::{Stream, StreamExt as _, lock::Mutex}; - use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _}; + use context_server::test::create_fake_transport; + use gpui::{AppContext, TestAppContext, UpdateGlobal as _}; use serde_json::json; - use std::{cell::RefCell, pin::Pin, rc::Rc}; + use std::{cell::RefCell, rc::Rc}; use util::path; #[gpui::test] @@ -532,33 +525,17 @@ mod tests { ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) }); - let server_1_id = ContextServerId("mcp-1".into()); - let server_2_id = ContextServerId("mcp-2".into()); - - let transport_1 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-1".to_string())) - } - _ => None, - }, - )); - - let transport_2 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-2".to_string())) - } - _ => None, - }, - )); + let server_1_id = ContextServerId(SERVER_1_ID.into()); + let server_2_id = ContextServerId(SERVER_2_ID.into()); - let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone())); - let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone())); + let server_1 = Arc::new(ContextServer::new( + server_1_id.clone(), + Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())), + )); + let server_2 = Arc::new(ContextServer::new( + server_2_id.clone(), + Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())), + )); store .update(cx, |store, cx| store.start_server(server_1, cx)) @@ -627,33 +604,17 @@ mod tests { ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx) }); - let server_1_id = ContextServerId("mcp-1".into()); - let server_2_id = ContextServerId("mcp-2".into()); - - let transport_1 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-1".to_string())) - } - _ => None, - }, - )); - - let transport_2 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response("mcp-2".to_string())) - } - _ => None, - }, - )); + let server_1_id = ContextServerId(SERVER_1_ID.into()); + let server_2_id = ContextServerId(SERVER_2_ID.into()); - let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone())); - let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone())); + let server_1 = Arc::new(ContextServer::new( + server_1_id.clone(), + Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())), + )); + let server_2 = Arc::new(ContextServer::new( + server_2_id.clone(), + Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())), + )); let _server_events = assert_server_events( &store, @@ -702,30 +663,14 @@ mod tests { let server_id = ContextServerId(SERVER_1_ID.into()); - let transport_1 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response(SERVER_1_ID.to_string())) - } - _ => None, - }, - )); - - let transport_2 = - Arc::new(FakeTransport::new( - cx.executor(), - |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response(SERVER_1_ID.to_string())) - } - _ => None, - }, - )); - - let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1)); - let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2)); + let server_with_same_id_1 = Arc::new(ContextServer::new( + server_id.clone(), + Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())), + )); + let server_with_same_id_2 = Arc::new(ContextServer::new( + server_id.clone(), + Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())), + )); // If we start another server with the same id, we should report that we stopped the previous one let _server_events = assert_server_events( @@ -794,16 +739,10 @@ mod tests { let store = cx.new(|cx| { ContextServerStore::test_maintain_server_loop( Box::new(move |id, _| { - let transport = FakeTransport::new(executor.clone(), { - let id = id.0.clone(); - move |_, request_type, _| match request_type { - Some(RequestType::Initialize) => { - Some(create_initialize_response(id.clone().to_string())) - } - _ => None, - } - }); - Arc::new(ContextServer::new(id.clone(), Arc::new(transport))) + Arc::new(ContextServer::new( + id.clone(), + Arc::new(create_fake_transport(id.0.to_string(), executor.clone())), + )) }), registry.clone(), project.read(cx).worktree_store(), @@ -1033,99 +972,4 @@ mod tests { (fs, project) } - - fn create_initialize_response(server_name: String) -> serde_json::Value { - serde_json::to_value(&InitializeResponse { - protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()), - server_info: Implementation { - name: server_name, - version: "1.0.0".to_string(), - }, - capabilities: ServerCapabilities::default(), - meta: None, - }) - .unwrap() - } - - struct FakeTransport { - on_request: Arc< - dyn Fn(u64, Option, serde_json::Value) -> Option - + Send - + Sync, - >, - tx: futures::channel::mpsc::UnboundedSender, - rx: Arc>>, - executor: BackgroundExecutor, - } - - impl FakeTransport { - fn new( - executor: BackgroundExecutor, - on_request: impl Fn( - u64, - Option, - serde_json::Value, - ) -> Option - + 'static - + Send - + Sync, - ) -> Self { - let (tx, rx) = futures::channel::mpsc::unbounded(); - Self { - on_request: Arc::new(on_request), - tx, - rx: Arc::new(Mutex::new(rx)), - executor, - } - } - } - - #[async_trait::async_trait] - impl Transport for FakeTransport { - async fn send(&self, message: String) -> Result<()> { - if let Ok(msg) = serde_json::from_str::(&message) { - let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0); - - if let Some(method) = msg.get("method") { - let request_type = method - .as_str() - .and_then(|method| types::RequestType::try_from(method).ok()); - if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) { - let response = serde_json::json!({ - "jsonrpc": "2.0", - "id": id, - "result": payload - }); - - self.tx - .unbounded_send(response.to_string()) - .context("sending a message")?; - } - } - } - Ok(()) - } - - fn receive(&self) -> Pin + Send>> { - let rx = self.rx.clone(); - let executor = self.executor.clone(); - Box::pin(futures::stream::unfold(rx, move |rx| { - let executor = executor.clone(); - async move { - let mut rx_guard = rx.lock().await; - executor.simulate_random_delay().await; - if let Some(message) = rx_guard.next().await { - drop(rx_guard); - Some((message, rx)) - } else { - None - } - } - })) - } - - fn receive_err(&self) -> Pin + Send>> { - Box::pin(futures::stream::empty()) - } - } }