manager.rs

  1//! This module implements a context server management system for Zed.
  2//!
  3//! It provides functionality to:
  4//! - Define and load context server settings
  5//! - Manage individual context servers (start, stop, restart)
  6//! - Maintain a global manager for all context servers
  7//!
  8//! Key components:
  9//! - `ContextServerSettings`: Defines the structure for server configurations
 10//! - `ContextServer`: Represents an individual context server
 11//! - `ContextServerManager`: Manages multiple context servers
 12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
 13//!
 14//! The module also includes initialization logic to set up the context server system
 15//! and react to changes in settings.
 16
 17use std::path::Path;
 18use std::sync::Arc;
 19
 20use anyhow::{Result, bail};
 21use collections::HashMap;
 22use command_palette_hooks::CommandPaletteFilter;
 23use gpui::{AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
 24use log;
 25use parking_lot::RwLock;
 26use project::Project;
 27use settings::{Settings, SettingsStore};
 28use util::ResultExt as _;
 29
 30use crate::transport::Transport;
 31use crate::{ContextServerSettings, ServerConfig};
 32
 33use crate::{
 34    CONTEXT_SERVERS_NAMESPACE, ContextServerDescriptorRegistry,
 35    client::{self, Client},
 36    types,
 37};
 38
 39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 40pub enum ContextServerStatus {
 41    Starting,
 42    Running,
 43    Error(Arc<str>),
 44}
 45
 46pub struct ContextServer {
 47    pub id: Arc<str>,
 48    pub config: Arc<ServerConfig>,
 49    pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 50    transport: Option<Arc<dyn Transport>>,
 51}
 52
 53impl ContextServer {
 54    pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
 55        Self {
 56            id,
 57            config,
 58            client: RwLock::new(None),
 59            transport: None,
 60        }
 61    }
 62
 63    #[cfg(any(test, feature = "test-support"))]
 64    pub fn test(id: Arc<str>, transport: Arc<dyn crate::transport::Transport>) -> Arc<Self> {
 65        Arc::new(Self {
 66            id,
 67            client: RwLock::new(None),
 68            config: Arc::new(ServerConfig::default()),
 69            transport: Some(transport),
 70        })
 71    }
 72
 73    pub fn id(&self) -> Arc<str> {
 74        self.id.clone()
 75    }
 76
 77    pub fn config(&self) -> Arc<ServerConfig> {
 78        self.config.clone()
 79    }
 80
 81    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
 82        self.client.read().clone()
 83    }
 84
 85    pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
 86        let client = if let Some(transport) = self.transport.clone() {
 87            Client::new(
 88                client::ContextServerId(self.id.clone()),
 89                self.id(),
 90                transport,
 91                cx.clone(),
 92            )?
 93        } else {
 94            let Some(command) = &self.config.command else {
 95                bail!("no command specified for server {}", self.id);
 96            };
 97            Client::stdio(
 98                client::ContextServerId(self.id.clone()),
 99                client::ModelContextServerBinary {
100                    executable: Path::new(&command.path).to_path_buf(),
101                    args: command.args.clone(),
102                    env: command.env.clone(),
103                },
104                cx.clone(),
105            )?
106        };
107        self.initialize(client).await
108    }
109
110    async fn initialize(&self, client: Client) -> Result<()> {
111        log::info!("starting context server {}", self.id);
112        let protocol = crate::protocol::ModelContextProtocol::new(client);
113        let client_info = types::Implementation {
114            name: "Zed".to_string(),
115            version: env!("CARGO_PKG_VERSION").to_string(),
116        };
117        let initialized_protocol = protocol.initialize(client_info).await?;
118
119        log::debug!(
120            "context server {} initialized: {:?}",
121            self.id,
122            initialized_protocol.initialize,
123        );
124
125        *self.client.write() = Some(Arc::new(initialized_protocol));
126        Ok(())
127    }
128
129    pub fn stop(&self) -> Result<()> {
130        let mut client = self.client.write();
131        if let Some(protocol) = client.take() {
132            drop(protocol);
133        }
134        Ok(())
135    }
136}
137
138pub struct ContextServerManager {
139    servers: HashMap<Arc<str>, Arc<ContextServer>>,
140    server_status: HashMap<Arc<str>, ContextServerStatus>,
141    project: Entity<Project>,
142    registry: Entity<ContextServerDescriptorRegistry>,
143    update_servers_task: Option<Task<Result<()>>>,
144    needs_server_update: bool,
145    _subscriptions: Vec<Subscription>,
146}
147
148pub enum Event {
149    ServerStatusChanged {
150        server_id: Arc<str>,
151        status: Option<ContextServerStatus>,
152    },
153}
154
155impl EventEmitter<Event> for ContextServerManager {}
156
157impl ContextServerManager {
158    pub fn new(
159        registry: Entity<ContextServerDescriptorRegistry>,
160        project: Entity<Project>,
161        cx: &mut Context<Self>,
162    ) -> Self {
163        let mut this = Self {
164            _subscriptions: vec![
165                cx.observe(&registry, |this, _registry, cx| {
166                    this.available_context_servers_changed(cx);
167                }),
168                cx.observe_global::<SettingsStore>(|this, cx| {
169                    this.available_context_servers_changed(cx);
170                }),
171            ],
172            project,
173            registry,
174            needs_server_update: false,
175            servers: HashMap::default(),
176            server_status: HashMap::default(),
177            update_servers_task: None,
178        };
179        this.available_context_servers_changed(cx);
180        this
181    }
182
183    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
184        if self.update_servers_task.is_some() {
185            self.needs_server_update = true;
186        } else {
187            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
188                this.update(cx, |this, _| {
189                    this.needs_server_update = false;
190                })?;
191
192                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
193                    log::error!("Error maintaining context servers: {}", err);
194                }
195
196                this.update(cx, |this, cx| {
197                    let has_any_context_servers = !this.running_servers().is_empty();
198                    if has_any_context_servers {
199                        CommandPaletteFilter::update_global(cx, |filter, _cx| {
200                            filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
201                        });
202                    }
203
204                    this.update_servers_task.take();
205                    if this.needs_server_update {
206                        this.available_context_servers_changed(cx);
207                    }
208                })?;
209
210                Ok(())
211            }));
212        }
213    }
214
215    pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
216        self.servers
217            .get(id)
218            .filter(|server| server.client().is_some())
219            .cloned()
220    }
221
222    pub fn status_for_server(&self, id: &str) -> Option<ContextServerStatus> {
223        self.server_status.get(id).cloned()
224    }
225
226    pub fn start_server(
227        &self,
228        server: Arc<ContextServer>,
229        cx: &mut Context<Self>,
230    ) -> Task<Result<()>> {
231        cx.spawn(async move |this, cx| Self::run_server(this, server, cx).await)
232    }
233
234    pub fn stop_server(
235        &mut self,
236        server: Arc<ContextServer>,
237        cx: &mut Context<Self>,
238    ) -> Result<()> {
239        server.stop().log_err();
240        self.update_server_status(server.id().clone(), None, cx);
241        Ok(())
242    }
243
244    pub fn restart_server(&mut self, id: &Arc<str>, cx: &mut Context<Self>) -> Task<Result<()>> {
245        let id = id.clone();
246        cx.spawn(async move |this, cx| {
247            if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
248                let config = server.config();
249
250                this.update(cx, |this, cx| this.stop_server(server, cx))??;
251                let new_server = Arc::new(ContextServer::new(id.clone(), config));
252                Self::run_server(this, new_server, cx).await?;
253            }
254            Ok(())
255        })
256    }
257
258    pub fn all_servers(&self) -> Vec<Arc<ContextServer>> {
259        self.servers.values().cloned().collect()
260    }
261
262    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
263        self.servers
264            .values()
265            .filter(|server| server.client().is_some())
266            .cloned()
267            .collect()
268    }
269
270    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
271        let mut desired_servers = HashMap::default();
272
273        let (registry, project) = this.update(cx, |this, cx| {
274            let location = this
275                .project
276                .read(cx)
277                .visible_worktrees(cx)
278                .next()
279                .map(|worktree| settings::SettingsLocation {
280                    worktree_id: worktree.read(cx).id(),
281                    path: Path::new(""),
282                });
283            let settings = ContextServerSettings::get(location, cx);
284            desired_servers = settings.context_servers.clone();
285
286            (this.registry.clone(), this.project.clone())
287        })?;
288
289        for (id, descriptor) in
290            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
291        {
292            let config = desired_servers.entry(id).or_default();
293            if config.command.is_none() {
294                if let Some(extension_command) =
295                    descriptor.command(project.clone(), &cx).await.log_err()
296                {
297                    config.command = Some(extension_command);
298                }
299            }
300        }
301
302        let mut servers_to_start = HashMap::default();
303        let mut servers_to_stop = HashMap::default();
304
305        this.update(cx, |this, _cx| {
306            this.servers.retain(|id, server| {
307                if desired_servers.contains_key(id) {
308                    true
309                } else {
310                    servers_to_stop.insert(id.clone(), server.clone());
311                    false
312                }
313            });
314
315            for (id, config) in desired_servers {
316                let existing_config = this.servers.get(&id).map(|server| server.config());
317                if existing_config.as_deref() != Some(&config) {
318                    let server = Arc::new(ContextServer::new(id.clone(), Arc::new(config)));
319                    servers_to_start.insert(id.clone(), server.clone());
320                    if let Some(old_server) = this.servers.remove(&id) {
321                        servers_to_stop.insert(id, old_server);
322                    }
323                }
324            }
325        })?;
326
327        for (_, server) in servers_to_stop {
328            this.update(cx, |this, cx| this.stop_server(server, cx).ok())?;
329        }
330
331        for (_, server) in servers_to_start {
332            Self::run_server(this.clone(), server, cx).await.ok();
333        }
334
335        Ok(())
336    }
337
338    async fn run_server(
339        this: WeakEntity<Self>,
340        server: Arc<ContextServer>,
341        cx: &mut AsyncApp,
342    ) -> Result<()> {
343        let id = server.id();
344
345        this.update(cx, |this, cx| {
346            this.update_server_status(id.clone(), Some(ContextServerStatus::Starting), cx);
347            this.servers.insert(id.clone(), server.clone());
348        })?;
349
350        match server.start(&cx).await {
351            Ok(_) => {
352                log::debug!("`{}` context server started", id);
353                this.update(cx, |this, cx| {
354                    this.update_server_status(id.clone(), Some(ContextServerStatus::Running), cx)
355                })?;
356                Ok(())
357            }
358            Err(err) => {
359                log::error!("`{}` context server failed to start\n{}", id, err);
360                this.update(cx, |this, cx| {
361                    this.update_server_status(
362                        id.clone(),
363                        Some(ContextServerStatus::Error(err.to_string().into())),
364                        cx,
365                    )
366                })?;
367                Err(err)
368            }
369        }
370    }
371
372    fn update_server_status(
373        &mut self,
374        id: Arc<str>,
375        status: Option<ContextServerStatus>,
376        cx: &mut Context<Self>,
377    ) {
378        if let Some(status) = status.clone() {
379            self.server_status.insert(id.clone(), status);
380        } else {
381            self.server_status.remove(&id);
382        }
383
384        cx.emit(Event::ServerStatusChanged {
385            server_id: id,
386            status,
387        });
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use std::pin::Pin;
394
395    use crate::types::{
396        Implementation, InitializeResponse, ProtocolVersion, RequestType, ServerCapabilities,
397    };
398
399    use super::*;
400    use futures::{Stream, StreamExt as _, lock::Mutex};
401    use gpui::{AppContext as _, TestAppContext};
402    use project::FakeFs;
403    use serde_json::json;
404    use util::path;
405
406    #[gpui::test]
407    async fn test_context_server_status(cx: &mut TestAppContext) {
408        init_test_settings(cx);
409        let project = create_test_project(cx, json!({"code.rs": ""})).await;
410
411        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
412        let manager = cx.new(|cx| ContextServerManager::new(registry.clone(), project, cx));
413
414        let server_1_id: Arc<str> = "mcp-1".into();
415        let server_2_id: Arc<str> = "mcp-2".into();
416
417        let transport_1 = Arc::new(FakeTransport::new(
418            |_, request_type, _| match request_type {
419                Some(RequestType::Initialize) => {
420                    Some(create_initialize_response("mcp-1".to_string()))
421                }
422                _ => None,
423            },
424        ));
425
426        let transport_2 = Arc::new(FakeTransport::new(
427            |_, request_type, _| match request_type {
428                Some(RequestType::Initialize) => {
429                    Some(create_initialize_response("mcp-2".to_string()))
430                }
431                _ => None,
432            },
433        ));
434
435        let server_1 = ContextServer::test(server_1_id.clone(), transport_1.clone());
436        let server_2 = ContextServer::test(server_2_id.clone(), transport_2.clone());
437
438        manager
439            .update(cx, |manager, cx| manager.start_server(server_1, cx))
440            .await
441            .unwrap();
442
443        cx.update(|cx| {
444            assert_eq!(
445                manager.read(cx).status_for_server(&server_1_id),
446                Some(ContextServerStatus::Running)
447            );
448            assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
449        });
450
451        manager
452            .update(cx, |manager, cx| manager.start_server(server_2.clone(), cx))
453            .await
454            .unwrap();
455
456        cx.update(|cx| {
457            assert_eq!(
458                manager.read(cx).status_for_server(&server_1_id),
459                Some(ContextServerStatus::Running)
460            );
461            assert_eq!(
462                manager.read(cx).status_for_server(&server_2_id),
463                Some(ContextServerStatus::Running)
464            );
465        });
466
467        manager
468            .update(cx, |manager, cx| manager.stop_server(server_2, cx))
469            .unwrap();
470
471        cx.update(|cx| {
472            assert_eq!(
473                manager.read(cx).status_for_server(&server_1_id),
474                Some(ContextServerStatus::Running)
475            );
476            assert_eq!(manager.read(cx).status_for_server(&server_2_id), None);
477        });
478    }
479
480    async fn create_test_project(
481        cx: &mut TestAppContext,
482        files: serde_json::Value,
483    ) -> Entity<Project> {
484        let fs = FakeFs::new(cx.executor());
485        fs.insert_tree(path!("/test"), files).await;
486        Project::test(fs, [path!("/test").as_ref()], cx).await
487    }
488
489    fn init_test_settings(cx: &mut TestAppContext) {
490        cx.update(|cx| {
491            let settings_store = SettingsStore::test(cx);
492            cx.set_global(settings_store);
493            Project::init_settings(cx);
494            ContextServerSettings::register(cx);
495        });
496    }
497
498    fn create_initialize_response(server_name: String) -> serde_json::Value {
499        serde_json::to_value(&InitializeResponse {
500            protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
501            server_info: Implementation {
502                name: server_name,
503                version: "1.0.0".to_string(),
504            },
505            capabilities: ServerCapabilities::default(),
506            meta: None,
507        })
508        .unwrap()
509    }
510
511    struct FakeTransport {
512        on_request: Arc<
513            dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
514                + Send
515                + Sync,
516        >,
517        tx: futures::channel::mpsc::UnboundedSender<String>,
518        rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
519    }
520
521    impl FakeTransport {
522        fn new(
523            on_request: impl Fn(
524                u64,
525                Option<RequestType>,
526                serde_json::Value,
527            ) -> Option<serde_json::Value>
528            + 'static
529            + Send
530            + Sync,
531        ) -> Self {
532            let (tx, rx) = futures::channel::mpsc::unbounded();
533            Self {
534                on_request: Arc::new(on_request),
535                tx,
536                rx: Arc::new(Mutex::new(rx)),
537            }
538        }
539    }
540
541    #[async_trait::async_trait]
542    impl Transport for FakeTransport {
543        async fn send(&self, message: String) -> Result<()> {
544            if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
545                let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
546
547                if let Some(method) = msg.get("method") {
548                    let request_type = method
549                        .as_str()
550                        .and_then(|method| types::RequestType::try_from(method).ok());
551                    if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
552                        let response = serde_json::json!({
553                            "jsonrpc": "2.0",
554                            "id": id,
555                            "result": payload
556                        });
557
558                        self.tx
559                            .unbounded_send(response.to_string())
560                            .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
561                    }
562                }
563            }
564            Ok(())
565        }
566
567        fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
568            let rx = self.rx.clone();
569            Box::pin(futures::stream::unfold(rx, |rx| async move {
570                let mut rx_guard = rx.lock().await;
571                if let Some(message) = rx_guard.next().await {
572                    drop(rx_guard);
573                    Some((message, rx))
574                } else {
575                    None
576                }
577            }))
578        }
579
580        fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
581            Box::pin(futures::stream::empty())
582        }
583    }
584}