manager.rs

  1//! This module implements a context server management system for Zed.
  2//!
  3//! It provides functionality to:
  4//! - Define and load context server settings
  5//! - Manage individual context servers (start, stop, restart)
  6//! - Maintain a global manager for all context servers
  7//!
  8//! Key components:
  9//! - `ContextServerSettings`: Defines the structure for server configurations
 10//! - `ContextServer`: Represents an individual context server
 11//! - `ContextServerManager`: Manages multiple context servers
 12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
 13//!
 14//! The module also includes initialization logic to set up the context server system
 15//! and react to changes in settings.
 16
 17use std::path::Path;
 18use std::sync::Arc;
 19
 20use anyhow::{bail, Result};
 21use collections::HashMap;
 22use command_palette_hooks::CommandPaletteFilter;
 23use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel};
 24use log;
 25use parking_lot::RwLock;
 26use project::Project;
 27use schemars::gen::SchemaGenerator;
 28use schemars::schema::{InstanceType, Schema, SchemaObject};
 29use schemars::JsonSchema;
 30use serde::{Deserialize, Serialize};
 31use settings::{Settings, SettingsSources, SettingsStore};
 32use util::ResultExt as _;
 33
 34use crate::{
 35    client::{self, Client},
 36    types, ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE,
 37};
 38
 39#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
 40pub struct ContextServerSettings {
 41    #[serde(default)]
 42    pub context_servers: HashMap<Arc<str>, ServerConfig>,
 43}
 44
 45#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
 46pub struct ServerConfig {
 47    pub command: Option<ServerCommand>,
 48    #[schemars(schema_with = "server_config_settings_json_schema")]
 49    pub settings: Option<serde_json::Value>,
 50}
 51
 52fn server_config_settings_json_schema(_generator: &mut SchemaGenerator) -> Schema {
 53    Schema::Object(SchemaObject {
 54        instance_type: Some(InstanceType::Object.into()),
 55        ..Default::default()
 56    })
 57}
 58
 59#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
 60pub struct ServerCommand {
 61    pub path: String,
 62    pub args: Vec<String>,
 63    pub env: Option<HashMap<String, String>>,
 64}
 65
 66impl Settings for ContextServerSettings {
 67    const KEY: Option<&'static str> = None;
 68
 69    type FileContent = Self;
 70
 71    fn load(
 72        sources: SettingsSources<Self::FileContent>,
 73        _: &mut gpui::AppContext,
 74    ) -> anyhow::Result<Self> {
 75        sources.json_merge()
 76    }
 77}
 78
 79pub struct ContextServer {
 80    pub id: Arc<str>,
 81    pub config: Arc<ServerConfig>,
 82    pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
 83}
 84
 85impl ContextServer {
 86    pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
 87        Self {
 88            id,
 89            config,
 90            client: RwLock::new(None),
 91        }
 92    }
 93
 94    pub fn id(&self) -> Arc<str> {
 95        self.id.clone()
 96    }
 97
 98    pub fn config(&self) -> Arc<ServerConfig> {
 99        self.config.clone()
100    }
101
102    pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
103        self.client.read().clone()
104    }
105
106    pub async fn start(self: Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
107        log::info!("starting context server {}", self.id);
108        let Some(command) = &self.config.command else {
109            bail!("no command specified for server {}", self.id);
110        };
111        let client = Client::new(
112            client::ContextServerId(self.id.clone()),
113            client::ModelContextServerBinary {
114                executable: Path::new(&command.path).to_path_buf(),
115                args: command.args.clone(),
116                env: command.env.clone(),
117            },
118            cx.clone(),
119        )?;
120
121        let protocol = crate::protocol::ModelContextProtocol::new(client);
122        let client_info = types::Implementation {
123            name: "Zed".to_string(),
124            version: env!("CARGO_PKG_VERSION").to_string(),
125        };
126        let initialized_protocol = protocol.initialize(client_info).await?;
127
128        log::debug!(
129            "context server {} initialized: {:?}",
130            self.id,
131            initialized_protocol.initialize,
132        );
133
134        *self.client.write() = Some(Arc::new(initialized_protocol));
135        Ok(())
136    }
137
138    pub fn stop(&self) -> Result<()> {
139        let mut client = self.client.write();
140        if let Some(protocol) = client.take() {
141            drop(protocol);
142        }
143        Ok(())
144    }
145}
146
147pub struct ContextServerManager {
148    servers: HashMap<Arc<str>, Arc<ContextServer>>,
149    project: Model<Project>,
150    registry: Model<ContextServerFactoryRegistry>,
151    update_servers_task: Option<Task<Result<()>>>,
152    needs_server_update: bool,
153    _subscriptions: Vec<Subscription>,
154}
155
156pub enum Event {
157    ServerStarted { server_id: Arc<str> },
158    ServerStopped { server_id: Arc<str> },
159}
160
161impl EventEmitter<Event> for ContextServerManager {}
162
163impl ContextServerManager {
164    pub fn new(
165        registry: Model<ContextServerFactoryRegistry>,
166        project: Model<Project>,
167        cx: &mut ModelContext<Self>,
168    ) -> Self {
169        let mut this = Self {
170            _subscriptions: vec![
171                cx.observe(&registry, |this, _registry, cx| {
172                    this.available_context_servers_changed(cx);
173                }),
174                cx.observe_global::<SettingsStore>(|this, cx| {
175                    this.available_context_servers_changed(cx);
176                }),
177            ],
178            project,
179            registry,
180            needs_server_update: false,
181            servers: HashMap::default(),
182            update_servers_task: None,
183        };
184        this.available_context_servers_changed(cx);
185        this
186    }
187
188    fn available_context_servers_changed(&mut self, cx: &mut ModelContext<Self>) {
189        if self.update_servers_task.is_some() {
190            self.needs_server_update = true;
191        } else {
192            self.update_servers_task = Some(cx.spawn(|this, mut cx| async move {
193                this.update(&mut cx, |this, _| {
194                    this.needs_server_update = false;
195                })?;
196
197                Self::maintain_servers(this.clone(), cx.clone()).await?;
198
199                this.update(&mut cx, |this, cx| {
200                    let has_any_context_servers = !this.servers().is_empty();
201                    if has_any_context_servers {
202                        CommandPaletteFilter::update_global(cx, |filter, _cx| {
203                            filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
204                        });
205                    }
206
207                    this.update_servers_task.take();
208                    if this.needs_server_update {
209                        this.available_context_servers_changed(cx);
210                    }
211                })?;
212
213                Ok(())
214            }));
215        }
216    }
217
218    pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
219        self.servers
220            .get(id)
221            .filter(|server| server.client().is_some())
222            .cloned()
223    }
224
225    pub fn restart_server(
226        &mut self,
227        id: &Arc<str>,
228        cx: &mut ModelContext<Self>,
229    ) -> Task<anyhow::Result<()>> {
230        let id = id.clone();
231        cx.spawn(|this, mut cx| async move {
232            if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
233                server.stop()?;
234                let config = server.config();
235                let new_server = Arc::new(ContextServer::new(id.clone(), config));
236                new_server.clone().start(&cx).await?;
237                this.update(&mut cx, |this, cx| {
238                    this.servers.insert(id.clone(), new_server);
239                    cx.emit(Event::ServerStopped {
240                        server_id: id.clone(),
241                    });
242                    cx.emit(Event::ServerStarted {
243                        server_id: id.clone(),
244                    });
245                })?;
246            }
247            Ok(())
248        })
249    }
250
251    pub fn servers(&self) -> Vec<Arc<ContextServer>> {
252        self.servers
253            .values()
254            .filter(|server| server.client().is_some())
255            .cloned()
256            .collect()
257    }
258
259    async fn maintain_servers(this: WeakModel<Self>, mut cx: AsyncAppContext) -> Result<()> {
260        let mut desired_servers = HashMap::default();
261
262        let (registry, project) = this.update(&mut cx, |this, cx| {
263            let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
264                settings::SettingsLocation {
265                    worktree_id: worktree.read(cx).id(),
266                    path: Path::new(""),
267                }
268            });
269            let settings = ContextServerSettings::get(location, cx);
270            desired_servers = settings.context_servers.clone();
271
272            (this.registry.clone(), this.project.clone())
273        })?;
274
275        for (id, factory) in
276            registry.read_with(&cx, |registry, _| registry.context_server_factories())?
277        {
278            let config = desired_servers.entry(id).or_default();
279            if config.command.is_none() {
280                if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
281                    config.command = Some(extension_command);
282                }
283            }
284        }
285
286        let mut servers_to_start = HashMap::default();
287        let mut servers_to_stop = HashMap::default();
288
289        this.update(&mut cx, |this, _cx| {
290            this.servers.retain(|id, server| {
291                if desired_servers.contains_key(id) {
292                    true
293                } else {
294                    servers_to_stop.insert(id.clone(), server.clone());
295                    false
296                }
297            });
298
299            for (id, config) in desired_servers {
300                let existing_config = this.servers.get(&id).map(|server| server.config());
301                if existing_config.as_deref() != Some(&config) {
302                    let config = Arc::new(config);
303                    let server = Arc::new(ContextServer::new(id.clone(), config));
304                    servers_to_start.insert(id.clone(), server.clone());
305                    let old_server = this.servers.insert(id.clone(), server);
306                    if let Some(old_server) = old_server {
307                        servers_to_stop.insert(id, old_server);
308                    }
309                }
310            }
311        })?;
312
313        for (id, server) in servers_to_stop {
314            server.stop().log_err();
315            this.update(&mut cx, |_, cx| {
316                cx.emit(Event::ServerStopped { server_id: id })
317            })?;
318        }
319
320        for (id, server) in servers_to_start {
321            if server.start(&cx).await.log_err().is_some() {
322                this.update(&mut cx, |_, cx| {
323                    cx.emit(Event::ServerStarted { server_id: id })
324                })?;
325            }
326        }
327
328        Ok(())
329    }
330}