manager.rs

  1//! This module implements a context server management system for Zed.
  2//!
  3//! It provides functionality to:
  4//! - Define and load context server settings
  5//! - Manage individual context servers (start, stop, restart)
  6//! - Maintain a global manager for all context servers
  7//!
  8//! Key components:
  9//! - `ContextServerSettings`: Defines the structure for server configurations
 10//! - `ContextServer`: Represents an individual context server
 11//! - `ContextServerManager`: Manages multiple context servers
 12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
 13//!
 14//! The module also includes initialization logic to set up the context server system
 15//! and react to changes in settings.
 16
 17use collections::{HashMap, HashSet};
 18use command_palette_hooks::CommandPaletteFilter;
 19use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Global, Model, ModelContext, Task};
 20use log;
 21use parking_lot::RwLock;
 22use schemars::JsonSchema;
 23use serde::{Deserialize, Serialize};
 24use settings::{Settings, SettingsSources, SettingsStore};
 25use std::path::Path;
 26use std::sync::Arc;
 27
 28use crate::CONTEXT_SERVERS_NAMESPACE;
 29use crate::{
 30    client::{self, Client},
 31    types,
 32};
 33
 34#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
 35pub struct ContextServerSettings {
 36    pub servers: Vec<ServerConfig>,
 37}
 38
 39#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
 40pub struct ServerConfig {
 41    pub id: String,
 42    pub executable: String,
 43    pub args: Vec<String>,
 44    pub env: Option<HashMap<String, String>>,
 45}
 46
 47impl Settings for ContextServerSettings {
 48    const KEY: Option<&'static str> = Some("experimental.context_servers");
 49
 50    type FileContent = Self;
 51
 52    fn load(
 53        sources: SettingsSources<Self::FileContent>,
 54        _: &mut gpui::AppContext,
 55    ) -> anyhow::Result<Self> {
 56        sources.json_merge()
 57    }
 58}
 59
 60pub struct ContextServer {
 61    pub id: String,
 62    pub config: ServerConfig,
 63    pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 64}
 65
 66impl ContextServer {
 67    fn new(config: ServerConfig) -> Self {
 68        Self {
 69            id: config.id.clone(),
 70            config,
 71            client: RwLock::new(None),
 72        }
 73    }
 74
 75    async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> {
 76        log::info!("starting context server {}", self.config.id,);
 77        let client = Client::new(
 78            client::ContextServerId(self.config.id.clone()),
 79            client::ModelContextServerBinary {
 80                executable: Path::new(&self.config.executable).to_path_buf(),
 81                args: self.config.args.clone(),
 82                env: self.config.env.clone(),
 83            },
 84            cx.clone(),
 85        )?;
 86
 87        let protocol = crate::protocol::ModelContextProtocol::new(client);
 88        let client_info = types::Implementation {
 89            name: "Zed".to_string(),
 90            version: env!("CARGO_PKG_VERSION").to_string(),
 91        };
 92        let initialized_protocol = protocol.initialize(client_info).await?;
 93
 94        log::debug!(
 95            "context server {} initialized: {:?}",
 96            self.config.id,
 97            initialized_protocol.initialize,
 98        );
 99
100        *self.client.write() = Some(Arc::new(initialized_protocol));
101        Ok(())
102    }
103
104    async fn stop(&self) -> anyhow::Result<()> {
105        let mut client = self.client.write();
106        if let Some(protocol) = client.take() {
107            drop(protocol);
108        }
109        Ok(())
110    }
111}
112
113/// A Context server manager manages the starting and stopping
114/// of all servers. To obtain a server to interact with, a crate
115/// must go through the `GlobalContextServerManager` which holds
116/// a model to the ContextServerManager.
117pub struct ContextServerManager {
118    servers: HashMap<String, Arc<ContextServer>>,
119    pending_servers: HashSet<String>,
120}
121
122pub enum Event {
123    ServerStarted { server_id: String },
124    ServerStopped { server_id: String },
125}
126
127impl Global for ContextServerManager {}
128impl EventEmitter<Event> for ContextServerManager {}
129
130impl Default for ContextServerManager {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136impl ContextServerManager {
137    pub fn new() -> Self {
138        Self {
139            servers: HashMap::default(),
140            pending_servers: HashSet::default(),
141        }
142    }
143    pub fn global(cx: &AppContext) -> Model<Self> {
144        cx.global::<GlobalContextServerManager>().0.clone()
145    }
146
147    pub fn add_server(
148        &mut self,
149        config: ServerConfig,
150        cx: &mut ModelContext<Self>,
151    ) -> Task<anyhow::Result<()>> {
152        let server_id = config.id.clone();
153
154        if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
155            return Task::ready(Ok(()));
156        }
157
158        let task = {
159            let server_id = server_id.clone();
160            cx.spawn(|this, mut cx| async move {
161                let server = Arc::new(ContextServer::new(config));
162                server.start(&cx).await?;
163                this.update(&mut cx, |this, cx| {
164                    this.servers.insert(server_id.clone(), server);
165                    this.pending_servers.remove(&server_id);
166                    cx.emit(Event::ServerStarted {
167                        server_id: server_id.clone(),
168                    });
169                })?;
170                Ok(())
171            })
172        };
173
174        self.pending_servers.insert(server_id);
175        task
176    }
177
178    pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
179        self.servers.get(id).cloned()
180    }
181
182    pub fn remove_server(
183        &mut self,
184        id: &str,
185        cx: &mut ModelContext<Self>,
186    ) -> Task<anyhow::Result<()>> {
187        let id = id.to_string();
188        cx.spawn(|this, mut cx| async move {
189            if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
190                server.stop().await?;
191            }
192            this.update(&mut cx, |this, cx| {
193                this.pending_servers.remove(&id);
194                cx.emit(Event::ServerStopped {
195                    server_id: id.clone(),
196                })
197            })?;
198            Ok(())
199        })
200    }
201
202    pub fn restart_server(
203        &mut self,
204        id: &str,
205        cx: &mut ModelContext<Self>,
206    ) -> Task<anyhow::Result<()>> {
207        let id = id.to_string();
208        cx.spawn(|this, mut cx| async move {
209            if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
210                server.stop().await?;
211                let config = server.config.clone();
212                let new_server = Arc::new(ContextServer::new(config));
213                new_server.start(&cx).await?;
214                this.update(&mut cx, |this, cx| {
215                    this.servers.insert(id.clone(), new_server);
216                    cx.emit(Event::ServerStopped {
217                        server_id: id.clone(),
218                    });
219                    cx.emit(Event::ServerStarted {
220                        server_id: id.clone(),
221                    });
222                })?;
223            }
224            Ok(())
225        })
226    }
227
228    pub fn servers(&self) -> Vec<Arc<ContextServer>> {
229        self.servers.values().cloned().collect()
230    }
231
232    pub fn model(cx: &mut AppContext) -> Model<Self> {
233        cx.new_model(|_cx| ContextServerManager::new())
234    }
235}
236
237pub struct GlobalContextServerManager(Model<ContextServerManager>);
238impl Global for GlobalContextServerManager {}
239
240impl GlobalContextServerManager {
241    fn register(cx: &mut AppContext) {
242        let model = ContextServerManager::model(cx);
243        cx.set_global(Self(model));
244    }
245}
246
247pub fn init(cx: &mut AppContext) {
248    ContextServerSettings::register(cx);
249    GlobalContextServerManager::register(cx);
250
251    CommandPaletteFilter::update_global(cx, |filter, _cx| {
252        filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
253    });
254
255    cx.observe_global::<SettingsStore>(|cx| {
256        let manager = ContextServerManager::global(cx);
257        cx.update_model(&manager, |manager, cx| {
258            let settings = ContextServerSettings::get_global(cx);
259            let current_servers = manager
260                .servers()
261                .into_iter()
262                .map(|server| (server.id.clone(), server.config.clone()))
263                .collect::<HashMap<_, _>>();
264
265            let new_servers = settings
266                .servers
267                .iter()
268                .map(|config| (config.id.clone(), config.clone()))
269                .collect::<HashMap<_, _>>();
270
271            let servers_to_add = new_servers
272                .values()
273                .filter(|config| !current_servers.contains_key(&config.id))
274                .cloned()
275                .collect::<Vec<_>>();
276
277            let servers_to_remove = current_servers
278                .keys()
279                .filter(|id| !new_servers.contains_key(*id))
280                .cloned()
281                .collect::<Vec<_>>();
282
283            log::trace!("servers_to_add={:?}", servers_to_add);
284            for config in servers_to_add {
285                manager.add_server(config, cx).detach_and_log_err(cx);
286            }
287
288            for id in servers_to_remove {
289                manager.remove_server(&id, cx).detach_and_log_err(cx);
290            }
291
292            let has_any_context_servers = !manager.servers().is_empty();
293            CommandPaletteFilter::update_global(cx, |filter, _cx| {
294                if has_any_context_servers {
295                    filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
296                } else {
297                    filter.hide_namespace(CONTEXT_SERVERS_NAMESPACE);
298                }
299            });
300        })
301    })
302    .detach();
303}