manager.rs

  1//! This module implements a context server management system for Zed.
  2//!
  3//! It provides functionality to:
  4//! - Define and load context server settings
  5//! - Manage individual context servers (start, stop, restart)
  6//! - Maintain a global manager for all context servers
  7//!
  8//! Key components:
  9//! - `ContextServerSettings`: Defines the structure for server configurations
 10//! - `ContextServer`: Represents an individual context server
 11//! - `ContextServerManager`: Manages multiple context servers
 12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
 13//!
 14//! The module also includes initialization logic to set up the context server system
 15//! and react to changes in settings.
 16
 17use std::path::Path;
 18use std::sync::Arc;
 19
 20use anyhow::{Result, bail};
 21use collections::HashMap;
 22use command_palette_hooks::CommandPaletteFilter;
 23use gpui::{AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
 24use log;
 25use parking_lot::RwLock;
 26use project::Project;
 27use settings::{Settings, SettingsStore};
 28use util::ResultExt as _;
 29
 30use crate::{ContextServerSettings, ServerConfig};
 31
 32use crate::{
 33    CONTEXT_SERVERS_NAMESPACE, ContextServerFactoryRegistry,
 34    client::{self, Client},
 35    types,
 36};
 37
 38pub struct ContextServer {
 39    pub id: Arc<str>,
 40    pub config: Arc<ServerConfig>,
 41    pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 42}
 43
 44impl ContextServer {
 45    pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
 46        Self {
 47            id,
 48            config,
 49            client: RwLock::new(None),
 50        }
 51    }
 52
 53    pub fn id(&self) -> Arc<str> {
 54        self.id.clone()
 55    }
 56
 57    pub fn config(&self) -> Arc<ServerConfig> {
 58        self.config.clone()
 59    }
 60
 61    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
 62        self.client.read().clone()
 63    }
 64
 65    pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
 66        log::info!("starting context server {}", self.id);
 67        let Some(command) = &self.config.command else {
 68            bail!("no command specified for server {}", self.id);
 69        };
 70        let client = Client::new(
 71            client::ContextServerId(self.id.clone()),
 72            client::ModelContextServerBinary {
 73                executable: Path::new(&command.path).to_path_buf(),
 74                args: command.args.clone(),
 75                env: command.env.clone(),
 76            },
 77            cx.clone(),
 78        )?;
 79
 80        let protocol = crate::protocol::ModelContextProtocol::new(client);
 81        let client_info = types::Implementation {
 82            name: "Zed".to_string(),
 83            version: env!("CARGO_PKG_VERSION").to_string(),
 84        };
 85        let initialized_protocol = protocol.initialize(client_info).await?;
 86
 87        log::debug!(
 88            "context server {} initialized: {:?}",
 89            self.id,
 90            initialized_protocol.initialize,
 91        );
 92
 93        *self.client.write() = Some(Arc::new(initialized_protocol));
 94        Ok(())
 95    }
 96
 97    pub fn stop(&self) -> Result<()> {
 98        let mut client = self.client.write();
 99        if let Some(protocol) = client.take() {
100            drop(protocol);
101        }
102        Ok(())
103    }
104}
105
106pub struct ContextServerManager {
107    servers: HashMap<Arc<str>, Arc<ContextServer>>,
108    project: Entity<Project>,
109    registry: Entity<ContextServerFactoryRegistry>,
110    update_servers_task: Option<Task<Result<()>>>,
111    needs_server_update: bool,
112    _subscriptions: Vec<Subscription>,
113}
114
115pub enum Event {
116    ServerStarted { server_id: Arc<str> },
117    ServerStopped { server_id: Arc<str> },
118}
119
120impl EventEmitter<Event> for ContextServerManager {}
121
122impl ContextServerManager {
123    pub fn new(
124        registry: Entity<ContextServerFactoryRegistry>,
125        project: Entity<Project>,
126        cx: &mut Context<Self>,
127    ) -> Self {
128        let mut this = Self {
129            _subscriptions: vec![
130                cx.observe(&registry, |this, _registry, cx| {
131                    this.available_context_servers_changed(cx);
132                }),
133                cx.observe_global::<SettingsStore>(|this, cx| {
134                    this.available_context_servers_changed(cx);
135                }),
136            ],
137            project,
138            registry,
139            needs_server_update: false,
140            servers: HashMap::default(),
141            update_servers_task: None,
142        };
143        this.available_context_servers_changed(cx);
144        this
145    }
146
147    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
148        if self.update_servers_task.is_some() {
149            self.needs_server_update = true;
150        } else {
151            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
152                this.update(cx, |this, _| {
153                    this.needs_server_update = false;
154                })?;
155
156                Self::maintain_servers(this.clone(), cx).await?;
157
158                this.update(cx, |this, cx| {
159                    let has_any_context_servers = !this.running_servers().is_empty();
160                    if has_any_context_servers {
161                        CommandPaletteFilter::update_global(cx, |filter, _cx| {
162                            filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
163                        });
164                    }
165
166                    this.update_servers_task.take();
167                    if this.needs_server_update {
168                        this.available_context_servers_changed(cx);
169                    }
170                })?;
171
172                Ok(())
173            }));
174        }
175    }
176
177    pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
178        self.servers
179            .get(id)
180            .filter(|server| server.client().is_some())
181            .cloned()
182    }
183
184    pub fn start_server(
185        &self,
186        server: Arc<ContextServer>,
187        cx: &mut Context<Self>,
188    ) -> Task<anyhow::Result<()>> {
189        cx.spawn(async move |this, cx| {
190            let id = server.id.clone();
191            server.start(&cx).await?;
192            this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
193            Ok(())
194        })
195    }
196
197    pub fn stop_server(
198        &self,
199        server: Arc<ContextServer>,
200        cx: &mut Context<Self>,
201    ) -> anyhow::Result<()> {
202        server.stop()?;
203        cx.emit(Event::ServerStopped {
204            server_id: server.id(),
205        });
206        Ok(())
207    }
208
209    pub fn restart_server(
210        &mut self,
211        id: &Arc<str>,
212        cx: &mut Context<Self>,
213    ) -> Task<anyhow::Result<()>> {
214        let id = id.clone();
215        cx.spawn(async move |this, cx| {
216            if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
217                server.stop()?;
218                let config = server.config();
219                let new_server = Arc::new(ContextServer::new(id.clone(), config));
220                new_server.clone().start(&cx).await?;
221                this.update(cx, |this, cx| {
222                    this.servers.insert(id.clone(), new_server);
223                    cx.emit(Event::ServerStopped {
224                        server_id: id.clone(),
225                    });
226                    cx.emit(Event::ServerStarted {
227                        server_id: id.clone(),
228                    });
229                })?;
230            }
231            Ok(())
232        })
233    }
234
235    pub fn all_servers(&self) -> Vec<Arc<ContextServer>> {
236        self.servers.values().cloned().collect()
237    }
238
239    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
240        self.servers
241            .values()
242            .filter(|server| server.client().is_some())
243            .cloned()
244            .collect()
245    }
246
247    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
248        let mut desired_servers = HashMap::default();
249
250        let (registry, project) = this.update(cx, |this, cx| {
251            let location = this
252                .project
253                .read(cx)
254                .visible_worktrees(cx)
255                .next()
256                .map(|worktree| settings::SettingsLocation {
257                    worktree_id: worktree.read(cx).id(),
258                    path: Path::new(""),
259                });
260            let settings = ContextServerSettings::get(location, cx);
261            desired_servers = settings.context_servers.clone();
262
263            (this.registry.clone(), this.project.clone())
264        })?;
265
266        for (id, factory) in
267            registry.read_with(cx, |registry, _| registry.context_server_factories())?
268        {
269            let config = desired_servers.entry(id).or_default();
270            if config.command.is_none() {
271                if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
272                    config.command = Some(extension_command);
273                }
274            }
275        }
276
277        let mut servers_to_start = HashMap::default();
278        let mut servers_to_stop = HashMap::default();
279
280        this.update(cx, |this, _cx| {
281            this.servers.retain(|id, server| {
282                if desired_servers.contains_key(id) {
283                    true
284                } else {
285                    servers_to_stop.insert(id.clone(), server.clone());
286                    false
287                }
288            });
289
290            for (id, config) in desired_servers {
291                let existing_config = this.servers.get(&id).map(|server| server.config());
292                if existing_config.as_deref() != Some(&config) {
293                    let config = Arc::new(config);
294                    let server = Arc::new(ContextServer::new(id.clone(), config));
295                    servers_to_start.insert(id.clone(), server.clone());
296                    let old_server = this.servers.insert(id.clone(), server);
297                    if let Some(old_server) = old_server {
298                        servers_to_stop.insert(id, old_server);
299                    }
300                }
301            }
302        })?;
303
304        for (id, server) in servers_to_stop {
305            server.stop().log_err();
306            this.update(cx, |_, cx| cx.emit(Event::ServerStopped { server_id: id }))?;
307        }
308
309        for (id, server) in servers_to_start {
310            if server.start(&cx).await.log_err().is_some() {
311                this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
312            }
313        }
314
315        Ok(())
316    }
317}