manager.rs

  1//! This module implements a context server management system for Zed.
  2//!
  3//! It provides functionality to:
  4//! - Define and load context server settings
  5//! - Manage individual context servers (start, stop, restart)
  6//! - Maintain a global manager for all context servers
  7//!
  8//! Key components:
  9//! - `ContextServerSettings`: Defines the structure for server configurations
 10//! - `ContextServer`: Represents an individual context server
 11//! - `ContextServerManager`: Manages multiple context servers
 12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
 13//!
 14//! The module also includes initialization logic to set up the context server system
 15//! and react to changes in settings.
 16
 17use std::path::Path;
 18use std::sync::Arc;
 19
 20use anyhow::{bail, Result};
 21use collections::HashMap;
 22use command_palette_hooks::CommandPaletteFilter;
 23use gpui::{AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
 24use log;
 25use parking_lot::RwLock;
 26use project::Project;
 27use settings::{Settings, SettingsStore};
 28use util::ResultExt as _;
 29
 30use crate::{ContextServerSettings, ServerConfig};
 31
 32use crate::{
 33    client::{self, Client},
 34    types, ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE,
 35};
 36
 37pub struct ContextServer {
 38    pub id: Arc<str>,
 39    pub config: Arc<ServerConfig>,
 40    pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 41}
 42
 43impl ContextServer {
 44    pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
 45        Self {
 46            id,
 47            config,
 48            client: RwLock::new(None),
 49        }
 50    }
 51
 52    pub fn id(&self) -> Arc<str> {
 53        self.id.clone()
 54    }
 55
 56    pub fn config(&self) -> Arc<ServerConfig> {
 57        self.config.clone()
 58    }
 59
 60    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
 61        self.client.read().clone()
 62    }
 63
 64    pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
 65        log::info!("starting context server {}", self.id);
 66        let Some(command) = &self.config.command else {
 67            bail!("no command specified for server {}", self.id);
 68        };
 69        let client = Client::new(
 70            client::ContextServerId(self.id.clone()),
 71            client::ModelContextServerBinary {
 72                executable: Path::new(&command.path).to_path_buf(),
 73                args: command.args.clone(),
 74                env: command.env.clone(),
 75            },
 76            cx.clone(),
 77        )?;
 78
 79        let protocol = crate::protocol::ModelContextProtocol::new(client);
 80        let client_info = types::Implementation {
 81            name: "Zed".to_string(),
 82            version: env!("CARGO_PKG_VERSION").to_string(),
 83        };
 84        let initialized_protocol = protocol.initialize(client_info).await?;
 85
 86        log::debug!(
 87            "context server {} initialized: {:?}",
 88            self.id,
 89            initialized_protocol.initialize,
 90        );
 91
 92        *self.client.write() = Some(Arc::new(initialized_protocol));
 93        Ok(())
 94    }
 95
 96    pub fn stop(&self) -> Result<()> {
 97        let mut client = self.client.write();
 98        if let Some(protocol) = client.take() {
 99            drop(protocol);
100        }
101        Ok(())
102    }
103}
104
105pub struct ContextServerManager {
106    servers: HashMap<Arc<str>, Arc<ContextServer>>,
107    project: Entity<Project>,
108    registry: Entity<ContextServerFactoryRegistry>,
109    update_servers_task: Option<Task<Result<()>>>,
110    needs_server_update: bool,
111    _subscriptions: Vec<Subscription>,
112}
113
114pub enum Event {
115    ServerStarted { server_id: Arc<str> },
116    ServerStopped { server_id: Arc<str> },
117}
118
119impl EventEmitter<Event> for ContextServerManager {}
120
121impl ContextServerManager {
122    pub fn new(
123        registry: Entity<ContextServerFactoryRegistry>,
124        project: Entity<Project>,
125        cx: &mut Context<Self>,
126    ) -> Self {
127        let mut this = Self {
128            _subscriptions: vec![
129                cx.observe(&registry, |this, _registry, cx| {
130                    this.available_context_servers_changed(cx);
131                }),
132                cx.observe_global::<SettingsStore>(|this, cx| {
133                    this.available_context_servers_changed(cx);
134                }),
135            ],
136            project,
137            registry,
138            needs_server_update: false,
139            servers: HashMap::default(),
140            update_servers_task: None,
141        };
142        this.available_context_servers_changed(cx);
143        this
144    }
145
146    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
147        if self.update_servers_task.is_some() {
148            self.needs_server_update = true;
149        } else {
150            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
151                this.update(cx, |this, _| {
152                    this.needs_server_update = false;
153                })?;
154
155                Self::maintain_servers(this.clone(), cx).await?;
156
157                this.update(cx, |this, cx| {
158                    let has_any_context_servers = !this.running_servers().is_empty();
159                    if has_any_context_servers {
160                        CommandPaletteFilter::update_global(cx, |filter, _cx| {
161                            filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
162                        });
163                    }
164
165                    this.update_servers_task.take();
166                    if this.needs_server_update {
167                        this.available_context_servers_changed(cx);
168                    }
169                })?;
170
171                Ok(())
172            }));
173        }
174    }
175
176    pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
177        self.servers
178            .get(id)
179            .filter(|server| server.client().is_some())
180            .cloned()
181    }
182
183    pub fn start_server(
184        &self,
185        server: Arc<ContextServer>,
186        cx: &mut Context<Self>,
187    ) -> Task<anyhow::Result<()>> {
188        cx.spawn(async move |this, cx| {
189            let id = server.id.clone();
190            server.start(&cx).await?;
191            this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
192            Ok(())
193        })
194    }
195
196    pub fn stop_server(
197        &self,
198        server: Arc<ContextServer>,
199        cx: &mut Context<Self>,
200    ) -> anyhow::Result<()> {
201        server.stop()?;
202        cx.emit(Event::ServerStopped {
203            server_id: server.id(),
204        });
205        Ok(())
206    }
207
208    pub fn restart_server(
209        &mut self,
210        id: &Arc<str>,
211        cx: &mut Context<Self>,
212    ) -> Task<anyhow::Result<()>> {
213        let id = id.clone();
214        cx.spawn(async move |this, cx| {
215            if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
216                server.stop()?;
217                let config = server.config();
218                let new_server = Arc::new(ContextServer::new(id.clone(), config));
219                new_server.clone().start(&cx).await?;
220                this.update(cx, |this, cx| {
221                    this.servers.insert(id.clone(), new_server);
222                    cx.emit(Event::ServerStopped {
223                        server_id: id.clone(),
224                    });
225                    cx.emit(Event::ServerStarted {
226                        server_id: id.clone(),
227                    });
228                })?;
229            }
230            Ok(())
231        })
232    }
233
234    pub fn all_servers(&self) -> Vec<Arc<ContextServer>> {
235        self.servers.values().cloned().collect()
236    }
237
238    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
239        self.servers
240            .values()
241            .filter(|server| server.client().is_some())
242            .cloned()
243            .collect()
244    }
245
246    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
247        let mut desired_servers = HashMap::default();
248
249        let (registry, project) = this.update(cx, |this, cx| {
250            let location = this
251                .project
252                .read(cx)
253                .visible_worktrees(cx)
254                .next()
255                .map(|worktree| settings::SettingsLocation {
256                    worktree_id: worktree.read(cx).id(),
257                    path: Path::new(""),
258                });
259            let settings = ContextServerSettings::get(location, cx);
260            desired_servers = settings.context_servers.clone();
261
262            (this.registry.clone(), this.project.clone())
263        })?;
264
265        for (id, factory) in
266            registry.read_with(cx, |registry, _| registry.context_server_factories())?
267        {
268            let config = desired_servers.entry(id).or_default();
269            if config.command.is_none() {
270                if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
271                    config.command = Some(extension_command);
272                }
273            }
274        }
275
276        let mut servers_to_start = HashMap::default();
277        let mut servers_to_stop = HashMap::default();
278
279        this.update(cx, |this, _cx| {
280            this.servers.retain(|id, server| {
281                if desired_servers.contains_key(id) {
282                    true
283                } else {
284                    servers_to_stop.insert(id.clone(), server.clone());
285                    false
286                }
287            });
288
289            for (id, config) in desired_servers {
290                let existing_config = this.servers.get(&id).map(|server| server.config());
291                if existing_config.as_deref() != Some(&config) {
292                    let config = Arc::new(config);
293                    let server = Arc::new(ContextServer::new(id.clone(), config));
294                    servers_to_start.insert(id.clone(), server.clone());
295                    let old_server = this.servers.insert(id.clone(), server);
296                    if let Some(old_server) = old_server {
297                        servers_to_stop.insert(id, old_server);
298                    }
299                }
300            }
301        })?;
302
303        for (id, server) in servers_to_stop {
304            server.stop().log_err();
305            this.update(cx, |_, cx| cx.emit(Event::ServerStopped { server_id: id }))?;
306        }
307
308        for (id, server) in servers_to_start {
309            if server.start(&cx).await.log_err().is_some() {
310                this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
311            }
312        }
313
314        Ok(())
315    }
316}