1//! This module implements a context server management system for Zed.
2//!
3//! It provides functionality to:
4//! - Define and load context server settings
5//! - Manage individual context servers (start, stop, restart)
6//! - Maintain a global manager for all context servers
7//!
8//! Key components:
9//! - `ContextServerSettings`: Defines the structure for server configurations
10//! - `ContextServer`: Represents an individual context server
11//! - `ContextServerManager`: Manages multiple context servers
12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
13//!
14//! The module also includes initialization logic to set up the context server system
15//! and react to changes in settings.
16
17use std::path::Path;
18use std::pin::Pin;
19use std::sync::Arc;
20
21use anyhow::Result;
22use async_trait::async_trait;
23use collections::{HashMap, HashSet};
24use futures::{Future, FutureExt};
25use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
26use log;
27use parking_lot::RwLock;
28use schemars::JsonSchema;
29use serde::{Deserialize, Serialize};
30use settings::{Settings, SettingsSources};
31
32use crate::{
33 client::{self, Client},
34 types,
35};
36
37#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
38pub struct ContextServerSettings {
39 pub servers: Vec<ServerConfig>,
40}
41
42#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
43pub struct ServerConfig {
44 pub id: String,
45 pub executable: String,
46 pub args: Vec<String>,
47 pub env: Option<HashMap<String, String>>,
48}
49
50impl Settings for ContextServerSettings {
51 const KEY: Option<&'static str> = Some("experimental.context_servers");
52
53 type FileContent = Self;
54
55 fn load(
56 sources: SettingsSources<Self::FileContent>,
57 _: &mut gpui::AppContext,
58 ) -> anyhow::Result<Self> {
59 sources.json_merge()
60 }
61}
62
63#[async_trait(?Send)]
64pub trait ContextServer: Send + Sync + 'static {
65 fn id(&self) -> Arc<str>;
66 fn config(&self) -> Arc<ServerConfig>;
67 fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>>;
68 fn start<'a>(
69 self: Arc<Self>,
70 cx: &'a AsyncAppContext,
71 ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>>;
72 fn stop(&self) -> Result<()>;
73}
74
75pub struct NativeContextServer {
76 pub id: Arc<str>,
77 pub config: Arc<ServerConfig>,
78 pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
79}
80
81impl NativeContextServer {
82 fn new(config: Arc<ServerConfig>) -> Self {
83 Self {
84 id: config.id.clone().into(),
85 config,
86 client: RwLock::new(None),
87 }
88 }
89}
90
91#[async_trait(?Send)]
92impl ContextServer for NativeContextServer {
93 fn id(&self) -> Arc<str> {
94 self.id.clone()
95 }
96
97 fn config(&self) -> Arc<ServerConfig> {
98 self.config.clone()
99 }
100
101 fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
102 self.client.read().clone()
103 }
104
105 fn start<'a>(
106 self: Arc<Self>,
107 cx: &'a AsyncAppContext,
108 ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
109 async move {
110 log::info!("starting context server {}", self.config.id,);
111 let client = Client::new(
112 client::ContextServerId(self.config.id.clone()),
113 client::ModelContextServerBinary {
114 executable: Path::new(&self.config.executable).to_path_buf(),
115 args: self.config.args.clone(),
116 env: self.config.env.clone(),
117 },
118 cx.clone(),
119 )?;
120
121 let protocol = crate::protocol::ModelContextProtocol::new(client);
122 let client_info = types::Implementation {
123 name: "Zed".to_string(),
124 version: env!("CARGO_PKG_VERSION").to_string(),
125 };
126 let initialized_protocol = protocol.initialize(client_info).await?;
127
128 log::debug!(
129 "context server {} initialized: {:?}",
130 self.config.id,
131 initialized_protocol.initialize,
132 );
133
134 *self.client.write() = Some(Arc::new(initialized_protocol));
135 Ok(())
136 }
137 .boxed_local()
138 }
139
140 fn stop(&self) -> Result<()> {
141 let mut client = self.client.write();
142 if let Some(protocol) = client.take() {
143 drop(protocol);
144 }
145 Ok(())
146 }
147}
148
149/// A Context server manager manages the starting and stopping
150/// of all servers. To obtain a server to interact with, a crate
151/// must go through the `GlobalContextServerManager` which holds
152/// a model to the ContextServerManager.
153pub struct ContextServerManager {
154 servers: HashMap<String, Arc<dyn ContextServer>>,
155 pending_servers: HashSet<String>,
156}
157
158pub enum Event {
159 ServerStarted { server_id: String },
160 ServerStopped { server_id: String },
161}
162
163impl EventEmitter<Event> for ContextServerManager {}
164
165impl Default for ContextServerManager {
166 fn default() -> Self {
167 Self::new()
168 }
169}
170
171impl ContextServerManager {
172 pub fn new() -> Self {
173 Self {
174 servers: HashMap::default(),
175 pending_servers: HashSet::default(),
176 }
177 }
178
179 pub fn add_server(
180 &mut self,
181 config: Arc<ServerConfig>,
182 cx: &ModelContext<Self>,
183 ) -> Task<anyhow::Result<()>> {
184 let server_id = config.id.clone();
185
186 if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
187 return Task::ready(Ok(()));
188 }
189
190 let task = {
191 let server_id = server_id.clone();
192 cx.spawn(|this, mut cx| async move {
193 let server = Arc::new(NativeContextServer::new(config));
194 server.clone().start(&cx).await?;
195 this.update(&mut cx, |this, cx| {
196 this.servers.insert(server_id.clone(), server);
197 this.pending_servers.remove(&server_id);
198 cx.emit(Event::ServerStarted {
199 server_id: server_id.clone(),
200 });
201 })?;
202 Ok(())
203 })
204 };
205
206 self.pending_servers.insert(server_id);
207 task
208 }
209
210 pub fn get_server(&self, id: &str) -> Option<Arc<dyn ContextServer>> {
211 self.servers.get(id).cloned()
212 }
213
214 pub fn remove_server(&mut self, id: &str, cx: &ModelContext<Self>) -> Task<anyhow::Result<()>> {
215 let id = id.to_string();
216 cx.spawn(|this, mut cx| async move {
217 if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
218 server.stop()?;
219 }
220 this.update(&mut cx, |this, cx| {
221 this.pending_servers.remove(&id);
222 cx.emit(Event::ServerStopped {
223 server_id: id.clone(),
224 })
225 })?;
226 Ok(())
227 })
228 }
229
230 pub fn restart_server(
231 &mut self,
232 id: &Arc<str>,
233 cx: &mut ModelContext<Self>,
234 ) -> Task<anyhow::Result<()>> {
235 let id = id.to_string();
236 cx.spawn(|this, mut cx| async move {
237 if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
238 server.stop()?;
239 let config = server.config();
240 let new_server = Arc::new(NativeContextServer::new(config));
241 new_server.clone().start(&cx).await?;
242 this.update(&mut cx, |this, cx| {
243 this.servers.insert(id.clone(), new_server);
244 cx.emit(Event::ServerStopped {
245 server_id: id.clone(),
246 });
247 cx.emit(Event::ServerStarted {
248 server_id: id.clone(),
249 });
250 })?;
251 }
252 Ok(())
253 })
254 }
255
256 pub fn servers(&self) -> Vec<Arc<dyn ContextServer>> {
257 self.servers.values().cloned().collect()
258 }
259
260 pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext<Self>) {
261 let current_servers = self
262 .servers()
263 .into_iter()
264 .map(|server| (server.id(), server.config()))
265 .collect::<HashMap<_, _>>();
266
267 let new_servers = settings
268 .servers
269 .iter()
270 .map(|config| (config.id.clone(), config.clone()))
271 .collect::<HashMap<_, _>>();
272
273 let servers_to_add = new_servers
274 .values()
275 .filter(|config| !current_servers.contains_key(config.id.as_str()))
276 .cloned()
277 .collect::<Vec<_>>();
278
279 let servers_to_remove = current_servers
280 .keys()
281 .filter(|id| !new_servers.contains_key(id.as_ref()))
282 .cloned()
283 .collect::<Vec<_>>();
284
285 log::trace!("servers_to_add={:?}", servers_to_add);
286 for config in servers_to_add {
287 self.add_server(Arc::new(config), cx).detach_and_log_err(cx);
288 }
289
290 for id in servers_to_remove {
291 self.remove_server(&id, cx).detach_and_log_err(cx);
292 }
293 }
294}