manager.rs

  1//! This module implements a context server management system for Zed.
  2//!
  3//! It provides functionality to:
  4//! - Define and load context server settings
  5//! - Manage individual context servers (start, stop, restart)
  6//! - Maintain a global manager for all context servers
  7//!
  8//! Key components:
  9//! - `ContextServerSettings`: Defines the structure for server configurations
 10//! - `ContextServer`: Represents an individual context server
 11//! - `ContextServerManager`: Manages multiple context servers
 12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
 13//!
 14//! The module also includes initialization logic to set up the context server system
 15//! and react to changes in settings.
 16
 17use collections::{HashMap, HashSet};
 18use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Global, Model, ModelContext, Task};
 19use log;
 20use parking_lot::RwLock;
 21use schemars::JsonSchema;
 22use serde::{Deserialize, Serialize};
 23use settings::{Settings, SettingsSources, SettingsStore};
 24use std::path::Path;
 25use std::sync::Arc;
 26
 27use crate::{
 28    client::{self, Client},
 29    types,
 30};
 31
 32#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
 33pub struct ContextServerSettings {
 34    pub servers: Vec<ServerConfig>,
 35}
 36
 37#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
 38pub struct ServerConfig {
 39    pub id: String,
 40    pub executable: String,
 41    pub args: Vec<String>,
 42}
 43
 44impl Settings for ContextServerSettings {
 45    const KEY: Option<&'static str> = Some("experimental.context_servers");
 46
 47    type FileContent = Self;
 48
 49    fn load(
 50        sources: SettingsSources<Self::FileContent>,
 51        _: &mut gpui::AppContext,
 52    ) -> anyhow::Result<Self> {
 53        sources.json_merge()
 54    }
 55}
 56
 57pub struct ContextServer {
 58    pub id: String,
 59    pub config: ServerConfig,
 60    pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 61}
 62
 63impl ContextServer {
 64    fn new(config: ServerConfig) -> Self {
 65        Self {
 66            id: config.id.clone(),
 67            config,
 68            client: RwLock::new(None),
 69        }
 70    }
 71
 72    async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> {
 73        log::info!("starting context server {}", self.config.id);
 74        let client = Client::new(
 75            client::ContextServerId(self.config.id.clone()),
 76            client::ModelContextServerBinary {
 77                executable: Path::new(&self.config.executable).to_path_buf(),
 78                args: self.config.args.clone(),
 79                env: None,
 80            },
 81            cx.clone(),
 82        )?;
 83
 84        let protocol = crate::protocol::ModelContextProtocol::new(client);
 85        let client_info = types::EntityInfo {
 86            name: "Zed".to_string(),
 87            version: env!("CARGO_PKG_VERSION").to_string(),
 88        };
 89        let initialized_protocol = protocol.initialize(client_info).await?;
 90
 91        log::debug!(
 92            "context server {} initialized: {:?}",
 93            self.config.id,
 94            initialized_protocol.initialize,
 95        );
 96
 97        *self.client.write() = Some(Arc::new(initialized_protocol));
 98        Ok(())
 99    }
100
101    async fn stop(&self) -> anyhow::Result<()> {
102        let mut client = self.client.write();
103        if let Some(protocol) = client.take() {
104            drop(protocol);
105        }
106        Ok(())
107    }
108}
109
110/// A Context server manager manages the starting and stopping
111/// of all servers. To obtain a server to interact with, a crate
112/// must go through the `GlobalContextServerManager` which holds
113/// a model to the ContextServerManager.
114pub struct ContextServerManager {
115    servers: HashMap<String, Arc<ContextServer>>,
116    pending_servers: HashSet<String>,
117}
118
119pub enum Event {
120    ServerStarted { server_id: String },
121    ServerStopped { server_id: String },
122}
123
124impl Global for ContextServerManager {}
125impl EventEmitter<Event> for ContextServerManager {}
126
127impl ContextServerManager {
128    pub fn new() -> Self {
129        Self {
130            servers: HashMap::default(),
131            pending_servers: HashSet::default(),
132        }
133    }
134    pub fn global(cx: &AppContext) -> Model<Self> {
135        cx.global::<GlobalContextServerManager>().0.clone()
136    }
137
138    pub fn add_server(
139        &mut self,
140        config: ServerConfig,
141        cx: &mut ModelContext<Self>,
142    ) -> Task<anyhow::Result<()>> {
143        let server_id = config.id.clone();
144        let server_id2 = config.id.clone();
145
146        if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
147            return Task::ready(Ok(()));
148        }
149
150        let task = cx.spawn(|this, mut cx| async move {
151            let server = Arc::new(ContextServer::new(config));
152            server.start(&cx).await?;
153            this.update(&mut cx, |this, cx| {
154                this.servers.insert(server_id.clone(), server);
155                this.pending_servers.remove(&server_id);
156                cx.emit(Event::ServerStarted {
157                    server_id: server_id.clone(),
158                });
159            })?;
160            Ok(())
161        });
162
163        self.pending_servers.insert(server_id2);
164        task
165    }
166
167    pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
168        self.servers.get(id).cloned()
169    }
170
171    pub fn remove_server(
172        &mut self,
173        id: &str,
174        cx: &mut ModelContext<Self>,
175    ) -> Task<anyhow::Result<()>> {
176        let id = id.to_string();
177        cx.spawn(|this, mut cx| async move {
178            if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
179                server.stop().await?;
180            }
181            this.update(&mut cx, |this, cx| {
182                this.pending_servers.remove(&id);
183                cx.emit(Event::ServerStopped {
184                    server_id: id.clone(),
185                })
186            })?;
187            Ok(())
188        })
189    }
190
191    pub fn restart_server(
192        &mut self,
193        id: &str,
194        cx: &mut ModelContext<Self>,
195    ) -> Task<anyhow::Result<()>> {
196        let id = id.to_string();
197        cx.spawn(|this, mut cx| async move {
198            if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
199                server.stop().await?;
200                let config = server.config.clone();
201                let new_server = Arc::new(ContextServer::new(config));
202                new_server.start(&cx).await?;
203                this.update(&mut cx, |this, cx| {
204                    this.servers.insert(id.clone(), new_server);
205                    cx.emit(Event::ServerStopped {
206                        server_id: id.clone(),
207                    });
208                    cx.emit(Event::ServerStarted {
209                        server_id: id.clone(),
210                    });
211                })?;
212            }
213            Ok(())
214        })
215    }
216
217    pub fn servers(&self) -> Vec<Arc<ContextServer>> {
218        self.servers.values().cloned().collect()
219    }
220
221    pub fn model(cx: &mut AppContext) -> Model<Self> {
222        cx.new_model(|_cx| ContextServerManager::new())
223    }
224}
225
226pub struct GlobalContextServerManager(Model<ContextServerManager>);
227impl Global for GlobalContextServerManager {}
228
229impl GlobalContextServerManager {
230    fn register(cx: &mut AppContext) {
231        let model = ContextServerManager::model(cx);
232        cx.set_global(Self(model));
233    }
234}
235
236pub fn init(cx: &mut AppContext) {
237    ContextServerSettings::register(cx);
238    GlobalContextServerManager::register(cx);
239    cx.observe_global::<SettingsStore>(|cx| {
240        let manager = ContextServerManager::global(cx);
241        cx.update_model(&manager, |manager, cx| {
242            let settings = ContextServerSettings::get_global(cx);
243            let current_servers: HashMap<String, ServerConfig> = manager
244                .servers()
245                .into_iter()
246                .map(|server| (server.id.clone(), server.config.clone()))
247                .collect();
248
249            let new_servers = settings
250                .servers
251                .iter()
252                .map(|config| (config.id.clone(), config.clone()))
253                .collect::<HashMap<_, _>>();
254
255            let servers_to_add = new_servers
256                .values()
257                .filter(|config| !current_servers.contains_key(&config.id))
258                .cloned()
259                .collect::<Vec<_>>();
260
261            let servers_to_remove = current_servers
262                .keys()
263                .filter(|id| !new_servers.contains_key(*id))
264                .cloned()
265                .collect::<Vec<_>>();
266
267            log::trace!("servers_to_add={:?}", servers_to_add);
268            for config in servers_to_add {
269                manager.add_server(config, cx).detach_and_log_err(cx);
270            }
271
272            for id in servers_to_remove {
273                manager.remove_server(&id, cx).detach_and_log_err(cx);
274            }
275        })
276    })
277    .detach();
278}