1pub mod client;
2pub mod listener;
3pub mod protocol;
4#[cfg(any(test, feature = "test-support"))]
5pub mod test;
6pub mod transport;
7pub mod types;
8
9use std::path::Path;
10use std::sync::Arc;
11use std::{fmt::Display, path::PathBuf};
12
13use anyhow::Result;
14use client::Client;
15use collections::HashMap;
16use gpui::AsyncApp;
17use parking_lot::RwLock;
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use util::redact::should_redact;
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23pub struct ContextServerId(pub Arc<str>);
24
25impl Display for ContextServerId {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 write!(f, "{}", self.0)
28 }
29}
30
31#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema)]
32pub struct ContextServerCommand {
33 #[serde(rename = "command")]
34 pub path: PathBuf,
35 pub args: Vec<String>,
36 pub env: Option<HashMap<String, String>>,
37}
38
39impl std::fmt::Debug for ContextServerCommand {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 let filtered_env = self.env.as_ref().map(|env| {
42 env.iter()
43 .map(|(k, v)| (k, if should_redact(k) { "[REDACTED]" } else { v }))
44 .collect::<Vec<_>>()
45 });
46
47 f.debug_struct("ContextServerCommand")
48 .field("path", &self.path)
49 .field("args", &self.args)
50 .field("env", &filtered_env)
51 .finish()
52 }
53}
54
55enum ContextServerTransport {
56 Stdio(ContextServerCommand, Option<PathBuf>),
57 Custom(Arc<dyn crate::transport::Transport>),
58}
59
60pub struct ContextServer {
61 id: ContextServerId,
62 client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
63 configuration: ContextServerTransport,
64}
65
66impl ContextServer {
67 pub fn stdio(
68 id: ContextServerId,
69 command: ContextServerCommand,
70 working_directory: Option<Arc<Path>>,
71 ) -> Self {
72 Self {
73 id,
74 client: RwLock::new(None),
75 configuration: ContextServerTransport::Stdio(
76 command,
77 working_directory.map(|directory| directory.to_path_buf()),
78 ),
79 }
80 }
81
82 pub fn new(id: ContextServerId, transport: Arc<dyn crate::transport::Transport>) -> Self {
83 Self {
84 id,
85 client: RwLock::new(None),
86 configuration: ContextServerTransport::Custom(transport),
87 }
88 }
89
90 pub fn id(&self) -> ContextServerId {
91 self.id.clone()
92 }
93
94 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
95 self.client.read().clone()
96 }
97
98 pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
99 self.initialize(self.new_client(cx)?).await
100 }
101
102 /// Starts the context server, making sure handlers are registered before initialization happens
103 pub async fn start_with_handlers(
104 &self,
105 notification_handlers: Vec<(
106 &'static str,
107 Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
108 )>,
109 cx: &AsyncApp,
110 ) -> Result<()> {
111 let client = self.new_client(cx)?;
112 for (method, handler) in notification_handlers {
113 client.on_notification(method, handler);
114 }
115 self.initialize(client).await
116 }
117
118 fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
119 Ok(match &self.configuration {
120 ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
121 client::ContextServerId(self.id.0.clone()),
122 client::ModelContextServerBinary {
123 executable: Path::new(&command.path).to_path_buf(),
124 args: command.args.clone(),
125 env: command.env.clone(),
126 },
127 working_directory,
128 cx.clone(),
129 )?,
130 ContextServerTransport::Custom(transport) => Client::new(
131 client::ContextServerId(self.id.0.clone()),
132 self.id().0,
133 transport.clone(),
134 cx.clone(),
135 )?,
136 })
137 }
138
139 async fn initialize(&self, client: Client) -> Result<()> {
140 log::debug!("starting context server {}", self.id);
141 let protocol = crate::protocol::ModelContextProtocol::new(client);
142 let client_info = types::Implementation {
143 name: "Zed".to_string(),
144 version: env!("CARGO_PKG_VERSION").to_string(),
145 };
146 let initialized_protocol = protocol.initialize(client_info).await?;
147
148 log::debug!(
149 "context server {} initialized: {:?}",
150 self.id,
151 initialized_protocol.initialize,
152 );
153
154 *self.client.write() = Some(Arc::new(initialized_protocol));
155 Ok(())
156 }
157
158 pub fn stop(&self) -> Result<()> {
159 let mut client = self.client.write();
160 if let Some(protocol) = client.take() {
161 drop(protocol);
162 }
163 Ok(())
164 }
165}