manager.rs

  1//! This module implements a context server management system for Zed.
  2//!
  3//! It provides functionality to:
  4//! - Define and load context server settings
  5//! - Manage individual context servers (start, stop, restart)
  6//! - Maintain a global manager for all context servers
  7//!
  8//! Key components:
  9//! - `ContextServerSettings`: Defines the structure for server configurations
 10//! - `ContextServer`: Represents an individual context server
 11//! - `ContextServerManager`: Manages multiple context servers
 12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
 13//!
 14//! The module also includes initialization logic to set up the context server system
 15//! and react to changes in settings.
 16
 17use std::path::Path;
 18use std::sync::Arc;
 19
 20use anyhow::{bail, Result};
 21use collections::HashMap;
 22use command_palette_hooks::CommandPaletteFilter;
 23use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel};
 24use log;
 25use parking_lot::RwLock;
 26use project::Project;
 27use schemars::JsonSchema;
 28use serde::{Deserialize, Serialize};
 29use settings::{Settings, SettingsSources, SettingsStore};
 30use util::ResultExt as _;
 31
 32use crate::{
 33    client::{self, Client},
 34    types, ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE,
 35};
 36
 37#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
 38pub struct ContextServerSettings {
 39    #[serde(default)]
 40    pub context_servers: HashMap<Arc<str>, ServerConfig>,
 41}
 42
 43#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
 44pub struct ServerConfig {
 45    pub command: Option<ServerCommand>,
 46    pub settings: Option<serde_json::Value>,
 47}
 48
 49#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
 50pub struct ServerCommand {
 51    pub path: String,
 52    pub args: Vec<String>,
 53    pub env: Option<HashMap<String, String>>,
 54}
 55
 56impl Settings for ContextServerSettings {
 57    const KEY: Option<&'static str> = None;
 58
 59    type FileContent = Self;
 60
 61    fn load(
 62        sources: SettingsSources<Self::FileContent>,
 63        _: &mut gpui::AppContext,
 64    ) -> anyhow::Result<Self> {
 65        sources.json_merge()
 66    }
 67}
 68
 69pub struct ContextServer {
 70    pub id: Arc<str>,
 71    pub config: Arc<ServerConfig>,
 72    pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 73}
 74
 75impl ContextServer {
 76    pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
 77        Self {
 78            id,
 79            config,
 80            client: RwLock::new(None),
 81        }
 82    }
 83
 84    pub fn id(&self) -> Arc<str> {
 85        self.id.clone()
 86    }
 87
 88    pub fn config(&self) -> Arc<ServerConfig> {
 89        self.config.clone()
 90    }
 91
 92    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
 93        self.client.read().clone()
 94    }
 95
 96    pub async fn start(self: Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
 97        log::info!("starting context server {}", self.id);
 98        let Some(command) = &self.config.command else {
 99            bail!("no command specified for server {}", self.id);
100        };
101        let client = Client::new(
102            client::ContextServerId(self.id.clone()),
103            client::ModelContextServerBinary {
104                executable: Path::new(&command.path).to_path_buf(),
105                args: command.args.clone(),
106                env: command.env.clone(),
107            },
108            cx.clone(),
109        )?;
110
111        let protocol = crate::protocol::ModelContextProtocol::new(client);
112        let client_info = types::Implementation {
113            name: "Zed".to_string(),
114            version: env!("CARGO_PKG_VERSION").to_string(),
115        };
116        let initialized_protocol = protocol.initialize(client_info).await?;
117
118        log::debug!(
119            "context server {} initialized: {:?}",
120            self.id,
121            initialized_protocol.initialize,
122        );
123
124        *self.client.write() = Some(Arc::new(initialized_protocol));
125        Ok(())
126    }
127
128    pub fn stop(&self) -> Result<()> {
129        let mut client = self.client.write();
130        if let Some(protocol) = client.take() {
131            drop(protocol);
132        }
133        Ok(())
134    }
135}
136
137pub struct ContextServerManager {
138    servers: HashMap<Arc<str>, Arc<ContextServer>>,
139    project: Model<Project>,
140    registry: Model<ContextServerFactoryRegistry>,
141    update_servers_task: Option<Task<Result<()>>>,
142    needs_server_update: bool,
143    _subscriptions: Vec<Subscription>,
144}
145
146pub enum Event {
147    ServerStarted { server_id: Arc<str> },
148    ServerStopped { server_id: Arc<str> },
149}
150
151impl EventEmitter<Event> for ContextServerManager {}
152
153impl ContextServerManager {
154    pub fn new(
155        registry: Model<ContextServerFactoryRegistry>,
156        project: Model<Project>,
157        cx: &mut ModelContext<Self>,
158    ) -> Self {
159        let mut this = Self {
160            _subscriptions: vec![
161                cx.observe(&registry, |this, _registry, cx| {
162                    this.available_context_servers_changed(cx);
163                }),
164                cx.observe_global::<SettingsStore>(|this, cx| {
165                    this.available_context_servers_changed(cx);
166                }),
167            ],
168            project,
169            registry,
170            needs_server_update: false,
171            servers: HashMap::default(),
172            update_servers_task: None,
173        };
174        this.available_context_servers_changed(cx);
175        this
176    }
177
178    fn available_context_servers_changed(&mut self, cx: &mut ModelContext<Self>) {
179        if self.update_servers_task.is_some() {
180            self.needs_server_update = true;
181        } else {
182            self.update_servers_task = Some(cx.spawn(|this, mut cx| async move {
183                this.update(&mut cx, |this, _| {
184                    this.needs_server_update = false;
185                })?;
186
187                Self::maintain_servers(this.clone(), cx.clone()).await?;
188
189                this.update(&mut cx, |this, cx| {
190                    let has_any_context_servers = !this.servers().is_empty();
191                    if has_any_context_servers {
192                        CommandPaletteFilter::update_global(cx, |filter, _cx| {
193                            filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
194                        });
195                    }
196
197                    this.update_servers_task.take();
198                    if this.needs_server_update {
199                        this.available_context_servers_changed(cx);
200                    }
201                })?;
202
203                Ok(())
204            }));
205        }
206    }
207
208    pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
209        self.servers
210            .get(id)
211            .filter(|server| server.client().is_some())
212            .cloned()
213    }
214
215    pub fn restart_server(
216        &mut self,
217        id: &Arc<str>,
218        cx: &mut ModelContext<Self>,
219    ) -> Task<anyhow::Result<()>> {
220        let id = id.clone();
221        cx.spawn(|this, mut cx| async move {
222            if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
223                server.stop()?;
224                let config = server.config();
225                let new_server = Arc::new(ContextServer::new(id.clone(), config));
226                new_server.clone().start(&cx).await?;
227                this.update(&mut cx, |this, cx| {
228                    this.servers.insert(id.clone(), new_server);
229                    cx.emit(Event::ServerStopped {
230                        server_id: id.clone(),
231                    });
232                    cx.emit(Event::ServerStarted {
233                        server_id: id.clone(),
234                    });
235                })?;
236            }
237            Ok(())
238        })
239    }
240
241    pub fn servers(&self) -> Vec<Arc<ContextServer>> {
242        self.servers
243            .values()
244            .filter(|server| server.client().is_some())
245            .cloned()
246            .collect()
247    }
248
249    async fn maintain_servers(this: WeakModel<Self>, mut cx: AsyncAppContext) -> Result<()> {
250        let mut desired_servers = HashMap::default();
251
252        let (registry, project) = this.update(&mut cx, |this, cx| {
253            let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
254                settings::SettingsLocation {
255                    worktree_id: worktree.read(cx).id(),
256                    path: Path::new(""),
257                }
258            });
259            let settings = ContextServerSettings::get(location, cx);
260            desired_servers = settings.context_servers.clone();
261
262            (this.registry.clone(), this.project.clone())
263        })?;
264
265        for (id, factory) in
266            registry.read_with(&cx, |registry, _| registry.context_server_factories())?
267        {
268            let config = desired_servers.entry(id).or_default();
269            if config.command.is_none() {
270                if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
271                    config.command = Some(extension_command);
272                }
273            }
274        }
275
276        let mut servers_to_start = HashMap::default();
277        let mut servers_to_stop = HashMap::default();
278
279        this.update(&mut cx, |this, _cx| {
280            this.servers.retain(|id, server| {
281                if desired_servers.contains_key(id) {
282                    true
283                } else {
284                    servers_to_stop.insert(id.clone(), server.clone());
285                    false
286                }
287            });
288
289            for (id, config) in desired_servers {
290                let existing_config = this.servers.get(&id).map(|server| server.config());
291                if existing_config.as_deref() != Some(&config) {
292                    let config = Arc::new(config);
293                    let server = Arc::new(ContextServer::new(id.clone(), config));
294                    servers_to_start.insert(id.clone(), server.clone());
295                    let old_server = this.servers.insert(id.clone(), server);
296                    if let Some(old_server) = old_server {
297                        servers_to_stop.insert(id, old_server);
298                    }
299                }
300            }
301        })?;
302
303        for (id, server) in servers_to_stop {
304            server.stop().log_err();
305            this.update(&mut cx, |_, cx| {
306                cx.emit(Event::ServerStopped { server_id: id })
307            })?;
308        }
309
310        for (id, server) in servers_to_start {
311            if server.start(&cx).await.log_err().is_some() {
312                this.update(&mut cx, |_, cx| {
313                    cx.emit(Event::ServerStarted { server_id: id })
314                })?;
315            }
316        }
317
318        Ok(())
319    }
320}