manager.rs

  1//! This module implements a context server management system for Zed.
  2//!
  3//! It provides functionality to:
  4//! - Define and load context server settings
  5//! - Manage individual context servers (start, stop, restart)
  6//! - Maintain a global manager for all context servers
  7//!
  8//! Key components:
  9//! - `ContextServerSettings`: Defines the structure for server configurations
 10//! - `ContextServer`: Represents an individual context server
 11//! - `ContextServerManager`: Manages multiple context servers
 12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
 13//!
 14//! The module also includes initialization logic to set up the context server system
 15//! and react to changes in settings.
 16
 17use std::path::Path;
 18use std::sync::Arc;
 19
 20use collections::{HashMap, HashSet};
 21use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
 22use log;
 23use parking_lot::RwLock;
 24use schemars::JsonSchema;
 25use serde::{Deserialize, Serialize};
 26use settings::{Settings, SettingsSources};
 27
 28use crate::{
 29    client::{self, Client},
 30    types,
 31};
 32
 33#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
 34pub struct ContextServerSettings {
 35    pub servers: Vec<ServerConfig>,
 36}
 37
 38#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
 39pub struct ServerConfig {
 40    pub id: String,
 41    pub executable: String,
 42    pub args: Vec<String>,
 43    pub env: Option<HashMap<String, String>>,
 44}
 45
 46impl Settings for ContextServerSettings {
 47    const KEY: Option<&'static str> = Some("experimental.context_servers");
 48
 49    type FileContent = Self;
 50
 51    fn load(
 52        sources: SettingsSources<Self::FileContent>,
 53        _: &mut gpui::AppContext,
 54    ) -> anyhow::Result<Self> {
 55        sources.json_merge()
 56    }
 57}
 58
 59pub struct ContextServer {
 60    pub id: String,
 61    pub config: ServerConfig,
 62    pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 63}
 64
 65impl ContextServer {
 66    fn new(config: ServerConfig) -> Self {
 67        Self {
 68            id: config.id.clone(),
 69            config,
 70            client: RwLock::new(None),
 71        }
 72    }
 73
 74    async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> {
 75        log::info!("starting context server {}", self.config.id,);
 76        let client = Client::new(
 77            client::ContextServerId(self.config.id.clone()),
 78            client::ModelContextServerBinary {
 79                executable: Path::new(&self.config.executable).to_path_buf(),
 80                args: self.config.args.clone(),
 81                env: self.config.env.clone(),
 82            },
 83            cx.clone(),
 84        )?;
 85
 86        let protocol = crate::protocol::ModelContextProtocol::new(client);
 87        let client_info = types::Implementation {
 88            name: "Zed".to_string(),
 89            version: env!("CARGO_PKG_VERSION").to_string(),
 90        };
 91        let initialized_protocol = protocol.initialize(client_info).await?;
 92
 93        log::debug!(
 94            "context server {} initialized: {:?}",
 95            self.config.id,
 96            initialized_protocol.initialize,
 97        );
 98
 99        *self.client.write() = Some(Arc::new(initialized_protocol));
100        Ok(())
101    }
102
103    async fn stop(&self) -> anyhow::Result<()> {
104        let mut client = self.client.write();
105        if let Some(protocol) = client.take() {
106            drop(protocol);
107        }
108        Ok(())
109    }
110}
111
112/// A Context server manager manages the starting and stopping
113/// of all servers. To obtain a server to interact with, a crate
114/// must go through the `GlobalContextServerManager` which holds
115/// a model to the ContextServerManager.
116pub struct ContextServerManager {
117    servers: HashMap<String, Arc<ContextServer>>,
118    pending_servers: HashSet<String>,
119}
120
121pub enum Event {
122    ServerStarted { server_id: String },
123    ServerStopped { server_id: String },
124}
125
126impl EventEmitter<Event> for ContextServerManager {}
127
128impl Default for ContextServerManager {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl ContextServerManager {
135    pub fn new() -> Self {
136        Self {
137            servers: HashMap::default(),
138            pending_servers: HashSet::default(),
139        }
140    }
141
142    pub fn add_server(
143        &mut self,
144        config: ServerConfig,
145        cx: &ModelContext<Self>,
146    ) -> Task<anyhow::Result<()>> {
147        let server_id = config.id.clone();
148
149        if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
150            return Task::ready(Ok(()));
151        }
152
153        let task = {
154            let server_id = server_id.clone();
155            cx.spawn(|this, mut cx| async move {
156                let server = Arc::new(ContextServer::new(config));
157                server.start(&cx).await?;
158                this.update(&mut cx, |this, cx| {
159                    this.servers.insert(server_id.clone(), server);
160                    this.pending_servers.remove(&server_id);
161                    cx.emit(Event::ServerStarted {
162                        server_id: server_id.clone(),
163                    });
164                })?;
165                Ok(())
166            })
167        };
168
169        self.pending_servers.insert(server_id);
170        task
171    }
172
173    pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
174        self.servers.get(id).cloned()
175    }
176
177    pub fn remove_server(&mut self, id: &str, cx: &ModelContext<Self>) -> Task<anyhow::Result<()>> {
178        let id = id.to_string();
179        cx.spawn(|this, mut cx| async move {
180            if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
181                server.stop().await?;
182            }
183            this.update(&mut cx, |this, cx| {
184                this.pending_servers.remove(&id);
185                cx.emit(Event::ServerStopped {
186                    server_id: id.clone(),
187                })
188            })?;
189            Ok(())
190        })
191    }
192
193    pub fn restart_server(
194        &mut self,
195        id: &str,
196        cx: &mut ModelContext<Self>,
197    ) -> Task<anyhow::Result<()>> {
198        let id = id.to_string();
199        cx.spawn(|this, mut cx| async move {
200            if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
201                server.stop().await?;
202                let config = server.config.clone();
203                let new_server = Arc::new(ContextServer::new(config));
204                new_server.start(&cx).await?;
205                this.update(&mut cx, |this, cx| {
206                    this.servers.insert(id.clone(), new_server);
207                    cx.emit(Event::ServerStopped {
208                        server_id: id.clone(),
209                    });
210                    cx.emit(Event::ServerStarted {
211                        server_id: id.clone(),
212                    });
213                })?;
214            }
215            Ok(())
216        })
217    }
218
219    pub fn servers(&self) -> Vec<Arc<ContextServer>> {
220        self.servers.values().cloned().collect()
221    }
222
223    pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext<Self>) {
224        let current_servers = self
225            .servers()
226            .into_iter()
227            .map(|server| (server.id.clone(), server.config.clone()))
228            .collect::<HashMap<_, _>>();
229
230        let new_servers = settings
231            .servers
232            .iter()
233            .map(|config| (config.id.clone(), config.clone()))
234            .collect::<HashMap<_, _>>();
235
236        let servers_to_add = new_servers
237            .values()
238            .filter(|config| !current_servers.contains_key(&config.id))
239            .cloned()
240            .collect::<Vec<_>>();
241
242        let servers_to_remove = current_servers
243            .keys()
244            .filter(|id| !new_servers.contains_key(*id))
245            .cloned()
246            .collect::<Vec<_>>();
247
248        log::trace!("servers_to_add={:?}", servers_to_add);
249        for config in servers_to_add {
250            self.add_server(config, cx).detach_and_log_err(cx);
251        }
252
253        for id in servers_to_remove {
254            self.remove_server(&id, cx).detach_and_log_err(cx);
255        }
256    }
257}