manager.rs

  1//! This module implements a context server management system for Zed.
  2//!
  3//! It provides functionality to:
  4//! - Define and load context server settings
  5//! - Manage individual context servers (start, stop, restart)
  6//! - Maintain a global manager for all context servers
  7//!
  8//! Key components:
  9//! - `ContextServerSettings`: Defines the structure for server configurations
 10//! - `ContextServer`: Represents an individual context server
 11//! - `ContextServerManager`: Manages multiple context servers
 12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
 13//!
 14//! The module also includes initialization logic to set up the context server system
 15//! and react to changes in settings.
 16
 17use std::path::Path;
 18use std::pin::Pin;
 19use std::sync::Arc;
 20
 21use anyhow::{bail, Result};
 22use async_trait::async_trait;
 23use collections::{HashMap, HashSet};
 24use futures::{Future, FutureExt};
 25use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
 26use log;
 27use parking_lot::RwLock;
 28use schemars::JsonSchema;
 29use serde::{Deserialize, Serialize};
 30use settings::{Settings, SettingsSources};
 31
 32use crate::{
 33    client::{self, Client},
 34    types,
 35};
 36
 37#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
 38pub struct ContextServerSettings {
 39    #[serde(default)]
 40    pub context_servers: HashMap<Arc<str>, ServerConfig>,
 41}
 42
 43#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
 44pub struct ServerConfig {
 45    pub command: Option<ServerCommand>,
 46    pub settings: Option<serde_json::Value>,
 47}
 48
 49#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
 50pub struct ServerCommand {
 51    pub path: String,
 52    pub args: Vec<String>,
 53    pub env: Option<HashMap<String, String>>,
 54}
 55
 56impl Settings for ContextServerSettings {
 57    const KEY: Option<&'static str> = None;
 58
 59    type FileContent = Self;
 60
 61    fn load(
 62        sources: SettingsSources<Self::FileContent>,
 63        _: &mut gpui::AppContext,
 64    ) -> anyhow::Result<Self> {
 65        sources.json_merge()
 66    }
 67}
 68
 69#[async_trait(?Send)]
 70pub trait ContextServer: Send + Sync + 'static {
 71    fn id(&self) -> Arc<str>;
 72    fn config(&self) -> Arc<ServerConfig>;
 73    fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>>;
 74    fn start<'a>(
 75        self: Arc<Self>,
 76        cx: &'a AsyncAppContext,
 77    ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>>;
 78    fn stop(&self) -> Result<()>;
 79}
 80
 81pub struct NativeContextServer {
 82    pub id: Arc<str>,
 83    pub config: Arc<ServerConfig>,
 84    pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 85}
 86
 87impl NativeContextServer {
 88    pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
 89        Self {
 90            id,
 91            config,
 92            client: RwLock::new(None),
 93        }
 94    }
 95}
 96
 97#[async_trait(?Send)]
 98impl ContextServer for NativeContextServer {
 99    fn id(&self) -> Arc<str> {
100        self.id.clone()
101    }
102
103    fn config(&self) -> Arc<ServerConfig> {
104        self.config.clone()
105    }
106
107    fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
108        self.client.read().clone()
109    }
110
111    fn start<'a>(
112        self: Arc<Self>,
113        cx: &'a AsyncAppContext,
114    ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
115        async move {
116            log::info!("starting context server {}", self.id);
117            let Some(command) = &self.config.command else {
118                bail!("no command specified for server {}", self.id);
119            };
120            let client = Client::new(
121                client::ContextServerId(self.id.clone()),
122                client::ModelContextServerBinary {
123                    executable: Path::new(&command.path).to_path_buf(),
124                    args: command.args.clone(),
125                    env: command.env.clone(),
126                },
127                cx.clone(),
128            )?;
129
130            let protocol = crate::protocol::ModelContextProtocol::new(client);
131            let client_info = types::Implementation {
132                name: "Zed".to_string(),
133                version: env!("CARGO_PKG_VERSION").to_string(),
134            };
135            let initialized_protocol = protocol.initialize(client_info).await?;
136
137            log::debug!(
138                "context server {} initialized: {:?}",
139                self.id,
140                initialized_protocol.initialize,
141            );
142
143            *self.client.write() = Some(Arc::new(initialized_protocol));
144            Ok(())
145        }
146        .boxed_local()
147    }
148
149    fn stop(&self) -> Result<()> {
150        let mut client = self.client.write();
151        if let Some(protocol) = client.take() {
152            drop(protocol);
153        }
154        Ok(())
155    }
156}
157
158/// A Context server manager manages the starting and stopping
159/// of all servers. To obtain a server to interact with, a crate
160/// must go through the `GlobalContextServerManager` which holds
161/// a model to the ContextServerManager.
162pub struct ContextServerManager {
163    servers: HashMap<Arc<str>, Arc<dyn ContextServer>>,
164    pending_servers: HashSet<Arc<str>>,
165}
166
167pub enum Event {
168    ServerStarted { server_id: Arc<str> },
169    ServerStopped { server_id: Arc<str> },
170}
171
172impl EventEmitter<Event> for ContextServerManager {}
173
174impl Default for ContextServerManager {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180impl ContextServerManager {
181    pub fn new() -> Self {
182        Self {
183            servers: HashMap::default(),
184            pending_servers: HashSet::default(),
185        }
186    }
187
188    pub fn add_server(
189        &mut self,
190        server: Arc<dyn ContextServer>,
191        cx: &ModelContext<Self>,
192    ) -> Task<anyhow::Result<()>> {
193        let server_id = server.id();
194
195        if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
196            return Task::ready(Ok(()));
197        }
198
199        let task = {
200            let server_id = server_id.clone();
201            cx.spawn(|this, mut cx| async move {
202                server.clone().start(&cx).await?;
203                this.update(&mut cx, |this, cx| {
204                    this.servers.insert(server_id.clone(), server);
205                    this.pending_servers.remove(&server_id);
206                    cx.emit(Event::ServerStarted {
207                        server_id: server_id.clone(),
208                    });
209                })?;
210                Ok(())
211            })
212        };
213
214        self.pending_servers.insert(server_id);
215        task
216    }
217
218    pub fn get_server(&self, id: &str) -> Option<Arc<dyn ContextServer>> {
219        self.servers.get(id).cloned()
220    }
221
222    pub fn remove_server(
223        &mut self,
224        id: &Arc<str>,
225        cx: &ModelContext<Self>,
226    ) -> Task<anyhow::Result<()>> {
227        let id = id.clone();
228        cx.spawn(|this, mut cx| async move {
229            if let Some(server) =
230                this.update(&mut cx, |this, _cx| this.servers.remove(id.as_ref()))?
231            {
232                server.stop()?;
233            }
234            this.update(&mut cx, |this, cx| {
235                this.pending_servers.remove(id.as_ref());
236                cx.emit(Event::ServerStopped {
237                    server_id: id.clone(),
238                })
239            })?;
240            Ok(())
241        })
242    }
243
244    pub fn restart_server(
245        &mut self,
246        id: &Arc<str>,
247        cx: &mut ModelContext<Self>,
248    ) -> Task<anyhow::Result<()>> {
249        let id = id.clone();
250        cx.spawn(|this, mut cx| async move {
251            if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
252                server.stop()?;
253                let config = server.config();
254                let new_server = Arc::new(NativeContextServer::new(id.clone(), config));
255                new_server.clone().start(&cx).await?;
256                this.update(&mut cx, |this, cx| {
257                    this.servers.insert(id.clone(), new_server);
258                    cx.emit(Event::ServerStopped {
259                        server_id: id.clone(),
260                    });
261                    cx.emit(Event::ServerStarted {
262                        server_id: id.clone(),
263                    });
264                })?;
265            }
266            Ok(())
267        })
268    }
269
270    pub fn servers(&self) -> Vec<Arc<dyn ContextServer>> {
271        self.servers.values().cloned().collect()
272    }
273
274    pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext<Self>) {
275        let current_servers = self
276            .servers()
277            .into_iter()
278            .map(|server| (server.id(), server.config()))
279            .collect::<HashMap<_, _>>();
280
281        let new_servers = settings
282            .context_servers
283            .iter()
284            .map(|(id, config)| (id.clone(), config.clone()))
285            .collect::<HashMap<_, _>>();
286
287        let servers_to_add = new_servers
288            .iter()
289            .filter(|(id, _)| !current_servers.contains_key(id.as_ref()))
290            .map(|(id, config)| (id.clone(), config.clone()))
291            .collect::<Vec<_>>();
292
293        let servers_to_remove = current_servers
294            .keys()
295            .filter(|id| !new_servers.contains_key(id.as_ref()))
296            .cloned()
297            .collect::<Vec<_>>();
298
299        log::trace!("servers_to_add={:?}", servers_to_add);
300        for (id, config) in servers_to_add {
301            if config.command.is_some() {
302                let server = Arc::new(NativeContextServer::new(id, Arc::new(config)));
303                self.add_server(server, cx).detach_and_log_err(cx);
304            }
305        }
306
307        for id in servers_to_remove {
308            self.remove_server(&id, cx).detach_and_log_err(cx);
309        }
310    }
311}