1pub mod client;
2pub mod listener;
3pub mod protocol;
4#[cfg(any(test, feature = "test-support"))]
5pub mod test;
6pub mod transport;
7pub mod types;
8
9use collections::HashMap;
10use http_client::HttpClient;
11use std::path::Path;
12use std::sync::Arc;
13use std::time::Duration;
14use std::{fmt::Display, path::PathBuf};
15
16use anyhow::Result;
17use client::Client;
18use gpui::AsyncApp;
19use parking_lot::RwLock;
20pub use settings::ContextServerCommand;
21use url::Url;
22
23use crate::transport::HttpTransport;
24
25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
26pub struct ContextServerId(pub Arc<str>);
27
28impl Display for ContextServerId {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 write!(f, "{}", self.0)
31 }
32}
33
34enum ContextServerTransport {
35 Stdio(ContextServerCommand, Option<PathBuf>),
36 Custom(Arc<dyn crate::transport::Transport>),
37}
38
39pub struct ContextServer {
40 id: ContextServerId,
41 client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
42 configuration: ContextServerTransport,
43 request_timeout: Option<Duration>,
44}
45
46impl ContextServer {
47 pub fn stdio(
48 id: ContextServerId,
49 command: ContextServerCommand,
50 working_directory: Option<Arc<Path>>,
51 ) -> Self {
52 Self {
53 id,
54 client: RwLock::new(None),
55 configuration: ContextServerTransport::Stdio(
56 command,
57 working_directory.map(|directory| directory.to_path_buf()),
58 ),
59 request_timeout: None, // Stdio handles timeout through command
60 }
61 }
62
63 pub fn http(
64 id: ContextServerId,
65 endpoint: &Url,
66 headers: HashMap<String, String>,
67 http_client: Arc<dyn HttpClient>,
68 executor: gpui::BackgroundExecutor,
69 request_timeout: Option<Duration>,
70 ) -> Result<Self> {
71 let transport = match endpoint.scheme() {
72 "http" | "https" => {
73 log::info!("Using HTTP transport for {}", endpoint);
74 let transport =
75 HttpTransport::new(http_client, endpoint.to_string(), headers, executor);
76 Arc::new(transport) as _
77 }
78 _ => anyhow::bail!("unsupported MCP url scheme {}", endpoint.scheme()),
79 };
80 Ok(Self::new_with_timeout(id, transport, request_timeout))
81 }
82
83 pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
84 Self::new_with_timeout(id, transport, None)
85 }
86
87 pub fn new_with_timeout(
88 id: ContextServerId,
89 transport: Arc<dyn crate::transport::Transport>,
90 request_timeout: Option<Duration>,
91 ) -> Self {
92 Self {
93 id,
94 client: RwLock::new(None),
95 configuration: ContextServerTransport::Custom(transport),
96 request_timeout,
97 }
98 }
99
100 pub fn id(&self) -> ContextServerId {
101 self.id.clone()
102 }
103
104 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
105 self.client.read().clone()
106 }
107
108 pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
109 self.initialize(self.new_client(cx)?).await
110 }
111
112 fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
113 Ok(match &self.configuration {
114 ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
115 client::ContextServerId(self.id.0.clone()),
116 client::ModelContextServerBinary {
117 executable: Path::new(&command.path).to_path_buf(),
118 args: command.args.clone(),
119 env: command.env.clone(),
120 timeout: command.timeout,
121 },
122 working_directory,
123 cx.clone(),
124 )?,
125 ContextServerTransport::Custom(transport) => Client::new(
126 client::ContextServerId(self.id.0.clone()),
127 self.id().0,
128 transport.clone(),
129 self.request_timeout,
130 cx.clone(),
131 )?,
132 })
133 }
134
135 async fn initialize(&self, client: Client) -> Result<()> {
136 log::debug!("starting context server {}", self.id);
137 let protocol = crate::protocol::ModelContextProtocol::new(client);
138 let client_info = types::Implementation {
139 name: "Zed".to_string(),
140 version: env!("CARGO_PKG_VERSION").to_string(),
141 };
142 let initialized_protocol = protocol.initialize(client_info).await?;
143
144 log::debug!(
145 "context server {} initialized: {:?}",
146 self.id,
147 initialized_protocol.initialize,
148 );
149
150 *self.client.write() = Some(Arc::new(initialized_protocol));
151 Ok(())
152 }
153
154 pub fn stop(&self) -> Result<()> {
155 let mut client = self.client.write();
156 if let Some(protocol) = client.take() {
157 drop(protocol);
158 }
159 Ok(())
160 }
161}