1//! This module implements a context server management system for Zed.
2//!
3//! It provides functionality to:
4//! - Define and load context server settings
5//! - Manage individual context servers (start, stop, restart)
6//! - Maintain a global manager for all context servers
7//!
8//! Key components:
9//! - `ContextServerSettings`: Defines the structure for server configurations
10//! - `ContextServer`: Represents an individual context server
11//! - `ContextServerManager`: Manages multiple context servers
12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
13//!
14//! The module also includes initialization logic to set up the context server system
15//! and react to changes in settings.
16
17use std::path::Path;
18use std::pin::Pin;
19use std::sync::Arc;
20
21use anyhow::Result;
22use async_trait::async_trait;
23use collections::{HashMap, HashSet};
24use futures::{Future, FutureExt};
25use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
26use log;
27use parking_lot::RwLock;
28use schemars::JsonSchema;
29use serde::{Deserialize, Serialize};
30use settings::{Settings, SettingsSources};
31
32use crate::{
33 client::{self, Client},
34 types,
35};
36
37#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
38pub struct ContextServerSettings {
39 pub servers: Vec<ServerConfig>,
40}
41
42#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
43pub struct ServerConfig {
44 pub id: String,
45 pub executable: String,
46 pub args: Vec<String>,
47 pub env: Option<HashMap<String, String>>,
48}
49
50impl Settings for ContextServerSettings {
51 const KEY: Option<&'static str> = Some("experimental.context_servers");
52
53 type FileContent = Self;
54
55 fn load(
56 sources: SettingsSources<Self::FileContent>,
57 _: &mut gpui::AppContext,
58 ) -> anyhow::Result<Self> {
59 sources.json_merge()
60 }
61}
62
63#[async_trait(?Send)]
64pub trait ContextServer: Send + Sync + 'static {
65 fn id(&self) -> Arc<str>;
66 fn config(&self) -> Arc<ServerConfig>;
67 fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>>;
68 fn start<'a>(
69 self: Arc<Self>,
70 cx: &'a AsyncAppContext,
71 ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>>;
72 fn stop(&self) -> Result<()>;
73}
74
75pub struct NativeContextServer {
76 pub id: Arc<str>,
77 pub config: Arc<ServerConfig>,
78 pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
79}
80
81impl NativeContextServer {
82 pub fn new(config: Arc<ServerConfig>) -> Self {
83 Self {
84 id: config.id.clone().into(),
85 config,
86 client: RwLock::new(None),
87 }
88 }
89}
90
91#[async_trait(?Send)]
92impl ContextServer for NativeContextServer {
93 fn id(&self) -> Arc<str> {
94 self.id.clone()
95 }
96
97 fn config(&self) -> Arc<ServerConfig> {
98 self.config.clone()
99 }
100
101 fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
102 self.client.read().clone()
103 }
104
105 fn start<'a>(
106 self: Arc<Self>,
107 cx: &'a AsyncAppContext,
108 ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
109 async move {
110 log::info!("starting context server {}", self.config.id,);
111 let client = Client::new(
112 client::ContextServerId(self.config.id.clone()),
113 client::ModelContextServerBinary {
114 executable: Path::new(&self.config.executable).to_path_buf(),
115 args: self.config.args.clone(),
116 env: self.config.env.clone(),
117 },
118 cx.clone(),
119 )?;
120
121 let protocol = crate::protocol::ModelContextProtocol::new(client);
122 let client_info = types::Implementation {
123 name: "Zed".to_string(),
124 version: env!("CARGO_PKG_VERSION").to_string(),
125 };
126 let initialized_protocol = protocol.initialize(client_info).await?;
127
128 log::debug!(
129 "context server {} initialized: {:?}",
130 self.config.id,
131 initialized_protocol.initialize,
132 );
133
134 *self.client.write() = Some(Arc::new(initialized_protocol));
135 Ok(())
136 }
137 .boxed_local()
138 }
139
140 fn stop(&self) -> Result<()> {
141 let mut client = self.client.write();
142 if let Some(protocol) = client.take() {
143 drop(protocol);
144 }
145 Ok(())
146 }
147}
148
149/// A Context server manager manages the starting and stopping
150/// of all servers. To obtain a server to interact with, a crate
151/// must go through the `GlobalContextServerManager` which holds
152/// a model to the ContextServerManager.
153pub struct ContextServerManager {
154 servers: HashMap<Arc<str>, Arc<dyn ContextServer>>,
155 pending_servers: HashSet<Arc<str>>,
156}
157
158pub enum Event {
159 ServerStarted { server_id: Arc<str> },
160 ServerStopped { server_id: Arc<str> },
161}
162
163impl EventEmitter<Event> for ContextServerManager {}
164
165impl Default for ContextServerManager {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171impl ContextServerManager {
172 pub fn new() -> Self {
173 Self {
174 servers: HashMap::default(),
175 pending_servers: HashSet::default(),
176 }
177 }
178
179 pub fn add_server(
180 &mut self,
181 server: Arc<dyn ContextServer>,
182 cx: &ModelContext<Self>,
183 ) -> Task<anyhow::Result<()>> {
184 let server_id = server.id();
185
186 if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
187 return Task::ready(Ok(()));
188 }
189
190 let task = {
191 let server_id = server_id.clone();
192 cx.spawn(|this, mut cx| async move {
193 server.clone().start(&cx).await?;
194 this.update(&mut cx, |this, cx| {
195 this.servers.insert(server_id.clone(), server);
196 this.pending_servers.remove(&server_id);
197 cx.emit(Event::ServerStarted {
198 server_id: server_id.clone(),
199 });
200 })?;
201 Ok(())
202 })
203 };
204
205 self.pending_servers.insert(server_id);
206 task
207 }
208
209 pub fn get_server(&self, id: &str) -> Option<Arc<dyn ContextServer>> {
210 self.servers.get(id).cloned()
211 }
212
213 pub fn remove_server(
214 &mut self,
215 id: &Arc<str>,
216 cx: &ModelContext<Self>,
217 ) -> Task<anyhow::Result<()>> {
218 let id = id.clone();
219 cx.spawn(|this, mut cx| async move {
220 if let Some(server) =
221 this.update(&mut cx, |this, _cx| this.servers.remove(id.as_ref()))?
222 {
223 server.stop()?;
224 }
225 this.update(&mut cx, |this, cx| {
226 this.pending_servers.remove(id.as_ref());
227 cx.emit(Event::ServerStopped {
228 server_id: id.clone(),
229 })
230 })?;
231 Ok(())
232 })
233 }
234
235 pub fn restart_server(
236 &mut self,
237 id: &Arc<str>,
238 cx: &mut ModelContext<Self>,
239 ) -> Task<anyhow::Result<()>> {
240 let id = id.clone();
241 cx.spawn(|this, mut cx| async move {
242 if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
243 server.stop()?;
244 let config = server.config();
245 let new_server = Arc::new(NativeContextServer::new(config));
246 new_server.clone().start(&cx).await?;
247 this.update(&mut cx, |this, cx| {
248 this.servers.insert(id.clone(), new_server);
249 cx.emit(Event::ServerStopped {
250 server_id: id.clone(),
251 });
252 cx.emit(Event::ServerStarted {
253 server_id: id.clone(),
254 });
255 })?;
256 }
257 Ok(())
258 })
259 }
260
261 pub fn servers(&self) -> Vec<Arc<dyn ContextServer>> {
262 self.servers.values().cloned().collect()
263 }
264
265 pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext<Self>) {
266 let current_servers = self
267 .servers()
268 .into_iter()
269 .map(|server| (server.id(), server.config()))
270 .collect::<HashMap<_, _>>();
271
272 let new_servers = settings
273 .servers
274 .iter()
275 .map(|config| (config.id.clone(), config.clone()))
276 .collect::<HashMap<_, _>>();
277
278 let servers_to_add = new_servers
279 .values()
280 .filter(|config| !current_servers.contains_key(config.id.as_str()))
281 .cloned()
282 .collect::<Vec<_>>();
283
284 let servers_to_remove = current_servers
285 .keys()
286 .filter(|id| !new_servers.contains_key(id.as_ref()))
287 .cloned()
288 .collect::<Vec<_>>();
289
290 log::trace!("servers_to_add={:?}", servers_to_add);
291 for config in servers_to_add {
292 let server = Arc::new(NativeContextServer::new(Arc::new(config)));
293 self.add_server(server, cx).detach_and_log_err(cx);
294 }
295
296 for id in servers_to_remove {
297 self.remove_server(&id, cx).detach_and_log_err(cx);
298 }
299 }
300}