manager.rs

  1//! This module implements a context server management system for Zed.
  2//!
  3//! It provides functionality to:
  4//! - Define and load context server settings
  5//! - Manage individual context servers (start, stop, restart)
  6//! - Maintain a global manager for all context servers
  7//!
  8//! Key components:
  9//! - `ContextServerSettings`: Defines the structure for server configurations
 10//! - `ContextServer`: Represents an individual context server
 11//! - `ContextServerManager`: Manages multiple context servers
 12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
 13//!
 14//! The module also includes initialization logic to set up the context server system
 15//! and react to changes in settings.
 16
 17use std::path::Path;
 18use std::sync::Arc;
 19
 20use anyhow::{bail, Result};
 21use collections::HashMap;
 22use command_palette_hooks::CommandPaletteFilter;
 23use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel};
 24use log;
 25use parking_lot::RwLock;
 26use project::Project;
 27use schemars::gen::SchemaGenerator;
 28use schemars::schema::{InstanceType, Schema, SchemaObject};
 29use schemars::JsonSchema;
 30use serde::{Deserialize, Serialize};
 31use settings::{Settings, SettingsSources, SettingsStore};
 32use util::ResultExt as _;
 33
 34use crate::{
 35    client::{self, Client},
 36    types, ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE,
 37};
 38
 39#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
 40pub struct ContextServerSettings {
 41    /// Settings for context servers used in the Assistant.
 42    #[serde(default)]
 43    pub context_servers: HashMap<Arc<str>, ServerConfig>,
 44}
 45
 46#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
 47pub struct ServerConfig {
 48    /// The command to run this context server.
 49    ///
 50    /// This will override the command set by an extension.
 51    pub command: Option<ServerCommand>,
 52    /// The settings for this context server.
 53    ///
 54    /// Consult the documentation for the context server to see what settings
 55    /// are supported.
 56    #[schemars(schema_with = "server_config_settings_json_schema")]
 57    pub settings: Option<serde_json::Value>,
 58}
 59
 60fn server_config_settings_json_schema(_generator: &mut SchemaGenerator) -> Schema {
 61    Schema::Object(SchemaObject {
 62        instance_type: Some(InstanceType::Object.into()),
 63        ..Default::default()
 64    })
 65}
 66
 67#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
 68pub struct ServerCommand {
 69    pub path: String,
 70    pub args: Vec<String>,
 71    pub env: Option<HashMap<String, String>>,
 72}
 73
 74impl Settings for ContextServerSettings {
 75    const KEY: Option<&'static str> = None;
 76
 77    type FileContent = Self;
 78
 79    fn load(
 80        sources: SettingsSources<Self::FileContent>,
 81        _: &mut gpui::AppContext,
 82    ) -> anyhow::Result<Self> {
 83        sources.json_merge()
 84    }
 85}
 86
 87pub struct ContextServer {
 88    pub id: Arc<str>,
 89    pub config: Arc<ServerConfig>,
 90    pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 91}
 92
 93impl ContextServer {
 94    pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
 95        Self {
 96            id,
 97            config,
 98            client: RwLock::new(None),
 99        }
100    }
101
102    pub fn id(&self) -> Arc<str> {
103        self.id.clone()
104    }
105
106    pub fn config(&self) -> Arc<ServerConfig> {
107        self.config.clone()
108    }
109
110    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
111        self.client.read().clone()
112    }
113
114    pub async fn start(self: Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
115        log::info!("starting context server {}", self.id);
116        let Some(command) = &self.config.command else {
117            bail!("no command specified for server {}", self.id);
118        };
119        let client = Client::new(
120            client::ContextServerId(self.id.clone()),
121            client::ModelContextServerBinary {
122                executable: Path::new(&command.path).to_path_buf(),
123                args: command.args.clone(),
124                env: command.env.clone(),
125            },
126            cx.clone(),
127        )?;
128
129        let protocol = crate::protocol::ModelContextProtocol::new(client);
130        let client_info = types::Implementation {
131            name: "Zed".to_string(),
132            version: env!("CARGO_PKG_VERSION").to_string(),
133        };
134        let initialized_protocol = protocol.initialize(client_info).await?;
135
136        log::debug!(
137            "context server {} initialized: {:?}",
138            self.id,
139            initialized_protocol.initialize,
140        );
141
142        *self.client.write() = Some(Arc::new(initialized_protocol));
143        Ok(())
144    }
145
146    pub fn stop(&self) -> Result<()> {
147        let mut client = self.client.write();
148        if let Some(protocol) = client.take() {
149            drop(protocol);
150        }
151        Ok(())
152    }
153}
154
155pub struct ContextServerManager {
156    servers: HashMap<Arc<str>, Arc<ContextServer>>,
157    project: Model<Project>,
158    registry: Model<ContextServerFactoryRegistry>,
159    update_servers_task: Option<Task<Result<()>>>,
160    needs_server_update: bool,
161    _subscriptions: Vec<Subscription>,
162}
163
164pub enum Event {
165    ServerStarted { server_id: Arc<str> },
166    ServerStopped { server_id: Arc<str> },
167}
168
169impl EventEmitter<Event> for ContextServerManager {}
170
171impl ContextServerManager {
172    pub fn new(
173        registry: Model<ContextServerFactoryRegistry>,
174        project: Model<Project>,
175        cx: &mut ModelContext<Self>,
176    ) -> Self {
177        let mut this = Self {
178            _subscriptions: vec![
179                cx.observe(&registry, |this, _registry, cx| {
180                    this.available_context_servers_changed(cx);
181                }),
182                cx.observe_global::<SettingsStore>(|this, cx| {
183                    this.available_context_servers_changed(cx);
184                }),
185            ],
186            project,
187            registry,
188            needs_server_update: false,
189            servers: HashMap::default(),
190            update_servers_task: None,
191        };
192        this.available_context_servers_changed(cx);
193        this
194    }
195
196    fn available_context_servers_changed(&mut self, cx: &mut ModelContext<Self>) {
197        if self.update_servers_task.is_some() {
198            self.needs_server_update = true;
199        } else {
200            self.update_servers_task = Some(cx.spawn(|this, mut cx| async move {
201                this.update(&mut cx, |this, _| {
202                    this.needs_server_update = false;
203                })?;
204
205                Self::maintain_servers(this.clone(), cx.clone()).await?;
206
207                this.update(&mut cx, |this, cx| {
208                    let has_any_context_servers = !this.servers().is_empty();
209                    if has_any_context_servers {
210                        CommandPaletteFilter::update_global(cx, |filter, _cx| {
211                            filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
212                        });
213                    }
214
215                    this.update_servers_task.take();
216                    if this.needs_server_update {
217                        this.available_context_servers_changed(cx);
218                    }
219                })?;
220
221                Ok(())
222            }));
223        }
224    }
225
226    pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
227        self.servers
228            .get(id)
229            .filter(|server| server.client().is_some())
230            .cloned()
231    }
232
233    pub fn restart_server(
234        &mut self,
235        id: &Arc<str>,
236        cx: &mut ModelContext<Self>,
237    ) -> Task<anyhow::Result<()>> {
238        let id = id.clone();
239        cx.spawn(|this, mut cx| async move {
240            if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
241                server.stop()?;
242                let config = server.config();
243                let new_server = Arc::new(ContextServer::new(id.clone(), config));
244                new_server.clone().start(&cx).await?;
245                this.update(&mut cx, |this, cx| {
246                    this.servers.insert(id.clone(), new_server);
247                    cx.emit(Event::ServerStopped {
248                        server_id: id.clone(),
249                    });
250                    cx.emit(Event::ServerStarted {
251                        server_id: id.clone(),
252                    });
253                })?;
254            }
255            Ok(())
256        })
257    }
258
259    pub fn servers(&self) -> Vec<Arc<ContextServer>> {
260        self.servers
261            .values()
262            .filter(|server| server.client().is_some())
263            .cloned()
264            .collect()
265    }
266
267    async fn maintain_servers(this: WeakModel<Self>, mut cx: AsyncAppContext) -> Result<()> {
268        let mut desired_servers = HashMap::default();
269
270        let (registry, project) = this.update(&mut cx, |this, cx| {
271            let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
272                settings::SettingsLocation {
273                    worktree_id: worktree.read(cx).id(),
274                    path: Path::new(""),
275                }
276            });
277            let settings = ContextServerSettings::get(location, cx);
278            desired_servers = settings.context_servers.clone();
279
280            (this.registry.clone(), this.project.clone())
281        })?;
282
283        for (id, factory) in
284            registry.read_with(&cx, |registry, _| registry.context_server_factories())?
285        {
286            let config = desired_servers.entry(id).or_default();
287            if config.command.is_none() {
288                if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
289                    config.command = Some(extension_command);
290                }
291            }
292        }
293
294        let mut servers_to_start = HashMap::default();
295        let mut servers_to_stop = HashMap::default();
296
297        this.update(&mut cx, |this, _cx| {
298            this.servers.retain(|id, server| {
299                if desired_servers.contains_key(id) {
300                    true
301                } else {
302                    servers_to_stop.insert(id.clone(), server.clone());
303                    false
304                }
305            });
306
307            for (id, config) in desired_servers {
308                let existing_config = this.servers.get(&id).map(|server| server.config());
309                if existing_config.as_deref() != Some(&config) {
310                    let config = Arc::new(config);
311                    let server = Arc::new(ContextServer::new(id.clone(), config));
312                    servers_to_start.insert(id.clone(), server.clone());
313                    let old_server = this.servers.insert(id.clone(), server);
314                    if let Some(old_server) = old_server {
315                        servers_to_stop.insert(id, old_server);
316                    }
317                }
318            }
319        })?;
320
321        for (id, server) in servers_to_stop {
322            server.stop().log_err();
323            this.update(&mut cx, |_, cx| {
324                cx.emit(Event::ServerStopped { server_id: id })
325            })?;
326        }
327
328        for (id, server) in servers_to_start {
329            if server.start(&cx).await.log_err().is_some() {
330                this.update(&mut cx, |_, cx| {
331                    cx.emit(Event::ServerStarted { server_id: id })
332                })?;
333            }
334        }
335
336        Ok(())
337    }
338}