1pub mod client;
2pub mod listener;
3pub mod protocol;
4#[cfg(any(test, feature = "test-support"))]
5pub mod test;
6pub mod transport;
7pub mod types;
8
9use collections::HashMap;
10use http_client::HttpClient;
11use std::path::Path;
12use std::sync::Arc;
13use std::{fmt::Display, path::PathBuf};
14
15use anyhow::Result;
16use client::Client;
17use gpui::AsyncApp;
18use parking_lot::RwLock;
19pub use settings::ContextServerCommand;
20use url::Url;
21
22use crate::transport::HttpTransport;
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
25pub struct ContextServerId(pub Arc<str>);
26
27impl Display for ContextServerId {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(f, "{}", self.0)
30 }
31}
32
33enum ContextServerTransport {
34 Stdio(ContextServerCommand, Option<PathBuf>),
35 Custom(Arc<dyn crate::transport::Transport>),
36}
37
38pub struct ContextServer {
39 id: ContextServerId,
40 client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
41 configuration: ContextServerTransport,
42}
43
44impl ContextServer {
45 pub fn stdio(
46 id: ContextServerId,
47 command: ContextServerCommand,
48 working_directory: Option<Arc<Path>>,
49 ) -> Self {
50 Self {
51 id,
52 client: RwLock::new(None),
53 configuration: ContextServerTransport::Stdio(
54 command,
55 working_directory.map(|directory| directory.to_path_buf()),
56 ),
57 }
58 }
59
60 pub fn http(
61 id: ContextServerId,
62 endpoint: &Url,
63 headers: HashMap<String, String>,
64 http_client: Arc<dyn HttpClient>,
65 executor: gpui::BackgroundExecutor,
66 ) -> Result<Self> {
67 let transport = match endpoint.scheme() {
68 "http" | "https" => {
69 log::info!("Using HTTP transport for {}", endpoint);
70 let transport =
71 HttpTransport::new(http_client, endpoint.to_string(), headers, executor);
72 Arc::new(transport) as _
73 }
74 _ => anyhow::bail!("unsupported MCP url scheme {}", endpoint.scheme()),
75 };
76 Ok(Self::new(id, transport))
77 }
78
79 pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
80 Self {
81 id,
82 client: RwLock::new(None),
83 configuration: ContextServerTransport::Custom(transport),
84 }
85 }
86
87 pub fn id(&self) -> ContextServerId {
88 self.id.clone()
89 }
90
91 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
92 self.client.read().clone()
93 }
94
95 pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
96 self.initialize(self.new_client(cx)?).await
97 }
98
99 /// Starts the context server, making sure handlers are registered before initialization happens
100 pub async fn start_with_handlers(
101 &self,
102 notification_handlers: Vec<(
103 &'static str,
104 Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
105 )>,
106 cx: &AsyncApp,
107 ) -> Result<()> {
108 let client = self.new_client(cx)?;
109 for (method, handler) in notification_handlers {
110 client.on_notification(method, handler);
111 }
112 self.initialize(client).await
113 }
114
115 fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
116 Ok(match &self.configuration {
117 ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
118 client::ContextServerId(self.id.0.clone()),
119 client::ModelContextServerBinary {
120 executable: Path::new(&command.path).to_path_buf(),
121 args: command.args.clone(),
122 env: command.env.clone(),
123 timeout: command.timeout,
124 },
125 working_directory,
126 cx.clone(),
127 )?,
128 ContextServerTransport::Custom(transport) => Client::new(
129 client::ContextServerId(self.id.0.clone()),
130 self.id().0,
131 transport.clone(),
132 None,
133 cx.clone(),
134 )?,
135 })
136 }
137
138 async fn initialize(&self, client: Client) -> Result<()> {
139 log::debug!("starting context server {}", self.id);
140 let protocol = crate::protocol::ModelContextProtocol::new(client);
141 let client_info = types::Implementation {
142 name: "Zed".to_string(),
143 version: env!("CARGO_PKG_VERSION").to_string(),
144 };
145 let initialized_protocol = protocol.initialize(client_info).await?;
146
147 log::debug!(
148 "context server {} initialized: {:?}",
149 self.id,
150 initialized_protocol.initialize,
151 );
152
153 *self.client.write() = Some(Arc::new(initialized_protocol));
154 Ok(())
155 }
156
157 pub fn stop(&self) -> Result<()> {
158 let mut client = self.client.write();
159 if let Some(protocol) = client.take() {
160 drop(protocol);
161 }
162 Ok(())
163 }
164}