1//! This module implements a context server management system for Zed.
2//!
3//! It provides functionality to:
4//! - Define and load context server settings
5//! - Manage individual context servers (start, stop, restart)
6//! - Maintain a global manager for all context servers
7//!
8//! Key components:
9//! - `ContextServerSettings`: Defines the structure for server configurations
10//! - `ContextServer`: Represents an individual context server
11//! - `ContextServerManager`: Manages multiple context servers
12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
13//!
14//! The module also includes initialization logic to set up the context server system
15//! and react to changes in settings.
16
17use std::path::Path;
18use std::sync::Arc;
19
20use anyhow::{Result, bail};
21use collections::HashMap;
22use command_palette_hooks::CommandPaletteFilter;
23use gpui::{AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
24use log;
25use parking_lot::RwLock;
26use project::Project;
27use settings::{Settings, SettingsStore};
28use util::ResultExt as _;
29
30use crate::transport::Transport;
31use crate::{ContextServerSettings, ServerConfig};
32
33use crate::{
34 CONTEXT_SERVERS_NAMESPACE, ContextServerDescriptorRegistry,
35 client::{self, Client},
36 types,
37};
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub enum ContextServerStatus {
41 Starting,
42 Running,
43 Error(Arc<str>),
44}
45
46pub struct ContextServer {
47 pub id: Arc<str>,
48 pub config: Arc<ServerConfig>,
49 pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
50 transport: Option<Arc<dyn Transport>>,
51}
52
53impl ContextServer {
54 pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
55 Self {
56 id,
57 config,
58 client: RwLock::new(None),
59 transport: None,
60 }
61 }
62
63 #[cfg(any(test, feature = "test-support"))]
64 pub fn test(id: Arc<str>, transport: Arc<dyn crate::transport::Transport>) -> Arc<Self> {
65 Arc::new(Self {
66 id,
67 client: RwLock::new(None),
68 config: Arc::new(ServerConfig::default()),
69 transport: Some(transport),
70 })
71 }
72
73 pub fn id(&self) -> Arc<str> {
74 self.id.clone()
75 }
76
77 pub fn config(&self) -> Arc<ServerConfig> {
78 self.config.clone()
79 }
80
81 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
82 self.client.read().clone()
83 }
84
85 pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
86 let client = if let Some(transport) = self.transport.clone() {
87 Client::new(
88 client::ContextServerId(self.id.clone()),
89 self.id(),
90 transport,
91 cx.clone(),
92 )?
93 } else {
94 let Some(command) = &self.config.command else {
95 bail!("no command specified for server {}", self.id);
96 };
97 Client::stdio(
98 client::ContextServerId(self.id.clone()),
99 client::ModelContextServerBinary {
100 executable: Path::new(&command.path).to_path_buf(),
101 args: command.args.clone(),
102 env: command.env.clone(),
103 },
104 cx.clone(),
105 )?
106 };
107 self.initialize(client).await
108 }
109
110 async fn initialize(&self, client: Client) -> Result<()> {
111 log::info!("starting context server {}", self.id);
112 let protocol = crate::protocol::ModelContextProtocol::new(client);
113 let client_info = types::Implementation {
114 name: "Zed".to_string(),
115 version: env!("CARGO_PKG_VERSION").to_string(),
116 };
117 let initialized_protocol = protocol.initialize(client_info).await?;
118
119 log::debug!(
120 "context server {} initialized: {:?}",
121 self.id,
122 initialized_protocol.initialize,
123 );
124
125 *self.client.write() = Some(Arc::new(initialized_protocol));
126 Ok(())
127 }
128
129 pub fn stop(&self) -> Result<()> {
130 let mut client = self.client.write();
131 if let Some(protocol) = client.take() {
132 drop(protocol);
133 }
134 Ok(())
135 }
136}
137
138pub struct ContextServerManager {
139 servers: HashMap<Arc<str>, Arc<ContextServer>>,
140 server_status: HashMap<Arc<str>, ContextServerStatus>,
141 project: Entity<Project>,
142 registry: Entity<ContextServerDescriptorRegistry>,
143 update_servers_task: Option<Task<Result<()>>>,
144 needs_server_update: bool,
145 _subscriptions: Vec<Subscription>,
146}
147
148pub enum Event {
149 ServerStatusChanged {
150 server_id: Arc<str>,
151 status: Option<ContextServerStatus>,
152 },
153}
154
155impl EventEmitter<Event> for ContextServerManager {}
156
157impl ContextServerManager {
158 pub fn new(
159 registry: Entity<ContextServerDescriptorRegistry>,
160 project: Entity<Project>,
161 cx: &mut Context<Self>,
162 ) -> Self {
163 let mut this = Self {
164 _subscriptions: vec![
165 cx.observe(®istry, |this, _registry, cx| {
166 this.available_context_servers_changed(cx);
167 }),
168 cx.observe_global::<SettingsStore>(|this, cx| {
169 this.available_context_servers_changed(cx);
170 }),
171 ],
172 project,
173 registry,
174 needs_server_update: false,
175 servers: HashMap::default(),
176 server_status: HashMap::default(),
177 update_servers_task: None,
178 };
179 this.available_context_servers_changed(cx);
180 this
181 }
182
183 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
184 if self.update_servers_task.is_some() {
185 self.needs_server_update = true;
186 } else {
187 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
188 this.update(cx, |this, _| {
189 this.needs_server_update = false;
190 })?;
191
192 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
193 log::error!("Error maintaining context servers: {}", err);
194 }
195
196 this.update(cx, |this, cx| {
197 let has_any_context_servers = !this.running_servers().is_empty();
198 if has_any_context_servers {
199 CommandPaletteFilter::update_global(cx, |filter, _cx| {
200 filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
201 });
202 }
203
204 this.update_servers_task.take();
205 if this.needs_server_update {
206 this.available_context_servers_changed(cx);
207 }
208 })?;
209
210 Ok(())
211 }));
212 }
213 }
214
215 pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
216 self.servers
217 .get(id)
218 .filter(|server| server.client().is_some())
219 .cloned()
220 }
221
222 pub fn status_for_server(&self, id: &str) -> Option<ContextServerStatus> {
223 self.server_status.get(id).cloned()
224 }
225
226 pub fn start_server(
227 &self,
228 server: Arc<ContextServer>,
229 cx: &mut Context<Self>,
230 ) -> Task<Result<()>> {
231 cx.spawn(async move |this, cx| Self::run_server(this, server, cx).await)
232 }
233
234 pub fn stop_server(
235 &mut self,
236 server: Arc<ContextServer>,
237 cx: &mut Context<Self>,
238 ) -> Result<()> {
239 server.stop().log_err();
240 self.update_server_status(server.id().clone(), None, cx);
241 Ok(())
242 }
243
244 pub fn restart_server(&mut self, id: &Arc<str>, cx: &mut Context<Self>) -> Task<Result<()>> {
245 let id = id.clone();
246 cx.spawn(async move |this, cx| {
247 if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
248 let config = server.config();
249
250 this.update(cx, |this, cx| this.stop_server(server, cx))??;
251 let new_server = Arc::new(ContextServer::new(id.clone(), config));
252 Self::run_server(this, new_server, cx).await?;
253 }
254 Ok(())
255 })
256 }
257
258 pub fn all_servers(&self) -> Vec<Arc<ContextServer>> {
259 self.servers.values().cloned().collect()
260 }
261
262 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
263 self.servers
264 .values()
265 .filter(|server| server.client().is_some())
266 .cloned()
267 .collect()
268 }
269
270 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
271 let mut desired_servers = HashMap::default();
272
273 let (registry, project) = this.update(cx, |this, cx| {
274 let location = this
275 .project
276 .read(cx)
277 .visible_worktrees(cx)
278 .next()
279 .map(|worktree| settings::SettingsLocation {
280 worktree_id: worktree.read(cx).id(),
281 path: Path::new(""),
282 });
283 let settings = ContextServerSettings::get(location, cx);
284 desired_servers = settings.context_servers.clone();
285
286 (this.registry.clone(), this.project.clone())
287 })?;
288
289 for (id, descriptor) in
290 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
291 {
292 let config = desired_servers.entry(id).or_default();
293 if config.command.is_none() {
294 if let Some(extension_command) =
295 descriptor.command(project.clone(), &cx).await.log_err()
296 {
297 config.command = Some(extension_command);
298 }
299 }
300 }
301
302 let mut servers_to_start = HashMap::default();
303 let mut servers_to_stop = HashMap::default();
304
305 this.update(cx, |this, _cx| {
306 this.servers.retain(|id, server| {
307 if desired_servers.contains_key(id) {
308 true
309 } else {
310 servers_to_stop.insert(id.clone(), server.clone());
311 false
312 }
313 });
314
315 for (id, config) in desired_servers {
316 let existing_config = this.servers.get(&id).map(|server| server.config());
317 if existing_config.as_deref() != Some(&config) {
318 let server = Arc::new(ContextServer::new(id.clone(), Arc::new(config)));
319 servers_to_start.insert(id.clone(), server.clone());
320 if let Some(old_server) = this.servers.remove(&id) {
321 servers_to_stop.insert(id, old_server);
322 }
323 }
324 }
325 })?;
326
327 for (_, server) in servers_to_stop {
328 this.update(cx, |this, cx| this.stop_server(server, cx).ok())?;
329 }
330
331 for (_, server) in servers_to_start {
332 Self::run_server(this.clone(), server, cx).await.ok();
333 }
334
335 Ok(())
336 }
337
338 async fn run_server(
339 this: WeakEntity<Self>,
340 server: Arc<ContextServer>,
341 cx: &mut AsyncApp,
342 ) -> Result<()> {
343 let id = server.id();
344
345 this.update(cx, |this, cx| {
346 this.update_server_status(id.clone(), Some(ContextServerStatus::Starting), cx);
347 this.servers.insert(id.clone(), server.clone());
348 })?;
349
350 match server.start(&cx).await {
351 Ok(_) => {
352 log::debug!("`{}` context server started", id);
353 this.update(cx, |this, cx| {
354 this.update_server_status(id.clone(), Some(ContextServerStatus::Running), cx)
355 })?;
356 Ok(())
357 }
358 Err(err) => {
359 log::error!("`{}` context server failed to start\n{}", id, err);
360 this.update(cx, |this, cx| {
361 this.update_server_status(
362 id.clone(),
363 Some(ContextServerStatus::Error(err.to_string().into())),
364 cx,
365 )
366 })?;
367 Err(err)
368 }
369 }
370 }
371
372 fn update_server_status(
373 &mut self,
374 id: Arc<str>,
375 status: Option<ContextServerStatus>,
376 cx: &mut Context<Self>,
377 ) {
378 if let Some(status) = status.clone() {
379 self.server_status.insert(id.clone(), status);
380 } else {
381 self.server_status.remove(&id);
382 }
383
384 cx.emit(Event::ServerStatusChanged {
385 server_id: id,
386 status,
387 });
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use std::pin::Pin;
394
395 use crate::types::{
396 Implementation, InitializeResponse, ProtocolVersion, RequestType, ServerCapabilities,
397 };
398
399 use super::*;
400 use futures::{Stream, StreamExt as _, lock::Mutex};
401 use gpui::{AppContext as _, TestAppContext};
402 use project::FakeFs;
403 use serde_json::json;
404 use util::path;
405
406 #[gpui::test]
407 async fn test_context_server_status(cx: &mut TestAppContext) {
408 init_test_settings(cx);
409 let project = create_test_project(cx, json!({"code.rs": ""})).await;
410
411 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
412 let manager = cx.new(|cx| ContextServerManager::new(registry.clone(), project, cx));
413
414 let server_1_id: Arc<str> = "mcp-1".into();
415 let server_2_id: Arc<str> = "mcp-2".into();
416
417 let transport_1 = Arc::new(FakeTransport::new(
418 |_, request_type, _| match request_type {
419 Some(RequestType::Initialize) => {
420 Some(create_initialize_response("mcp-1".to_string()))
421 }
422 _ => None,
423 },
424 ));
425
426 let transport_2 = Arc::new(FakeTransport::new(
427 |_, request_type, _| match request_type {
428 Some(RequestType::Initialize) => {
429 Some(create_initialize_response("mcp-2".to_string()))
430 }
431 _ => None,
432 },
433 ));
434
435 let server_1 = ContextServer::test(server_1_id.clone(), transport_1.clone());
436 let server_2 = ContextServer::test(server_2_id.clone(), transport_2.clone());
437
438 manager
439 .update(cx, |manager, cx| manager.start_server(server_1, cx))
440 .await
441 .unwrap();
442
443 cx.update(|cx| {
444 assert_eq!(
445 manager.read(cx).status_for_server(&server_1_id),
446 Some(ContextServerStatus::Running)
447 );
448 assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
449 });
450
451 manager
452 .update(cx, |manager, cx| manager.start_server(server_2.clone(), cx))
453 .await
454 .unwrap();
455
456 cx.update(|cx| {
457 assert_eq!(
458 manager.read(cx).status_for_server(&server_1_id),
459 Some(ContextServerStatus::Running)
460 );
461 assert_eq!(
462 manager.read(cx).status_for_server(&server_2_id),
463 Some(ContextServerStatus::Running)
464 );
465 });
466
467 manager
468 .update(cx, |manager, cx| manager.stop_server(server_2, cx))
469 .unwrap();
470
471 cx.update(|cx| {
472 assert_eq!(
473 manager.read(cx).status_for_server(&server_1_id),
474 Some(ContextServerStatus::Running)
475 );
476 assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
477 });
478 }
479
480 async fn create_test_project(
481 cx: &mut TestAppContext,
482 files: serde_json::Value,
483 ) -> Entity<Project> {
484 let fs = FakeFs::new(cx.executor());
485 fs.insert_tree(path!("/test"), files).await;
486 Project::test(fs, [path!("/test").as_ref()], cx).await
487 }
488
489 fn init_test_settings(cx: &mut TestAppContext) {
490 cx.update(|cx| {
491 let settings_store = SettingsStore::test(cx);
492 cx.set_global(settings_store);
493 Project::init_settings(cx);
494 ContextServerSettings::register(cx);
495 });
496 }
497
498 fn create_initialize_response(server_name: String) -> serde_json::Value {
499 serde_json::to_value(&InitializeResponse {
500 protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
501 server_info: Implementation {
502 name: server_name,
503 version: "1.0.0".to_string(),
504 },
505 capabilities: ServerCapabilities::default(),
506 meta: None,
507 })
508 .unwrap()
509 }
510
511 struct FakeTransport {
512 on_request: Arc<
513 dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
514 + Send
515 + Sync,
516 >,
517 tx: futures::channel::mpsc::UnboundedSender<String>,
518 rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
519 }
520
521 impl FakeTransport {
522 fn new(
523 on_request: impl Fn(
524 u64,
525 Option<RequestType>,
526 serde_json::Value,
527 ) -> Option<serde_json::Value>
528 + 'static
529 + Send
530 + Sync,
531 ) -> Self {
532 let (tx, rx) = futures::channel::mpsc::unbounded();
533 Self {
534 on_request: Arc::new(on_request),
535 tx,
536 rx: Arc::new(Mutex::new(rx)),
537 }
538 }
539 }
540
541 #[async_trait::async_trait]
542 impl Transport for FakeTransport {
543 async fn send(&self, message: String) -> Result<()> {
544 if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
545 let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
546
547 if let Some(method) = msg.get("method") {
548 let request_type = method
549 .as_str()
550 .and_then(|method| types::RequestType::try_from(method).ok());
551 if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
552 let response = serde_json::json!({
553 "jsonrpc": "2.0",
554 "id": id,
555 "result": payload
556 });
557
558 self.tx
559 .unbounded_send(response.to_string())
560 .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
561 }
562 }
563 }
564 Ok(())
565 }
566
567 fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
568 let rx = self.rx.clone();
569 Box::pin(futures::stream::unfold(rx, |rx| async move {
570 let mut rx_guard = rx.lock().await;
571 if let Some(message) = rx_guard.next().await {
572 drop(rx_guard);
573 Some((message, rx))
574 } else {
575 None
576 }
577 }))
578 }
579
580 fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
581 Box::pin(futures::stream::empty())
582 }
583 }
584}