manager.rs

  1//! This module implements a context server management system for Zed.
  2//!
  3//! It provides functionality to:
  4//! - Define and load context server settings
  5//! - Manage individual context servers (start, stop, restart)
  6//! - Maintain a global manager for all context servers
  7//!
  8//! Key components:
  9//! - `ContextServerSettings`: Defines the structure for server configurations
 10//! - `ContextServer`: Represents an individual context server
 11//! - `ContextServerManager`: Manages multiple context servers
 12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
 13//!
 14//! The module also includes initialization logic to set up the context server system
 15//! and react to changes in settings.
 16
 17use std::path::Path;
 18use std::pin::Pin;
 19use std::sync::Arc;
 20
 21use anyhow::Result;
 22use async_trait::async_trait;
 23use collections::{HashMap, HashSet};
 24use futures::{Future, FutureExt};
 25use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
 26use log;
 27use parking_lot::RwLock;
 28use schemars::JsonSchema;
 29use serde::{Deserialize, Serialize};
 30use settings::{Settings, SettingsSources};
 31
 32use crate::{
 33    client::{self, Client},
 34    types,
 35};
 36
 37#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
 38pub struct ContextServerSettings {
 39    pub servers: Vec<ServerConfig>,
 40}
 41
 42#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
 43pub struct ServerConfig {
 44    pub id: String,
 45    pub executable: String,
 46    pub args: Vec<String>,
 47    pub env: Option<HashMap<String, String>>,
 48}
 49
 50impl Settings for ContextServerSettings {
 51    const KEY: Option<&'static str> = Some("experimental.context_servers");
 52
 53    type FileContent = Self;
 54
 55    fn load(
 56        sources: SettingsSources<Self::FileContent>,
 57        _: &mut gpui::AppContext,
 58    ) -> anyhow::Result<Self> {
 59        sources.json_merge()
 60    }
 61}
 62
 63#[async_trait(?Send)]
 64pub trait ContextServer: Send + Sync + 'static {
 65    fn id(&self) -> Arc<str>;
 66    fn config(&self) -> Arc<ServerConfig>;
 67    fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>>;
 68    fn start<'a>(
 69        self: Arc<Self>,
 70        cx: &'a AsyncAppContext,
 71    ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>>;
 72    fn stop(&self) -> Result<()>;
 73}
 74
 75pub struct NativeContextServer {
 76    pub id: Arc<str>,
 77    pub config: Arc<ServerConfig>,
 78    pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 79}
 80
 81impl NativeContextServer {
 82    fn new(config: Arc<ServerConfig>) -> Self {
 83        Self {
 84            id: config.id.clone().into(),
 85            config,
 86            client: RwLock::new(None),
 87        }
 88    }
 89}
 90
 91#[async_trait(?Send)]
 92impl ContextServer for NativeContextServer {
 93    fn id(&self) -> Arc<str> {
 94        self.id.clone()
 95    }
 96
 97    fn config(&self) -> Arc<ServerConfig> {
 98        self.config.clone()
 99    }
100
101    fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
102        self.client.read().clone()
103    }
104
105    fn start<'a>(
106        self: Arc<Self>,
107        cx: &'a AsyncAppContext,
108    ) -> Pin<Box<dyn 'a + Future<Output = Result<()>>>> {
109        async move {
110            log::info!("starting context server {}", self.config.id,);
111            let client = Client::new(
112                client::ContextServerId(self.config.id.clone()),
113                client::ModelContextServerBinary {
114                    executable: Path::new(&self.config.executable).to_path_buf(),
115                    args: self.config.args.clone(),
116                    env: self.config.env.clone(),
117                },
118                cx.clone(),
119            )?;
120
121            let protocol = crate::protocol::ModelContextProtocol::new(client);
122            let client_info = types::Implementation {
123                name: "Zed".to_string(),
124                version: env!("CARGO_PKG_VERSION").to_string(),
125            };
126            let initialized_protocol = protocol.initialize(client_info).await?;
127
128            log::debug!(
129                "context server {} initialized: {:?}",
130                self.config.id,
131                initialized_protocol.initialize,
132            );
133
134            *self.client.write() = Some(Arc::new(initialized_protocol));
135            Ok(())
136        }
137        .boxed_local()
138    }
139
140    fn stop(&self) -> Result<()> {
141        let mut client = self.client.write();
142        if let Some(protocol) = client.take() {
143            drop(protocol);
144        }
145        Ok(())
146    }
147}
148
149/// A Context server manager manages the starting and stopping
150/// of all servers. To obtain a server to interact with, a crate
151/// must go through the `GlobalContextServerManager` which holds
152/// a model to the ContextServerManager.
153pub struct ContextServerManager {
154    servers: HashMap<String, Arc<dyn ContextServer>>,
155    pending_servers: HashSet<String>,
156}
157
158pub enum Event {
159    ServerStarted { server_id: String },
160    ServerStopped { server_id: String },
161}
162
163impl EventEmitter<Event> for ContextServerManager {}
164
165impl Default for ContextServerManager {
166    fn default() -> Self {
167        Self::new()
168    }
169}
170
171impl ContextServerManager {
172    pub fn new() -> Self {
173        Self {
174            servers: HashMap::default(),
175            pending_servers: HashSet::default(),
176        }
177    }
178
179    pub fn add_server(
180        &mut self,
181        config: Arc<ServerConfig>,
182        cx: &ModelContext<Self>,
183    ) -> Task<anyhow::Result<()>> {
184        let server_id = config.id.clone();
185
186        if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
187            return Task::ready(Ok(()));
188        }
189
190        let task = {
191            let server_id = server_id.clone();
192            cx.spawn(|this, mut cx| async move {
193                let server = Arc::new(NativeContextServer::new(config));
194                server.clone().start(&cx).await?;
195                this.update(&mut cx, |this, cx| {
196                    this.servers.insert(server_id.clone(), server);
197                    this.pending_servers.remove(&server_id);
198                    cx.emit(Event::ServerStarted {
199                        server_id: server_id.clone(),
200                    });
201                })?;
202                Ok(())
203            })
204        };
205
206        self.pending_servers.insert(server_id);
207        task
208    }
209
210    pub fn get_server(&self, id: &str) -> Option<Arc<dyn ContextServer>> {
211        self.servers.get(id).cloned()
212    }
213
214    pub fn remove_server(&mut self, id: &str, cx: &ModelContext<Self>) -> Task<anyhow::Result<()>> {
215        let id = id.to_string();
216        cx.spawn(|this, mut cx| async move {
217            if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
218                server.stop()?;
219            }
220            this.update(&mut cx, |this, cx| {
221                this.pending_servers.remove(&id);
222                cx.emit(Event::ServerStopped {
223                    server_id: id.clone(),
224                })
225            })?;
226            Ok(())
227        })
228    }
229
230    pub fn restart_server(
231        &mut self,
232        id: &Arc<str>,
233        cx: &mut ModelContext<Self>,
234    ) -> Task<anyhow::Result<()>> {
235        let id = id.to_string();
236        cx.spawn(|this, mut cx| async move {
237            if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
238                server.stop()?;
239                let config = server.config();
240                let new_server = Arc::new(NativeContextServer::new(config));
241                new_server.clone().start(&cx).await?;
242                this.update(&mut cx, |this, cx| {
243                    this.servers.insert(id.clone(), new_server);
244                    cx.emit(Event::ServerStopped {
245                        server_id: id.clone(),
246                    });
247                    cx.emit(Event::ServerStarted {
248                        server_id: id.clone(),
249                    });
250                })?;
251            }
252            Ok(())
253        })
254    }
255
256    pub fn servers(&self) -> Vec<Arc<dyn ContextServer>> {
257        self.servers.values().cloned().collect()
258    }
259
260    pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext<Self>) {
261        let current_servers = self
262            .servers()
263            .into_iter()
264            .map(|server| (server.id(), server.config()))
265            .collect::<HashMap<_, _>>();
266
267        let new_servers = settings
268            .servers
269            .iter()
270            .map(|config| (config.id.clone(), config.clone()))
271            .collect::<HashMap<_, _>>();
272
273        let servers_to_add = new_servers
274            .values()
275            .filter(|config| !current_servers.contains_key(config.id.as_str()))
276            .cloned()
277            .collect::<Vec<_>>();
278
279        let servers_to_remove = current_servers
280            .keys()
281            .filter(|id| !new_servers.contains_key(id.as_ref()))
282            .cloned()
283            .collect::<Vec<_>>();
284
285        log::trace!("servers_to_add={:?}", servers_to_add);
286        for config in servers_to_add {
287            self.add_server(Arc::new(config), cx).detach_and_log_err(cx);
288        }
289
290        for id in servers_to_remove {
291            self.remove_server(&id, cx).detach_and_log_err(cx);
292        }
293    }
294}