1pub mod client;
2pub mod listener;
3pub mod oauth;
4pub mod protocol;
5#[cfg(any(test, feature = "test-support"))]
6pub mod test;
7pub mod transport;
8pub mod types;
9
10use collections::HashMap;
11use http_client::HttpClient;
12use std::path::Path;
13use std::sync::Arc;
14use std::time::Duration;
15use std::{fmt::Display, path::PathBuf};
16
17use anyhow::Result;
18use client::Client;
19use gpui::AsyncApp;
20use parking_lot::RwLock;
21pub use settings::ContextServerCommand;
22use url::Url;
23
24use crate::transport::HttpTransport;
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27pub struct ContextServerId(pub Arc<str>);
28
29impl Display for ContextServerId {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "{}", self.0)
32 }
33}
34
35enum ContextServerTransport {
36 Stdio(ContextServerCommand, Option<PathBuf>),
37 Custom(Arc<dyn crate::transport::Transport>),
38}
39
40pub struct ContextServer {
41 id: ContextServerId,
42 client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
43 configuration: ContextServerTransport,
44 request_timeout: Option<Duration>,
45}
46
47impl ContextServer {
48 pub fn stdio(
49 id: ContextServerId,
50 command: ContextServerCommand,
51 working_directory: Option<Arc<Path>>,
52 ) -> Self {
53 Self {
54 id,
55 client: RwLock::new(None),
56 configuration: ContextServerTransport::Stdio(
57 command,
58 working_directory.map(|directory| directory.to_path_buf()),
59 ),
60 request_timeout: None,
61 }
62 }
63
64 pub fn http(
65 id: ContextServerId,
66 endpoint: &Url,
67 headers: HashMap<String, String>,
68 http_client: Arc<dyn HttpClient>,
69 executor: gpui::BackgroundExecutor,
70 request_timeout: Option<Duration>,
71 ) -> Result<Self> {
72 let transport = match endpoint.scheme() {
73 "http" | "https" => {
74 log::info!("Using HTTP transport for {}", endpoint);
75 let transport =
76 HttpTransport::new(http_client, endpoint.to_string(), headers, executor);
77 Arc::new(transport) as _
78 }
79 _ => anyhow::bail!("unsupported MCP url scheme {}", endpoint.scheme()),
80 };
81 Ok(Self::new_with_timeout(id, transport, request_timeout))
82 }
83
84 pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
85 Self::new_with_timeout(id, transport, None)
86 }
87
88 pub fn new_with_timeout(
89 id: ContextServerId,
90 transport: Arc<dyn crate::transport::Transport>,
91 request_timeout: Option<Duration>,
92 ) -> Self {
93 Self {
94 id,
95 client: RwLock::new(None),
96 configuration: ContextServerTransport::Custom(transport),
97 request_timeout,
98 }
99 }
100
101 pub fn id(&self) -> ContextServerId {
102 self.id.clone()
103 }
104
105 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
106 self.client.read().clone()
107 }
108
109 pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
110 self.initialize(self.new_client(cx)?).await
111 }
112
113 fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
114 Ok(match &self.configuration {
115 ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
116 client::ContextServerId(self.id.0.clone()),
117 client::ModelContextServerBinary {
118 executable: Path::new(&command.path).to_path_buf(),
119 args: command.args.clone(),
120 env: command.env.clone(),
121 timeout: command.timeout,
122 },
123 working_directory,
124 cx.clone(),
125 )?,
126 ContextServerTransport::Custom(transport) => Client::new(
127 client::ContextServerId(self.id.0.clone()),
128 self.id().0,
129 transport.clone(),
130 self.request_timeout,
131 cx.clone(),
132 )?,
133 })
134 }
135
136 async fn initialize(&self, client: Client) -> Result<()> {
137 log::debug!("starting context server {}", self.id);
138 let protocol = crate::protocol::ModelContextProtocol::new(client);
139 let client_info = types::Implementation {
140 name: "Zed".to_string(),
141 version: env!("CARGO_PKG_VERSION").to_string(),
142 };
143 let initialized_protocol = protocol.initialize(client_info).await?;
144
145 log::debug!(
146 "context server {} initialized: {:?}",
147 self.id,
148 initialized_protocol.initialize,
149 );
150
151 *self.client.write() = Some(Arc::new(initialized_protocol));
152 Ok(())
153 }
154
155 pub fn stop(&self) -> Result<()> {
156 let mut client = self.client.write();
157 if let Some(protocol) = client.take() {
158 drop(protocol);
159 }
160 Ok(())
161 }
162}