1//! This module implements a context server management system for Zed.
2//!
3//! It provides functionality to:
4//! - Define and load context server settings
5//! - Manage individual context servers (start, stop, restart)
6//! - Maintain a global manager for all context servers
7//!
8//! Key components:
9//! - `ContextServerSettings`: Defines the structure for server configurations
10//! - `ContextServer`: Represents an individual context server
11//! - `ContextServerManager`: Manages multiple context servers
12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
13//!
14//! The module also includes initialization logic to set up the context server system
15//! and react to changes in settings.
16
17use std::path::Path;
18use std::pin::Pin;
19use std::sync::Arc;
20
21use anyhow::{bail, Result};
22use async_trait::async_trait;
23use collections::{HashMap, HashSet};
24use futures::{Future, FutureExt};
25use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
26use log;
27use parking_lot::RwLock;
28use schemars::JsonSchema;
29use serde::{Deserialize, Serialize};
30use settings::{Settings, SettingsSources};
31
32use crate::{
33 client::{self, Client},
34 types,
35};
36
37#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
38pub struct ContextServerSettings {
39 #[serde(default)]
40 pub context_servers: HashMap<Arc<str>, ServerConfig>,
41}
42
43#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
44pub struct ServerConfig {
45 pub command: Option<ServerCommand>,
46 pub settings: Option<serde_json::Value>,
47}
48
49#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
50pub struct ServerCommand {
51 pub path: String,
52 pub args: Vec<String>,
53 pub env: Option<HashMap<String, String>>,
54}
55
56impl Settings for ContextServerSettings {
57 const KEY: Option<&'static str> = None;
58
59 type FileContent = Self;
60
61 fn load(
62 sources: SettingsSources<Self::FileContent>,
63 _: &mut gpui::AppContext,
64 ) -> anyhow::Result<Self> {
65 sources.json_merge()
66 }
67}
68
69#[async_trait(?Send)]
70pub trait ContextServer: Send + Sync + 'static {
71 fn id(&self) -> Arc<str>;
72 fn config(&self) -> Arc<ServerConfig>;
73 fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>>;
74 fn start<'a>(
75 self: Arc<Self>,
76 cx: &'a AsyncAppContext,
77 ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>>;
78 fn stop(&self) -> Result<()>;
79}
80
81pub struct NativeContextServer {
82 pub id: Arc<str>,
83 pub config: Arc<ServerConfig>,
84 pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
85}
86
87impl NativeContextServer {
88 pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
89 Self {
90 id,
91 config,
92 client: RwLock::new(None),
93 }
94 }
95}
96
97#[async_trait(?Send)]
98impl ContextServer for NativeContextServer {
99 fn id(&self) -> Arc<str> {
100 self.id.clone()
101 }
102
103 fn config(&self) -> Arc<ServerConfig> {
104 self.config.clone()
105 }
106
107 fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
108 self.client.read().clone()
109 }
110
111 fn start<'a>(
112 self: Arc<Self>,
113 cx: &'a AsyncAppContext,
114 ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
115 async move {
116 log::info!("starting context server {}", self.id);
117 let Some(command) = &self.config.command else {
118 bail!("no command specified for server {}", self.id);
119 };
120 let client = Client::new(
121 client::ContextServerId(self.id.clone()),
122 client::ModelContextServerBinary {
123 executable: Path::new(&command.path).to_path_buf(),
124 args: command.args.clone(),
125 env: command.env.clone(),
126 },
127 cx.clone(),
128 )?;
129
130 let protocol = crate::protocol::ModelContextProtocol::new(client);
131 let client_info = types::Implementation {
132 name: "Zed".to_string(),
133 version: env!("CARGO_PKG_VERSION").to_string(),
134 };
135 let initialized_protocol = protocol.initialize(client_info).await?;
136
137 log::debug!(
138 "context server {} initialized: {:?}",
139 self.id,
140 initialized_protocol.initialize,
141 );
142
143 *self.client.write() = Some(Arc::new(initialized_protocol));
144 Ok(())
145 }
146 .boxed_local()
147 }
148
149 fn stop(&self) -> Result<()> {
150 let mut client = self.client.write();
151 if let Some(protocol) = client.take() {
152 drop(protocol);
153 }
154 Ok(())
155 }
156}
157
158/// A Context server manager manages the starting and stopping
159/// of all servers. To obtain a server to interact with, a crate
160/// must go through the `GlobalContextServerManager` which holds
161/// a model to the ContextServerManager.
162pub struct ContextServerManager {
163 servers: HashMap<Arc<str>, Arc<dyn ContextServer>>,
164 pending_servers: HashSet<Arc<str>>,
165}
166
167pub enum Event {
168 ServerStarted { server_id: Arc<str> },
169 ServerStopped { server_id: Arc<str> },
170}
171
172impl EventEmitter<Event> for ContextServerManager {}
173
174impl Default for ContextServerManager {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180impl ContextServerManager {
181 pub fn new() -> Self {
182 Self {
183 servers: HashMap::default(),
184 pending_servers: HashSet::default(),
185 }
186 }
187
188 pub fn add_server(
189 &mut self,
190 server: Arc<dyn ContextServer>,
191 cx: &ModelContext<Self>,
192 ) -> Task<anyhow::Result<()>> {
193 let server_id = server.id();
194
195 if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
196 return Task::ready(Ok(()));
197 }
198
199 let task = {
200 let server_id = server_id.clone();
201 cx.spawn(|this, mut cx| async move {
202 server.clone().start(&cx).await?;
203 this.update(&mut cx, |this, cx| {
204 this.servers.insert(server_id.clone(), server);
205 this.pending_servers.remove(&server_id);
206 cx.emit(Event::ServerStarted {
207 server_id: server_id.clone(),
208 });
209 })?;
210 Ok(())
211 })
212 };
213
214 self.pending_servers.insert(server_id);
215 task
216 }
217
218 pub fn get_server(&self, id: &str) -> Option<Arc<dyn ContextServer>> {
219 self.servers.get(id).cloned()
220 }
221
222 pub fn remove_server(
223 &mut self,
224 id: &Arc<str>,
225 cx: &ModelContext<Self>,
226 ) -> Task<anyhow::Result<()>> {
227 let id = id.clone();
228 cx.spawn(|this, mut cx| async move {
229 if let Some(server) =
230 this.update(&mut cx, |this, _cx| this.servers.remove(id.as_ref()))?
231 {
232 server.stop()?;
233 }
234 this.update(&mut cx, |this, cx| {
235 this.pending_servers.remove(id.as_ref());
236 cx.emit(Event::ServerStopped {
237 server_id: id.clone(),
238 })
239 })?;
240 Ok(())
241 })
242 }
243
244 pub fn restart_server(
245 &mut self,
246 id: &Arc<str>,
247 cx: &mut ModelContext<Self>,
248 ) -> Task<anyhow::Result<()>> {
249 let id = id.clone();
250 cx.spawn(|this, mut cx| async move {
251 if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
252 server.stop()?;
253 let config = server.config();
254 let new_server = Arc::new(NativeContextServer::new(id.clone(), config));
255 new_server.clone().start(&cx).await?;
256 this.update(&mut cx, |this, cx| {
257 this.servers.insert(id.clone(), new_server);
258 cx.emit(Event::ServerStopped {
259 server_id: id.clone(),
260 });
261 cx.emit(Event::ServerStarted {
262 server_id: id.clone(),
263 });
264 })?;
265 }
266 Ok(())
267 })
268 }
269
270 pub fn servers(&self) -> Vec<Arc<dyn ContextServer>> {
271 self.servers.values().cloned().collect()
272 }
273
274 pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext<Self>) {
275 let current_servers = self
276 .servers()
277 .into_iter()
278 .map(|server| (server.id(), server.config()))
279 .collect::<HashMap<_, _>>();
280
281 let new_servers = settings
282 .context_servers
283 .iter()
284 .map(|(id, config)| (id.clone(), config.clone()))
285 .collect::<HashMap<_, _>>();
286
287 let servers_to_add = new_servers
288 .iter()
289 .filter(|(id, _)| !current_servers.contains_key(id.as_ref()))
290 .map(|(id, config)| (id.clone(), config.clone()))
291 .collect::<Vec<_>>();
292
293 let servers_to_remove = current_servers
294 .keys()
295 .filter(|id| !new_servers.contains_key(id.as_ref()))
296 .cloned()
297 .collect::<Vec<_>>();
298
299 log::trace!("servers_to_add={:?}", servers_to_add);
300 for (id, config) in servers_to_add {
301 if config.command.is_some() {
302 let server = Arc::new(NativeContextServer::new(id, Arc::new(config)));
303 self.add_server(server, cx).detach_and_log_err(cx);
304 }
305 }
306
307 for id in servers_to_remove {
308 self.remove_server(&id, cx).detach_and_log_err(cx);
309 }
310 }
311}