1//! This module implements a context server management system for Zed.
2//!
3//! It provides functionality to:
4//! - Define and load context server settings
5//! - Manage individual context servers (start, stop, restart)
6//! - Maintain a global manager for all context servers
7//!
8//! Key components:
9//! - `ContextServerSettings`: Defines the structure for server configurations
10//! - `ContextServer`: Represents an individual context server
11//! - `ContextServerManager`: Manages multiple context servers
12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
13//!
14//! The module also includes initialization logic to set up the context server system
15//! and react to changes in settings.
16
17use std::path::Path;
18use std::sync::Arc;
19
20use anyhow::{bail, Result};
21use collections::HashMap;
22use command_palette_hooks::CommandPaletteFilter;
23use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel};
24use log;
25use parking_lot::RwLock;
26use project::Project;
27use schemars::gen::SchemaGenerator;
28use schemars::schema::{InstanceType, Schema, SchemaObject};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::{Settings, SettingsSources, SettingsStore};
32use util::ResultExt as _;
33
34use crate::{
35 client::{self, Client},
36 types, ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE,
37};
38
39#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
40pub struct ContextServerSettings {
41 /// Settings for context servers used in the Assistant.
42 #[serde(default)]
43 pub context_servers: HashMap<Arc<str>, ServerConfig>,
44}
45
46#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
47pub struct ServerConfig {
48 /// The command to run this context server.
49 ///
50 /// This will override the command set by an extension.
51 pub command: Option<ServerCommand>,
52 /// The settings for this context server.
53 ///
54 /// Consult the documentation for the context server to see what settings
55 /// are supported.
56 #[schemars(schema_with = "server_config_settings_json_schema")]
57 pub settings: Option<serde_json::Value>,
58}
59
60fn server_config_settings_json_schema(_generator: &mut SchemaGenerator) -> Schema {
61 Schema::Object(SchemaObject {
62 instance_type: Some(InstanceType::Object.into()),
63 ..Default::default()
64 })
65}
66
67#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
68pub struct ServerCommand {
69 pub path: String,
70 pub args: Vec<String>,
71 pub env: Option<HashMap<String, String>>,
72}
73
74impl Settings for ContextServerSettings {
75 const KEY: Option<&'static str> = None;
76
77 type FileContent = Self;
78
79 fn load(
80 sources: SettingsSources<Self::FileContent>,
81 _: &mut gpui::AppContext,
82 ) -> anyhow::Result<Self> {
83 sources.json_merge()
84 }
85}
86
87pub struct ContextServer {
88 pub id: Arc<str>,
89 pub config: Arc<ServerConfig>,
90 pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
91}
92
93impl ContextServer {
94 pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
95 Self {
96 id,
97 config,
98 client: RwLock::new(None),
99 }
100 }
101
102 pub fn id(&self) -> Arc<str> {
103 self.id.clone()
104 }
105
106 pub fn config(&self) -> Arc<ServerConfig> {
107 self.config.clone()
108 }
109
110 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
111 self.client.read().clone()
112 }
113
114 pub async fn start(self: Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
115 log::info!("starting context server {}", self.id);
116 let Some(command) = &self.config.command else {
117 bail!("no command specified for server {}", self.id);
118 };
119 let client = Client::new(
120 client::ContextServerId(self.id.clone()),
121 client::ModelContextServerBinary {
122 executable: Path::new(&command.path).to_path_buf(),
123 args: command.args.clone(),
124 env: command.env.clone(),
125 },
126 cx.clone(),
127 )?;
128
129 let protocol = crate::protocol::ModelContextProtocol::new(client);
130 let client_info = types::Implementation {
131 name: "Zed".to_string(),
132 version: env!("CARGO_PKG_VERSION").to_string(),
133 };
134 let initialized_protocol = protocol.initialize(client_info).await?;
135
136 log::debug!(
137 "context server {} initialized: {:?}",
138 self.id,
139 initialized_protocol.initialize,
140 );
141
142 *self.client.write() = Some(Arc::new(initialized_protocol));
143 Ok(())
144 }
145
146 pub fn stop(&self) -> Result<()> {
147 let mut client = self.client.write();
148 if let Some(protocol) = client.take() {
149 drop(protocol);
150 }
151 Ok(())
152 }
153}
154
155pub struct ContextServerManager {
156 servers: HashMap<Arc<str>, Arc<ContextServer>>,
157 project: Model<Project>,
158 registry: Model<ContextServerFactoryRegistry>,
159 update_servers_task: Option<Task<Result<()>>>,
160 needs_server_update: bool,
161 _subscriptions: Vec<Subscription>,
162}
163
164pub enum Event {
165 ServerStarted { server_id: Arc<str> },
166 ServerStopped { server_id: Arc<str> },
167}
168
169impl EventEmitter<Event> for ContextServerManager {}
170
171impl ContextServerManager {
172 pub fn new(
173 registry: Model<ContextServerFactoryRegistry>,
174 project: Model<Project>,
175 cx: &mut ModelContext<Self>,
176 ) -> Self {
177 let mut this = Self {
178 _subscriptions: vec![
179 cx.observe(®istry, |this, _registry, cx| {
180 this.available_context_servers_changed(cx);
181 }),
182 cx.observe_global::<SettingsStore>(|this, cx| {
183 this.available_context_servers_changed(cx);
184 }),
185 ],
186 project,
187 registry,
188 needs_server_update: false,
189 servers: HashMap::default(),
190 update_servers_task: None,
191 };
192 this.available_context_servers_changed(cx);
193 this
194 }
195
196 fn available_context_servers_changed(&mut self, cx: &mut ModelContext<Self>) {
197 if self.update_servers_task.is_some() {
198 self.needs_server_update = true;
199 } else {
200 self.update_servers_task = Some(cx.spawn(|this, mut cx| async move {
201 this.update(&mut cx, |this, _| {
202 this.needs_server_update = false;
203 })?;
204
205 Self::maintain_servers(this.clone(), cx.clone()).await?;
206
207 this.update(&mut cx, |this, cx| {
208 let has_any_context_servers = !this.servers().is_empty();
209 if has_any_context_servers {
210 CommandPaletteFilter::update_global(cx, |filter, _cx| {
211 filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
212 });
213 }
214
215 this.update_servers_task.take();
216 if this.needs_server_update {
217 this.available_context_servers_changed(cx);
218 }
219 })?;
220
221 Ok(())
222 }));
223 }
224 }
225
226 pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
227 self.servers
228 .get(id)
229 .filter(|server| server.client().is_some())
230 .cloned()
231 }
232
233 pub fn restart_server(
234 &mut self,
235 id: &Arc<str>,
236 cx: &mut ModelContext<Self>,
237 ) -> Task<anyhow::Result<()>> {
238 let id = id.clone();
239 cx.spawn(|this, mut cx| async move {
240 if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
241 server.stop()?;
242 let config = server.config();
243 let new_server = Arc::new(ContextServer::new(id.clone(), config));
244 new_server.clone().start(&cx).await?;
245 this.update(&mut cx, |this, cx| {
246 this.servers.insert(id.clone(), new_server);
247 cx.emit(Event::ServerStopped {
248 server_id: id.clone(),
249 });
250 cx.emit(Event::ServerStarted {
251 server_id: id.clone(),
252 });
253 })?;
254 }
255 Ok(())
256 })
257 }
258
259 pub fn servers(&self) -> Vec<Arc<ContextServer>> {
260 self.servers
261 .values()
262 .filter(|server| server.client().is_some())
263 .cloned()
264 .collect()
265 }
266
267 async fn maintain_servers(this: WeakModel<Self>, mut cx: AsyncAppContext) -> Result<()> {
268 let mut desired_servers = HashMap::default();
269
270 let (registry, project) = this.update(&mut cx, |this, cx| {
271 let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
272 settings::SettingsLocation {
273 worktree_id: worktree.read(cx).id(),
274 path: Path::new(""),
275 }
276 });
277 let settings = ContextServerSettings::get(location, cx);
278 desired_servers = settings.context_servers.clone();
279
280 (this.registry.clone(), this.project.clone())
281 })?;
282
283 for (id, factory) in
284 registry.read_with(&cx, |registry, _| registry.context_server_factories())?
285 {
286 let config = desired_servers.entry(id).or_default();
287 if config.command.is_none() {
288 if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
289 config.command = Some(extension_command);
290 }
291 }
292 }
293
294 let mut servers_to_start = HashMap::default();
295 let mut servers_to_stop = HashMap::default();
296
297 this.update(&mut cx, |this, _cx| {
298 this.servers.retain(|id, server| {
299 if desired_servers.contains_key(id) {
300 true
301 } else {
302 servers_to_stop.insert(id.clone(), server.clone());
303 false
304 }
305 });
306
307 for (id, config) in desired_servers {
308 let existing_config = this.servers.get(&id).map(|server| server.config());
309 if existing_config.as_deref() != Some(&config) {
310 let config = Arc::new(config);
311 let server = Arc::new(ContextServer::new(id.clone(), config));
312 servers_to_start.insert(id.clone(), server.clone());
313 let old_server = this.servers.insert(id.clone(), server);
314 if let Some(old_server) = old_server {
315 servers_to_stop.insert(id, old_server);
316 }
317 }
318 }
319 })?;
320
321 for (id, server) in servers_to_stop {
322 server.stop().log_err();
323 this.update(&mut cx, |_, cx| {
324 cx.emit(Event::ServerStopped { server_id: id })
325 })?;
326 }
327
328 for (id, server) in servers_to_start {
329 if server.start(&cx).await.log_err().is_some() {
330 this.update(&mut cx, |_, cx| {
331 cx.emit(Event::ServerStarted { server_id: id })
332 })?;
333 }
334 }
335
336 Ok(())
337 }
338}