manager.rs

  1//! This module implements a context server management system for Zed.
  2//!
  3//! It provides functionality to:
  4//! - Define and load context server settings
  5//! - Manage individual context servers (start, stop, restart)
  6//! - Maintain a global manager for all context servers
  7//!
  8//! Key components:
  9//! - `ContextServerSettings`: Defines the structure for server configurations
 10//! - `ContextServer`: Represents an individual context server
 11//! - `ContextServerManager`: Manages multiple context servers
 12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
 13//!
 14//! The module also includes initialization logic to set up the context server system
 15//! and react to changes in settings.
 16
 17use std::path::Path;
 18use std::pin::Pin;
 19use std::sync::Arc;
 20
 21use anyhow::Result;
 22use async_trait::async_trait;
 23use collections::{HashMap, HashSet};
 24use futures::{Future, FutureExt};
 25use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
 26use log;
 27use parking_lot::RwLock;
 28use schemars::JsonSchema;
 29use serde::{Deserialize, Serialize};
 30use settings::{Settings, SettingsSources};
 31
 32use crate::{
 33    client::{self, Client},
 34    types,
 35};
 36
 37#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
 38pub struct ContextServerSettings {
 39    pub servers: Vec<ServerConfig>,
 40}
 41
 42#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
 43pub struct ServerConfig {
 44    pub id: String,
 45    pub executable: String,
 46    pub args: Vec<String>,
 47    pub env: Option<HashMap<String, String>>,
 48}
 49
 50impl Settings for ContextServerSettings {
 51    const KEY: Option<&'static str> = Some("experimental.context_servers");
 52
 53    type FileContent = Self;
 54
 55    fn load(
 56        sources: SettingsSources<Self::FileContent>,
 57        _: &mut gpui::AppContext,
 58    ) -> anyhow::Result<Self> {
 59        sources.json_merge()
 60    }
 61}
 62
 63#[async_trait(?Send)]
 64pub trait ContextServer: Send + Sync + 'static {
 65    fn id(&self) -> Arc<str>;
 66    fn config(&self) -> Arc<ServerConfig>;
 67    fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>>;
 68    fn start<'a>(
 69        self: Arc<Self>,
 70        cx: &'a AsyncAppContext,
 71    ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>>;
 72    fn stop(&self) -> Result<()>;
 73}
 74
 75pub struct NativeContextServer {
 76    pub id: Arc<str>,
 77    pub config: Arc<ServerConfig>,
 78    pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 79}
 80
 81impl NativeContextServer {
 82    pub fn new(config: Arc<ServerConfig>) -> Self {
 83        Self {
 84            id: config.id.clone().into(),
 85            config,
 86            client: RwLock::new(None),
 87        }
 88    }
 89}
 90
 91#[async_trait(?Send)]
 92impl ContextServer for NativeContextServer {
 93    fn id(&self) -> Arc<str> {
 94        self.id.clone()
 95    }
 96
 97    fn config(&self) -> Arc<ServerConfig> {
 98        self.config.clone()
 99    }
100
101    fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
102        self.client.read().clone()
103    }
104
105    fn start<'a>(
106        self: Arc<Self>,
107        cx: &'a AsyncAppContext,
108    ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
109        async move {
110            log::info!("starting context server {}", self.config.id,);
111            let client = Client::new(
112                client::ContextServerId(self.config.id.clone()),
113                client::ModelContextServerBinary {
114                    executable: Path::new(&self.config.executable).to_path_buf(),
115                    args: self.config.args.clone(),
116                    env: self.config.env.clone(),
117                },
118                cx.clone(),
119            )?;
120
121            let protocol = crate::protocol::ModelContextProtocol::new(client);
122            let client_info = types::Implementation {
123                name: "Zed".to_string(),
124                version: env!("CARGO_PKG_VERSION").to_string(),
125            };
126            let initialized_protocol = protocol.initialize(client_info).await?;
127
128            log::debug!(
129                "context server {} initialized: {:?}",
130                self.config.id,
131                initialized_protocol.initialize,
132            );
133
134            *self.client.write() = Some(Arc::new(initialized_protocol));
135            Ok(())
136        }
137        .boxed_local()
138    }
139
140    fn stop(&self) -> Result<()> {
141        let mut client = self.client.write();
142        if let Some(protocol) = client.take() {
143            drop(protocol);
144        }
145        Ok(())
146    }
147}
148
149/// A Context server manager manages the starting and stopping
150/// of all servers. To obtain a server to interact with, a crate
151/// must go through the `GlobalContextServerManager` which holds
152/// a model to the ContextServerManager.
153pub struct ContextServerManager {
154    servers: HashMap<Arc<str>, Arc<dyn ContextServer>>,
155    pending_servers: HashSet<Arc<str>>,
156}
157
158pub enum Event {
159    ServerStarted { server_id: Arc<str> },
160    ServerStopped { server_id: Arc<str> },
161}
162
163impl EventEmitter<Event> for ContextServerManager {}
164
165impl Default for ContextServerManager {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171impl ContextServerManager {
172    pub fn new() -> Self {
173        Self {
174            servers: HashMap::default(),
175            pending_servers: HashSet::default(),
176        }
177    }
178
179    pub fn add_server(
180        &mut self,
181        server: Arc<dyn ContextServer>,
182        cx: &ModelContext<Self>,
183    ) -> Task<anyhow::Result<()>> {
184        let server_id = server.id();
185
186        if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
187            return Task::ready(Ok(()));
188        }
189
190        let task = {
191            let server_id = server_id.clone();
192            cx.spawn(|this, mut cx| async move {
193                server.clone().start(&cx).await?;
194                this.update(&mut cx, |this, cx| {
195                    this.servers.insert(server_id.clone(), server);
196                    this.pending_servers.remove(&server_id);
197                    cx.emit(Event::ServerStarted {
198                        server_id: server_id.clone(),
199                    });
200                })?;
201                Ok(())
202            })
203        };
204
205        self.pending_servers.insert(server_id);
206        task
207    }
208
209    pub fn get_server(&self, id: &str) -> Option<Arc<dyn ContextServer>> {
210        self.servers.get(id).cloned()
211    }
212
213    pub fn remove_server(
214        &mut self,
215        id: &Arc<str>,
216        cx: &ModelContext<Self>,
217    ) -> Task<anyhow::Result<()>> {
218        let id = id.clone();
219        cx.spawn(|this, mut cx| async move {
220            if let Some(server) =
221                this.update(&mut cx, |this, _cx| this.servers.remove(id.as_ref()))?
222            {
223                server.stop()?;
224            }
225            this.update(&mut cx, |this, cx| {
226                this.pending_servers.remove(id.as_ref());
227                cx.emit(Event::ServerStopped {
228                    server_id: id.clone(),
229                })
230            })?;
231            Ok(())
232        })
233    }
234
235    pub fn restart_server(
236        &mut self,
237        id: &Arc<str>,
238        cx: &mut ModelContext<Self>,
239    ) -> Task<anyhow::Result<()>> {
240        let id = id.clone();
241        cx.spawn(|this, mut cx| async move {
242            if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
243                server.stop()?;
244                let config = server.config();
245                let new_server = Arc::new(NativeContextServer::new(config));
246                new_server.clone().start(&cx).await?;
247                this.update(&mut cx, |this, cx| {
248                    this.servers.insert(id.clone(), new_server);
249                    cx.emit(Event::ServerStopped {
250                        server_id: id.clone(),
251                    });
252                    cx.emit(Event::ServerStarted {
253                        server_id: id.clone(),
254                    });
255                })?;
256            }
257            Ok(())
258        })
259    }
260
261    pub fn servers(&self) -> Vec<Arc<dyn ContextServer>> {
262        self.servers.values().cloned().collect()
263    }
264
265    pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext<Self>) {
266        let current_servers = self
267            .servers()
268            .into_iter()
269            .map(|server| (server.id(), server.config()))
270            .collect::<HashMap<_, _>>();
271
272        let new_servers = settings
273            .servers
274            .iter()
275            .map(|config| (config.id.clone(), config.clone()))
276            .collect::<HashMap<_, _>>();
277
278        let servers_to_add = new_servers
279            .values()
280            .filter(|config| !current_servers.contains_key(config.id.as_str()))
281            .cloned()
282            .collect::<Vec<_>>();
283
284        let servers_to_remove = current_servers
285            .keys()
286            .filter(|id| !new_servers.contains_key(id.as_ref()))
287            .cloned()
288            .collect::<Vec<_>>();
289
290        log::trace!("servers_to_add={:?}", servers_to_add);
291        for config in servers_to_add {
292            let server = Arc::new(NativeContextServer::new(Arc::new(config)));
293            self.add_server(server, cx).detach_and_log_err(cx);
294        }
295
296        for id in servers_to_remove {
297            self.remove_server(&id, cx).detach_and_log_err(cx);
298        }
299    }
300}