1pub mod client;
2pub mod listener;
3pub mod protocol;
4#[cfg(any(test, feature = "test-support"))]
5pub mod test;
6pub mod transport;
7pub mod types;
8
9use std::path::Path;
10use std::sync::Arc;
11use std::{fmt::Display, path::PathBuf};
12
13use anyhow::Result;
14use client::Client;
15use collections::HashMap;
16use gpui::AsyncApp;
17use parking_lot::RwLock;
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use util::redact::should_redact;
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub struct ContextServerId(pub Arc<str>);
24
25impl Display for ContextServerId {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 write!(f, "{}", self.0)
28 }
29}
30
31#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
32pub struct ContextServerCommand {
33 #[serde(rename = "command")]
34 pub path: PathBuf,
35 pub args: Vec<String>,
36 pub env: Option<HashMap<String, String>>,
37 /// Timeout for tool calls in milliseconds. Defaults to 60000 (60 seconds) if not specified.
38 pub timeout: Option<u64>,
39}
40
41impl std::fmt::Debug for ContextServerCommand {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 let filtered_env = self.env.as_ref().map(|env| {
44 env.iter()
45 .map(|(k, v)| (k, if should_redact(k) { "[REDACTED]" } else { v }))
46 .collect::<Vec<_>>()
47 });
48
49 f.debug_struct("ContextServerCommand")
50 .field("path", &self.path)
51 .field("args", &self.args)
52 .field("env", &filtered_env)
53 .finish()
54 }
55}
56
57enum ContextServerTransport {
58 Stdio(ContextServerCommand, Option<PathBuf>),
59 Custom(Arc<dyn crate::transport::Transport>),
60}
61
62pub struct ContextServer {
63 id: ContextServerId,
64 client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
65 configuration: ContextServerTransport,
66}
67
68impl ContextServer {
69 pub fn stdio(
70 id: ContextServerId,
71 command: ContextServerCommand,
72 working_directory: Option<Arc<Path>>,
73 ) -> Self {
74 Self {
75 id,
76 client: RwLock::new(None),
77 configuration: ContextServerTransport::Stdio(
78 command,
79 working_directory.map(|directory| directory.to_path_buf()),
80 ),
81 }
82 }
83
84 pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
85 Self {
86 id,
87 client: RwLock::new(None),
88 configuration: ContextServerTransport::Custom(transport),
89 }
90 }
91
92 pub fn id(&self) -> ContextServerId {
93 self.id.clone()
94 }
95
96 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
97 self.client.read().clone()
98 }
99
100 pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
101 self.initialize(self.new_client(cx)?).await
102 }
103
104 /// Starts the context server, making sure handlers are registered before initialization happens
105 pub async fn start_with_handlers(
106 &self,
107 notification_handlers: Vec<(
108 &'static str,
109 Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
110 )>,
111 cx: &AsyncApp,
112 ) -> Result<()> {
113 let client = self.new_client(cx)?;
114 for (method, handler) in notification_handlers {
115 client.on_notification(method, handler);
116 }
117 self.initialize(client).await
118 }
119
120 fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
121 Ok(match &self.configuration {
122 ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
123 client::ContextServerId(self.id.0.clone()),
124 client::ModelContextServerBinary {
125 executable: Path::new(&command.path).to_path_buf(),
126 args: command.args.clone(),
127 env: command.env.clone(),
128 timeout: command.timeout,
129 },
130 working_directory,
131 cx.clone(),
132 )?,
133 ContextServerTransport::Custom(transport) => Client::new(
134 client::ContextServerId(self.id.0.clone()),
135 self.id().0,
136 transport.clone(),
137 None,
138 cx.clone(),
139 )?,
140 })
141 }
142
143 async fn initialize(&self, client: Client) -> Result<()> {
144 log::debug!("starting context server {}", self.id);
145 let protocol = crate::protocol::ModelContextProtocol::new(client);
146 let client_info = types::Implementation {
147 name: "Zed".to_string(),
148 version: env!("CARGO_PKG_VERSION").to_string(),
149 };
150 let initialized_protocol = protocol.initialize(client_info).await?;
151
152 log::debug!(
153 "context server {} initialized: {:?}",
154 self.id,
155 initialized_protocol.initialize,
156 );
157
158 *self.client.write() = Some(Arc::new(initialized_protocol));
159 Ok(())
160 }
161
162 pub fn stop(&self) -> Result<()> {
163 let mut client = self.client.write();
164 if let Some(protocol) = client.take() {
165 drop(protocol);
166 }
167 Ok(())
168 }
169}