1//! This module implements a context server management system for Zed.
2//!
3//! It provides functionality to:
4//! - Define and load context server settings
5//! - Manage individual context servers (start, stop, restart)
6//! - Maintain a global manager for all context servers
7//!
8//! Key components:
9//! - `ContextServerSettings`: Defines the structure for server configurations
10//! - `ContextServer`: Represents an individual context server
11//! - `ContextServerManager`: Manages multiple context servers
12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
13//!
14//! The module also includes initialization logic to set up the context server system
15//! and react to changes in settings.
16
17use std::path::Path;
18use std::sync::Arc;
19
20use anyhow::{bail, Result};
21use collections::HashMap;
22use command_palette_hooks::CommandPaletteFilter;
23use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel};
24use log;
25use parking_lot::RwLock;
26use project::Project;
27use schemars::gen::SchemaGenerator;
28use schemars::schema::{InstanceType, Schema, SchemaObject};
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::{Settings, SettingsSources, SettingsStore};
32use util::ResultExt as _;
33
34use crate::{
35 client::{self, Client},
36 types, ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE,
37};
38
39#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
40pub struct ContextServerSettings {
41 #[serde(default)]
42 pub context_servers: HashMap<Arc<str>, ServerConfig>,
43}
44
45#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
46pub struct ServerConfig {
47 pub command: Option<ServerCommand>,
48 #[schemars(schema_with = "server_config_settings_json_schema")]
49 pub settings: Option<serde_json::Value>,
50}
51
52fn server_config_settings_json_schema(_generator: &mut SchemaGenerator) -> Schema {
53 Schema::Object(SchemaObject {
54 instance_type: Some(InstanceType::Object.into()),
55 ..Default::default()
56 })
57}
58
59#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
60pub struct ServerCommand {
61 pub path: String,
62 pub args: Vec<String>,
63 pub env: Option<HashMap<String, String>>,
64}
65
66impl Settings for ContextServerSettings {
67 const KEY: Option<&'static str> = None;
68
69 type FileContent = Self;
70
71 fn load(
72 sources: SettingsSources<Self::FileContent>,
73 _: &mut gpui::AppContext,
74 ) -> anyhow::Result<Self> {
75 sources.json_merge()
76 }
77}
78
79pub struct ContextServer {
80 pub id: Arc<str>,
81 pub config: Arc<ServerConfig>,
82 pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
83}
84
85impl ContextServer {
86 pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
87 Self {
88 id,
89 config,
90 client: RwLock::new(None),
91 }
92 }
93
94 pub fn id(&self) -> Arc<str> {
95 self.id.clone()
96 }
97
98 pub fn config(&self) -> Arc<ServerConfig> {
99 self.config.clone()
100 }
101
102 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
103 self.client.read().clone()
104 }
105
106 pub async fn start(self: Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
107 log::info!("starting context server {}", self.id);
108 let Some(command) = &self.config.command else {
109 bail!("no command specified for server {}", self.id);
110 };
111 let client = Client::new(
112 client::ContextServerId(self.id.clone()),
113 client::ModelContextServerBinary {
114 executable: Path::new(&command.path).to_path_buf(),
115 args: command.args.clone(),
116 env: command.env.clone(),
117 },
118 cx.clone(),
119 )?;
120
121 let protocol = crate::protocol::ModelContextProtocol::new(client);
122 let client_info = types::Implementation {
123 name: "Zed".to_string(),
124 version: env!("CARGO_PKG_VERSION").to_string(),
125 };
126 let initialized_protocol = protocol.initialize(client_info).await?;
127
128 log::debug!(
129 "context server {} initialized: {:?}",
130 self.id,
131 initialized_protocol.initialize,
132 );
133
134 *self.client.write() = Some(Arc::new(initialized_protocol));
135 Ok(())
136 }
137
138 pub fn stop(&self) -> Result<()> {
139 let mut client = self.client.write();
140 if let Some(protocol) = client.take() {
141 drop(protocol);
142 }
143 Ok(())
144 }
145}
146
147pub struct ContextServerManager {
148 servers: HashMap<Arc<str>, Arc<ContextServer>>,
149 project: Model<Project>,
150 registry: Model<ContextServerFactoryRegistry>,
151 update_servers_task: Option<Task<Result<()>>>,
152 needs_server_update: bool,
153 _subscriptions: Vec<Subscription>,
154}
155
156pub enum Event {
157 ServerStarted { server_id: Arc<str> },
158 ServerStopped { server_id: Arc<str> },
159}
160
161impl EventEmitter<Event> for ContextServerManager {}
162
163impl ContextServerManager {
164 pub fn new(
165 registry: Model<ContextServerFactoryRegistry>,
166 project: Model<Project>,
167 cx: &mut ModelContext<Self>,
168 ) -> Self {
169 let mut this = Self {
170 _subscriptions: vec![
171 cx.observe(®istry, |this, _registry, cx| {
172 this.available_context_servers_changed(cx);
173 }),
174 cx.observe_global::<SettingsStore>(|this, cx| {
175 this.available_context_servers_changed(cx);
176 }),
177 ],
178 project,
179 registry,
180 needs_server_update: false,
181 servers: HashMap::default(),
182 update_servers_task: None,
183 };
184 this.available_context_servers_changed(cx);
185 this
186 }
187
188 fn available_context_servers_changed(&mut self, cx: &mut ModelContext<Self>) {
189 if self.update_servers_task.is_some() {
190 self.needs_server_update = true;
191 } else {
192 self.update_servers_task = Some(cx.spawn(|this, mut cx| async move {
193 this.update(&mut cx, |this, _| {
194 this.needs_server_update = false;
195 })?;
196
197 Self::maintain_servers(this.clone(), cx.clone()).await?;
198
199 this.update(&mut cx, |this, cx| {
200 let has_any_context_servers = !this.servers().is_empty();
201 if has_any_context_servers {
202 CommandPaletteFilter::update_global(cx, |filter, _cx| {
203 filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
204 });
205 }
206
207 this.update_servers_task.take();
208 if this.needs_server_update {
209 this.available_context_servers_changed(cx);
210 }
211 })?;
212
213 Ok(())
214 }));
215 }
216 }
217
218 pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
219 self.servers
220 .get(id)
221 .filter(|server| server.client().is_some())
222 .cloned()
223 }
224
225 pub fn restart_server(
226 &mut self,
227 id: &Arc<str>,
228 cx: &mut ModelContext<Self>,
229 ) -> Task<anyhow::Result<()>> {
230 let id = id.clone();
231 cx.spawn(|this, mut cx| async move {
232 if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
233 server.stop()?;
234 let config = server.config();
235 let new_server = Arc::new(ContextServer::new(id.clone(), config));
236 new_server.clone().start(&cx).await?;
237 this.update(&mut cx, |this, cx| {
238 this.servers.insert(id.clone(), new_server);
239 cx.emit(Event::ServerStopped {
240 server_id: id.clone(),
241 });
242 cx.emit(Event::ServerStarted {
243 server_id: id.clone(),
244 });
245 })?;
246 }
247 Ok(())
248 })
249 }
250
251 pub fn servers(&self) -> Vec<Arc<ContextServer>> {
252 self.servers
253 .values()
254 .filter(|server| server.client().is_some())
255 .cloned()
256 .collect()
257 }
258
259 async fn maintain_servers(this: WeakModel<Self>, mut cx: AsyncAppContext) -> Result<()> {
260 let mut desired_servers = HashMap::default();
261
262 let (registry, project) = this.update(&mut cx, |this, cx| {
263 let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
264 settings::SettingsLocation {
265 worktree_id: worktree.read(cx).id(),
266 path: Path::new(""),
267 }
268 });
269 let settings = ContextServerSettings::get(location, cx);
270 desired_servers = settings.context_servers.clone();
271
272 (this.registry.clone(), this.project.clone())
273 })?;
274
275 for (id, factory) in
276 registry.read_with(&cx, |registry, _| registry.context_server_factories())?
277 {
278 let config = desired_servers.entry(id).or_default();
279 if config.command.is_none() {
280 if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
281 config.command = Some(extension_command);
282 }
283 }
284 }
285
286 let mut servers_to_start = HashMap::default();
287 let mut servers_to_stop = HashMap::default();
288
289 this.update(&mut cx, |this, _cx| {
290 this.servers.retain(|id, server| {
291 if desired_servers.contains_key(id) {
292 true
293 } else {
294 servers_to_stop.insert(id.clone(), server.clone());
295 false
296 }
297 });
298
299 for (id, config) in desired_servers {
300 let existing_config = this.servers.get(&id).map(|server| server.config());
301 if existing_config.as_deref() != Some(&config) {
302 let config = Arc::new(config);
303 let server = Arc::new(ContextServer::new(id.clone(), config));
304 servers_to_start.insert(id.clone(), server.clone());
305 let old_server = this.servers.insert(id.clone(), server);
306 if let Some(old_server) = old_server {
307 servers_to_stop.insert(id, old_server);
308 }
309 }
310 }
311 })?;
312
313 for (id, server) in servers_to_stop {
314 server.stop().log_err();
315 this.update(&mut cx, |_, cx| {
316 cx.emit(Event::ServerStopped { server_id: id })
317 })?;
318 }
319
320 for (id, server) in servers_to_start {
321 if server.start(&cx).await.log_err().is_some() {
322 this.update(&mut cx, |_, cx| {
323 cx.emit(Event::ServerStarted { server_id: id })
324 })?;
325 }
326 }
327
328 Ok(())
329 }
330}