Detailed changes
@@ -104,7 +104,15 @@ impl Tool for ContextServerTool {
tool_name,
arguments
);
- let response = protocol.run_tool(tool_name, arguments).await?;
+ let response = protocol
+ .request::<context_server::types::request::CallTool>(
+ context_server::types::CallToolParams {
+ name: tool_name,
+ arguments,
+ meta: None,
+ },
+ )
+ .await?;
let mut result = String::new();
for content in response.content {
@@ -566,10 +566,14 @@ impl ThreadStore {
};
if protocol.capable(context_server::protocol::ServerCapability::Tools) {
- if let Some(tools) = protocol.list_tools().await.log_err() {
+ if let Some(response) = protocol
+ .request::<context_server::types::request::ListTools>(())
+ .await
+ .log_err()
+ {
let tool_ids = tool_working_set
.update(cx, |tool_working_set, _| {
- tools
+ response
.tools
.into_iter()
.map(|tool| {
@@ -864,8 +864,13 @@ impl ContextStore {
};
if protocol.capable(context_server::protocol::ServerCapability::Prompts) {
- if let Some(prompts) = protocol.list_prompts().await.log_err() {
- let slash_command_ids = prompts
+ if let Some(response) = protocol
+ .request::<context_server::types::request::PromptsList>(())
+ .await
+ .log_err()
+ {
+ let slash_command_ids = response
+ .prompts
.into_iter()
.filter(assistant_slash_commands::acceptable_prompt)
.map(|prompt| {
@@ -86,20 +86,26 @@ impl SlashCommand for ContextServerSlashCommand {
cx.foreground_executor().spawn(async move {
let protocol = server.client().context("Context server not initialized")?;
- let completion_result = protocol
- .completion(
- context_server::types::CompletionReference::Prompt(
- context_server::types::PromptReference {
- r#type: context_server::types::PromptReferenceType::Prompt,
- name: prompt_name,
+ let response = protocol
+ .request::<context_server::types::request::CompletionComplete>(
+ context_server::types::CompletionCompleteParams {
+ reference: context_server::types::CompletionReference::Prompt(
+ context_server::types::PromptReference {
+ ty: context_server::types::PromptReferenceType::Prompt,
+ name: prompt_name,
+ },
+ ),
+ argument: context_server::types::CompletionArgument {
+ name: arg_name,
+ value: arg_value,
},
- ),
- arg_name,
- arg_value,
+ meta: None,
+ },
)
.await?;
- let completions = completion_result
+ let completions = response
+ .completion
.values
.into_iter()
.map(|value| ArgumentCompletion {
@@ -138,10 +144,18 @@ impl SlashCommand for ContextServerSlashCommand {
if let Some(server) = store.get_running_server(&server_id) {
cx.foreground_executor().spawn(async move {
let protocol = server.client().context("Context server not initialized")?;
- let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
+ let response = protocol
+ .request::<context_server::types::request::PromptsGet>(
+ context_server::types::PromptsGetParams {
+ name: prompt_name.clone(),
+ arguments: Some(prompt_args),
+ meta: None,
+ },
+ )
+ .await?;
anyhow::ensure!(
- result
+ response
.messages
.iter()
.all(|msg| matches!(msg.role, context_server::types::Role::User)),
@@ -149,7 +163,7 @@ impl SlashCommand for ContextServerSlashCommand {
);
// Extract text from user messages into a single prompt string
- let mut prompt = result
+ let mut prompt = response
.messages
.into_iter()
.filter_map(|msg| match msg.content {
@@ -167,7 +181,7 @@ impl SlashCommand for ContextServerSlashCommand {
range: 0..(prompt.len()),
icon: IconName::ZedAssistant,
label: SharedString::from(
- result
+ response
.description
.unwrap_or(format!("Result from {}", prompt_name)),
),
@@ -11,6 +11,9 @@ workspace = true
[lib]
path = "src/context_server.rs"
+[features]
+test-support = []
+
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
@@ -1,5 +1,7 @@
pub mod client;
pub mod protocol;
+#[cfg(any(test, feature = "test-support"))]
+pub mod test;
pub mod transport;
pub mod types;
@@ -6,10 +6,9 @@
//! of messages.
use anyhow::Result;
-use collections::HashMap;
use crate::client::Client;
-use crate::types;
+use crate::types::{self, Request};
pub struct ModelContextProtocol {
inner: Client,
@@ -43,7 +42,7 @@ impl ModelContextProtocol {
let response: types::InitializeResponse = self
.inner
- .request(types::RequestType::Initialize.as_str(), params)
+ .request(types::request::Initialize::METHOD, params)
.await?;
anyhow::ensure!(
@@ -94,137 +93,7 @@ impl InitializedContextServerProtocol {
}
}
- fn check_capability(&self, capability: ServerCapability) -> Result<()> {
- anyhow::ensure!(
- self.capable(capability),
- "Server does not support {capability:?} capability"
- );
- Ok(())
- }
-
- /// List the MCP prompts.
- pub async fn list_prompts(&self) -> Result<Vec<types::Prompt>> {
- self.check_capability(ServerCapability::Prompts)?;
-
- let response: types::PromptsListResponse = self
- .inner
- .request(
- types::RequestType::PromptsList.as_str(),
- serde_json::json!({}),
- )
- .await?;
-
- Ok(response.prompts)
- }
-
- /// List the MCP resources.
- pub async fn list_resources(&self) -> Result<types::ResourcesListResponse> {
- self.check_capability(ServerCapability::Resources)?;
-
- let response: types::ResourcesListResponse = self
- .inner
- .request(
- types::RequestType::ResourcesList.as_str(),
- serde_json::json!({}),
- )
- .await?;
-
- Ok(response)
- }
-
- /// Executes a prompt with the given arguments and returns the result.
- pub async fn run_prompt<P: AsRef<str>>(
- &self,
- prompt: P,
- arguments: HashMap<String, String>,
- ) -> Result<types::PromptsGetResponse> {
- self.check_capability(ServerCapability::Prompts)?;
-
- let params = types::PromptsGetParams {
- name: prompt.as_ref().to_string(),
- arguments: Some(arguments),
- meta: None,
- };
-
- let response: types::PromptsGetResponse = self
- .inner
- .request(types::RequestType::PromptsGet.as_str(), params)
- .await?;
-
- Ok(response)
- }
-
- pub async fn completion<P: Into<String>>(
- &self,
- reference: types::CompletionReference,
- argument: P,
- value: P,
- ) -> Result<types::Completion> {
- let params = types::CompletionCompleteParams {
- r#ref: reference,
- argument: types::CompletionArgument {
- name: argument.into(),
- value: value.into(),
- },
- meta: None,
- };
- let result: types::CompletionCompleteResponse = self
- .inner
- .request(types::RequestType::CompletionComplete.as_str(), params)
- .await?;
-
- let completion = types::Completion {
- values: result.completion.values,
- total: types::CompletionTotal::from_options(
- result.completion.has_more,
- result.completion.total,
- ),
- };
-
- Ok(completion)
- }
-
- /// List MCP tools.
- pub async fn list_tools(&self) -> Result<types::ListToolsResponse> {
- self.check_capability(ServerCapability::Tools)?;
-
- let response = self
- .inner
- .request::<types::ListToolsResponse>(types::RequestType::ListTools.as_str(), ())
- .await?;
-
- Ok(response)
- }
-
- /// Executes a tool with the given arguments
- pub async fn run_tool<P: AsRef<str>>(
- &self,
- tool: P,
- arguments: Option<HashMap<String, serde_json::Value>>,
- ) -> Result<types::CallToolResponse> {
- self.check_capability(ServerCapability::Tools)?;
-
- let params = types::CallToolParams {
- name: tool.as_ref().to_string(),
- arguments,
- meta: None,
- };
-
- let response: types::CallToolResponse = self
- .inner
- .request(types::RequestType::CallTool.as_str(), params)
- .await?;
-
- Ok(response)
- }
-}
-
-impl InitializedContextServerProtocol {
- pub async fn request<R: serde::de::DeserializeOwned>(
- &self,
- method: &str,
- params: impl serde::Serialize,
- ) -> Result<R> {
- self.inner.request(method, params).await
+ pub async fn request<T: Request>(&self, params: T::Params) -> Result<T::Response> {
+ self.inner.request(T::METHOD, params).await
}
}
@@ -0,0 +1,118 @@
+use anyhow::Context as _;
+use collections::HashMap;
+use futures::{Stream, StreamExt as _, lock::Mutex};
+use gpui::BackgroundExecutor;
+use std::{pin::Pin, sync::Arc};
+
+use crate::{
+ transport::Transport,
+ types::{Implementation, InitializeResponse, ProtocolVersion, ServerCapabilities},
+};
+
+pub fn create_fake_transport(
+ name: impl Into<String>,
+ executor: BackgroundExecutor,
+) -> FakeTransport {
+ let name = name.into();
+ FakeTransport::new(executor).on_request::<crate::types::request::Initialize>(move |_params| {
+ create_initialize_response(name.clone())
+ })
+}
+
+fn create_initialize_response(server_name: String) -> InitializeResponse {
+ InitializeResponse {
+ protocol_version: ProtocolVersion(crate::types::LATEST_PROTOCOL_VERSION.to_string()),
+ server_info: Implementation {
+ name: server_name,
+ version: "1.0.0".to_string(),
+ },
+ capabilities: ServerCapabilities::default(),
+ meta: None,
+ }
+}
+
+pub struct FakeTransport {
+ request_handlers:
+ HashMap<&'static str, Arc<dyn Fn(serde_json::Value) -> serde_json::Value + Send + Sync>>,
+ tx: futures::channel::mpsc::UnboundedSender<String>,
+ rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
+ executor: BackgroundExecutor,
+}
+
+impl FakeTransport {
+ pub fn new(executor: BackgroundExecutor) -> Self {
+ let (tx, rx) = futures::channel::mpsc::unbounded();
+ Self {
+ request_handlers: Default::default(),
+ tx,
+ rx: Arc::new(Mutex::new(rx)),
+ executor,
+ }
+ }
+
+ pub fn on_request<T: crate::types::Request>(
+ mut self,
+ handler: impl Fn(T::Params) -> T::Response + Send + Sync + 'static,
+ ) -> Self {
+ self.request_handlers.insert(
+ T::METHOD,
+ Arc::new(move |value| {
+ let params = value.get("params").expect("Missing parameters").clone();
+ let params: T::Params =
+ serde_json::from_value(params).expect("Invalid parameters received");
+ let response = handler(params);
+ serde_json::to_value(response).unwrap()
+ }),
+ );
+ self
+ }
+}
+
+#[async_trait::async_trait]
+impl Transport for FakeTransport {
+ async fn send(&self, message: String) -> anyhow::Result<()> {
+ if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
+ let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
+
+ if let Some(method) = msg.get("method") {
+ let method = method.as_str().expect("Invalid method received");
+ if let Some(handler) = self.request_handlers.get(method) {
+ let payload = handler(msg);
+ let response = serde_json::json!({
+ "jsonrpc": "2.0",
+ "id": id,
+ "result": payload
+ });
+ self.tx
+ .unbounded_send(response.to_string())
+ .context("sending a message")?;
+ } else {
+ log::debug!("No handler registered for MCP request '{method}'");
+ }
+ }
+ }
+ Ok(())
+ }
+
+ fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
+ let rx = self.rx.clone();
+ let executor = self.executor.clone();
+ Box::pin(futures::stream::unfold(rx, move |rx| {
+ let executor = executor.clone();
+ async move {
+ let mut rx_guard = rx.lock().await;
+ executor.simulate_random_delay().await;
+ if let Some(message) = rx_guard.next().await {
+ drop(rx_guard);
+ Some((message, rx))
+ } else {
+ None
+ }
+ }
+ }))
+ }
+
+ fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
+ Box::pin(futures::stream::empty())
+ }
+}
@@ -1,76 +1,92 @@
use collections::HashMap;
+use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use url::Url;
pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05";
-pub enum RequestType {
- Initialize,
- CallTool,
- ResourcesUnsubscribe,
- ResourcesSubscribe,
- ResourcesRead,
- ResourcesList,
- LoggingSetLevel,
- PromptsGet,
- PromptsList,
- CompletionComplete,
- Ping,
- ListTools,
- ListResourceTemplates,
- ListRoots,
-}
-
-impl RequestType {
- pub fn as_str(&self) -> &'static str {
- match self {
- RequestType::Initialize => "initialize",
- RequestType::CallTool => "tools/call",
- RequestType::ResourcesUnsubscribe => "resources/unsubscribe",
- RequestType::ResourcesSubscribe => "resources/subscribe",
- RequestType::ResourcesRead => "resources/read",
- RequestType::ResourcesList => "resources/list",
- RequestType::LoggingSetLevel => "logging/setLevel",
- RequestType::PromptsGet => "prompts/get",
- RequestType::PromptsList => "prompts/list",
- RequestType::CompletionComplete => "completion/complete",
- RequestType::Ping => "ping",
- RequestType::ListTools => "tools/list",
- RequestType::ListResourceTemplates => "resources/templates/list",
- RequestType::ListRoots => "roots/list",
- }
- }
-}
+pub mod request {
+ use super::*;
-impl TryFrom<&str> for RequestType {
- type Error = ();
-
- fn try_from(s: &str) -> Result<Self, Self::Error> {
- match s {
- "initialize" => Ok(RequestType::Initialize),
- "tools/call" => Ok(RequestType::CallTool),
- "resources/unsubscribe" => Ok(RequestType::ResourcesUnsubscribe),
- "resources/subscribe" => Ok(RequestType::ResourcesSubscribe),
- "resources/read" => Ok(RequestType::ResourcesRead),
- "resources/list" => Ok(RequestType::ResourcesList),
- "logging/setLevel" => Ok(RequestType::LoggingSetLevel),
- "prompts/get" => Ok(RequestType::PromptsGet),
- "prompts/list" => Ok(RequestType::PromptsList),
- "completion/complete" => Ok(RequestType::CompletionComplete),
- "ping" => Ok(RequestType::Ping),
- "tools/list" => Ok(RequestType::ListTools),
- "resources/templates/list" => Ok(RequestType::ListResourceTemplates),
- "roots/list" => Ok(RequestType::ListRoots),
- _ => Err(()),
- }
+ macro_rules! request {
+ ($method:expr, $name:ident, $params:ty, $response:ty) => {
+ pub struct $name;
+
+ impl Request for $name {
+ type Params = $params;
+ type Response = $response;
+ const METHOD: &'static str = $method;
+ }
+ };
}
+
+ request!(
+ "initialize",
+ Initialize,
+ InitializeParams,
+ InitializeResponse
+ );
+ request!("tools/call", CallTool, CallToolParams, CallToolResponse);
+ request!(
+ "resources/unsubscribe",
+ ResourcesUnsubscribe,
+ ResourcesUnsubscribeParams,
+ ()
+ );
+ request!(
+ "resources/subscribe",
+ ResourcesSubscribe,
+ ResourcesSubscribeParams,
+ ()
+ );
+ request!(
+ "resources/read",
+ ResourcesRead,
+ ResourcesReadParams,
+ ResourcesReadResponse
+ );
+ request!("resources/list", ResourcesList, (), ResourcesListResponse);
+ request!(
+ "logging/setLevel",
+ LoggingSetLevel,
+ LoggingSetLevelParams,
+ ()
+ );
+ request!(
+ "prompts/get",
+ PromptsGet,
+ PromptsGetParams,
+ PromptsGetResponse
+ );
+ request!("prompts/list", PromptsList, (), PromptsListResponse);
+ request!(
+ "completion/complete",
+ CompletionComplete,
+ CompletionCompleteParams,
+ CompletionCompleteResponse
+ );
+ request!("ping", Ping, (), ());
+ request!("tools/list", ListTools, (), ListToolsResponse);
+ request!(
+ "resources/templates/list",
+ ListResourceTemplates,
+ (),
+ ListResourceTemplatesResponse
+ );
+ request!("roots/list", ListRoots, (), ListRootsResponse);
+}
+
+pub trait Request {
+ type Params: DeserializeOwned + Serialize + Send + Sync + 'static;
+ type Response: DeserializeOwned + Serialize + Send + Sync + 'static;
+ const METHOD: &'static str;
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct ProtocolVersion(pub String);
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeParams {
pub protocol_version: ProtocolVersion,
@@ -80,7 +96,7 @@ pub struct InitializeParams {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CallToolParams {
pub name: String,
@@ -90,7 +106,7 @@ pub struct CallToolParams {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesUnsubscribeParams {
pub uri: Url,
@@ -98,7 +114,7 @@ pub struct ResourcesUnsubscribeParams {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesSubscribeParams {
pub uri: Url,
@@ -106,7 +122,7 @@ pub struct ResourcesSubscribeParams {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesReadParams {
pub uri: Url,
@@ -114,7 +130,7 @@ pub struct ResourcesReadParams {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LoggingSetLevelParams {
pub level: LoggingLevel,
@@ -122,7 +138,7 @@ pub struct LoggingSetLevelParams {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsGetParams {
pub name: String,
@@ -132,37 +148,40 @@ pub struct PromptsGetParams {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionCompleteParams {
- pub r#ref: CompletionReference,
+ #[serde(rename = "ref")]
+ pub reference: CompletionReference,
pub argument: CompletionArgument,
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CompletionReference {
Prompt(PromptReference),
Resource(ResourceReference),
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptReference {
- pub r#type: PromptReferenceType,
+ #[serde(rename = "type")]
+ pub ty: PromptReferenceType,
pub name: String,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceReference {
- pub r#type: PromptReferenceType,
+ #[serde(rename = "type")]
+ pub ty: PromptReferenceType,
pub uri: Url,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PromptReferenceType {
#[serde(rename = "ref/prompt")]
@@ -171,7 +190,7 @@ pub enum PromptReferenceType {
Resource,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionArgument {
pub name: String,
@@ -188,7 +207,7 @@ pub struct InitializeResponse {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesReadResponse {
pub contents: Vec<ResourceContentsType>,
@@ -196,14 +215,14 @@ pub struct ResourcesReadResponse {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ResourceContentsType {
Text(TextResourceContents),
Blob(BlobResourceContents),
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourcesListResponse {
pub resources: Vec<Resource>,
@@ -220,7 +239,7 @@ pub struct SamplingMessage {
pub content: MessageContent,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateMessageRequest {
pub messages: Vec<SamplingMessage>,
@@ -296,7 +315,7 @@ pub struct MessageAnnotations {
pub priority: Option<f64>,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsGetResponse {
#[serde(skip_serializing_if = "Option::is_none")]
@@ -306,7 +325,7 @@ pub struct PromptsGetResponse {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptsListResponse {
pub prompts: Vec<Prompt>,
@@ -316,7 +335,7 @@ pub struct PromptsListResponse {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionCompleteResponse {
pub completion: CompletionResult,
@@ -324,7 +343,7 @@ pub struct CompletionCompleteResponse {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CompletionResult {
pub values: Vec<String>,
@@ -336,7 +355,7 @@ pub struct CompletionResult {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Deserialize, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Prompt {
pub name: String,
@@ -346,7 +365,7 @@ pub struct Prompt {
pub arguments: Option<Vec<PromptArgument>>,
}
-#[derive(Debug, Deserialize, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptArgument {
pub name: String,
@@ -509,7 +528,7 @@ pub struct ModelHint {
pub name: Option<String>,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum NotificationType {
Initialized,
@@ -589,7 +608,7 @@ pub struct Completion {
pub total: CompletionTotal,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CallToolResponse {
pub content: Vec<ToolResponseContent>,
@@ -620,7 +639,7 @@ pub struct ListToolsResponse {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListResourceTemplatesResponse {
pub resource_templates: Vec<ResourceTemplate>,
@@ -630,7 +649,7 @@ pub struct ListResourceTemplatesResponse {
pub meta: Option<HashMap<String, serde_json::Value>>,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ListRootsResponse {
pub roots: Vec<Root>,
@@ -91,6 +91,7 @@ workspace-hack.workspace = true
[dev-dependencies]
client = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
+context_server = { workspace = true, features = ["test-support"] }
buffer_diff = { workspace = true, features = ["test-support"] }
dap = { workspace = true, features = ["test-support"] }
dap_adapters = { workspace = true, features = ["test-support"] }
@@ -499,17 +499,10 @@ impl ContextServerStore {
mod tests {
use super::*;
use crate::{FakeFs, Project, project_settings::ProjectSettings};
- use context_server::{
- transport::Transport,
- types::{
- self, Implementation, InitializeResponse, ProtocolVersion, RequestType,
- ServerCapabilities,
- },
- };
- use futures::{Stream, StreamExt as _, lock::Mutex};
- use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _};
+ use context_server::test::create_fake_transport;
+ use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
use serde_json::json;
- use std::{cell::RefCell, pin::Pin, rc::Rc};
+ use std::{cell::RefCell, rc::Rc};
use util::path;
#[gpui::test]
@@ -532,33 +525,17 @@ mod tests {
ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
});
- let server_1_id = ContextServerId("mcp-1".into());
- let server_2_id = ContextServerId("mcp-2".into());
-
- let transport_1 =
- Arc::new(FakeTransport::new(
- cx.executor(),
- |_, request_type, _| match request_type {
- Some(RequestType::Initialize) => {
- Some(create_initialize_response("mcp-1".to_string()))
- }
- _ => None,
- },
- ));
-
- let transport_2 =
- Arc::new(FakeTransport::new(
- cx.executor(),
- |_, request_type, _| match request_type {
- Some(RequestType::Initialize) => {
- Some(create_initialize_response("mcp-2".to_string()))
- }
- _ => None,
- },
- ));
+ let server_1_id = ContextServerId(SERVER_1_ID.into());
+ let server_2_id = ContextServerId(SERVER_2_ID.into());
- let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
- let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
+ let server_1 = Arc::new(ContextServer::new(
+ server_1_id.clone(),
+ Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
+ ));
+ let server_2 = Arc::new(ContextServer::new(
+ server_2_id.clone(),
+ Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
+ ));
store
.update(cx, |store, cx| store.start_server(server_1, cx))
@@ -627,33 +604,17 @@ mod tests {
ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
});
- let server_1_id = ContextServerId("mcp-1".into());
- let server_2_id = ContextServerId("mcp-2".into());
-
- let transport_1 =
- Arc::new(FakeTransport::new(
- cx.executor(),
- |_, request_type, _| match request_type {
- Some(RequestType::Initialize) => {
- Some(create_initialize_response("mcp-1".to_string()))
- }
- _ => None,
- },
- ));
-
- let transport_2 =
- Arc::new(FakeTransport::new(
- cx.executor(),
- |_, request_type, _| match request_type {
- Some(RequestType::Initialize) => {
- Some(create_initialize_response("mcp-2".to_string()))
- }
- _ => None,
- },
- ));
+ let server_1_id = ContextServerId(SERVER_1_ID.into());
+ let server_2_id = ContextServerId(SERVER_2_ID.into());
- let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
- let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
+ let server_1 = Arc::new(ContextServer::new(
+ server_1_id.clone(),
+ Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
+ ));
+ let server_2 = Arc::new(ContextServer::new(
+ server_2_id.clone(),
+ Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
+ ));
let _server_events = assert_server_events(
&store,
@@ -702,30 +663,14 @@ mod tests {
let server_id = ContextServerId(SERVER_1_ID.into());
- let transport_1 =
- Arc::new(FakeTransport::new(
- cx.executor(),
- |_, request_type, _| match request_type {
- Some(RequestType::Initialize) => {
- Some(create_initialize_response(SERVER_1_ID.to_string()))
- }
- _ => None,
- },
- ));
-
- let transport_2 =
- Arc::new(FakeTransport::new(
- cx.executor(),
- |_, request_type, _| match request_type {
- Some(RequestType::Initialize) => {
- Some(create_initialize_response(SERVER_1_ID.to_string()))
- }
- _ => None,
- },
- ));
-
- let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1));
- let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2));
+ let server_with_same_id_1 = Arc::new(ContextServer::new(
+ server_id.clone(),
+ Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
+ ));
+ let server_with_same_id_2 = Arc::new(ContextServer::new(
+ server_id.clone(),
+ Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
+ ));
// If we start another server with the same id, we should report that we stopped the previous one
let _server_events = assert_server_events(
@@ -794,16 +739,10 @@ mod tests {
let store = cx.new(|cx| {
ContextServerStore::test_maintain_server_loop(
Box::new(move |id, _| {
- let transport = FakeTransport::new(executor.clone(), {
- let id = id.0.clone();
- move |_, request_type, _| match request_type {
- Some(RequestType::Initialize) => {
- Some(create_initialize_response(id.clone().to_string()))
- }
- _ => None,
- }
- });
- Arc::new(ContextServer::new(id.clone(), Arc::new(transport)))
+ Arc::new(ContextServer::new(
+ id.clone(),
+ Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
+ ))
}),
registry.clone(),
project.read(cx).worktree_store(),
@@ -1033,99 +972,4 @@ mod tests {
(fs, project)
}
-
- fn create_initialize_response(server_name: String) -> serde_json::Value {
- serde_json::to_value(&InitializeResponse {
- protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
- server_info: Implementation {
- name: server_name,
- version: "1.0.0".to_string(),
- },
- capabilities: ServerCapabilities::default(),
- meta: None,
- })
- .unwrap()
- }
-
- struct FakeTransport {
- on_request: Arc<
- dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
- + Send
- + Sync,
- >,
- tx: futures::channel::mpsc::UnboundedSender<String>,
- rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
- executor: BackgroundExecutor,
- }
-
- impl FakeTransport {
- fn new(
- executor: BackgroundExecutor,
- on_request: impl Fn(
- u64,
- Option<RequestType>,
- serde_json::Value,
- ) -> Option<serde_json::Value>
- + 'static
- + Send
- + Sync,
- ) -> Self {
- let (tx, rx) = futures::channel::mpsc::unbounded();
- Self {
- on_request: Arc::new(on_request),
- tx,
- rx: Arc::new(Mutex::new(rx)),
- executor,
- }
- }
- }
-
- #[async_trait::async_trait]
- impl Transport for FakeTransport {
- async fn send(&self, message: String) -> Result<()> {
- if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
- let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
-
- if let Some(method) = msg.get("method") {
- let request_type = method
- .as_str()
- .and_then(|method| types::RequestType::try_from(method).ok());
- if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
- let response = serde_json::json!({
- "jsonrpc": "2.0",
- "id": id,
- "result": payload
- });
-
- self.tx
- .unbounded_send(response.to_string())
- .context("sending a message")?;
- }
- }
- }
- Ok(())
- }
-
- fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
- let rx = self.rx.clone();
- let executor = self.executor.clone();
- Box::pin(futures::stream::unfold(rx, move |rx| {
- let executor = executor.clone();
- async move {
- let mut rx_guard = rx.lock().await;
- executor.simulate_random_delay().await;
- if let Some(message) = rx_guard.next().await {
- drop(rx_guard);
- Some((message, rx))
- } else {
- None
- }
- }
- }))
- }
-
- fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
- Box::pin(futures::stream::empty())
- }
- }
}