1//! This module implements a context server management system for Zed.
2//!
3//! It provides functionality to:
4//! - Define and load context server settings
5//! - Manage individual context servers (start, stop, restart)
6//! - Maintain a global manager for all context servers
7//!
8//! Key components:
9//! - `ContextServerSettings`: Defines the structure for server configurations
10//! - `ContextServer`: Represents an individual context server
11//! - `ContextServerManager`: Manages multiple context servers
12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
13//!
14//! The module also includes initialization logic to set up the context server system
15//! and react to changes in settings.
16
17use std::path::Path;
18use std::sync::Arc;
19
20use anyhow::{bail, Result};
21use collections::HashMap;
22use command_palette_hooks::CommandPaletteFilter;
23use gpui::{AsyncAppContext, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel};
24use log;
25use parking_lot::RwLock;
26use project::Project;
27use schemars::JsonSchema;
28use serde::{Deserialize, Serialize};
29use settings::{Settings, SettingsSources, SettingsStore};
30use util::ResultExt as _;
31
32use crate::{
33 client::{self, Client},
34 types, ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE,
35};
36
37#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
38pub struct ContextServerSettings {
39 #[serde(default)]
40 pub context_servers: HashMap<Arc<str>, ServerConfig>,
41}
42
43#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug, Default)]
44pub struct ServerConfig {
45 pub command: Option<ServerCommand>,
46 pub settings: Option<serde_json::Value>,
47}
48
49#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
50pub struct ServerCommand {
51 pub path: String,
52 pub args: Vec<String>,
53 pub env: Option<HashMap<String, String>>,
54}
55
56impl Settings for ContextServerSettings {
57 const KEY: Option<&'static str> = None;
58
59 type FileContent = Self;
60
61 fn load(
62 sources: SettingsSources<Self::FileContent>,
63 _: &mut gpui::AppContext,
64 ) -> anyhow::Result<Self> {
65 sources.json_merge()
66 }
67}
68
69pub struct ContextServer {
70 pub id: Arc<str>,
71 pub config: Arc<ServerConfig>,
72 pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
73}
74
75impl ContextServer {
76 pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
77 Self {
78 id,
79 config,
80 client: RwLock::new(None),
81 }
82 }
83
84 pub fn id(&self) -> Arc<str> {
85 self.id.clone()
86 }
87
88 pub fn config(&self) -> Arc<ServerConfig> {
89 self.config.clone()
90 }
91
92 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
93 self.client.read().clone()
94 }
95
96 pub async fn start(self: Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
97 log::info!("starting context server {}", self.id);
98 let Some(command) = &self.config.command else {
99 bail!("no command specified for server {}", self.id);
100 };
101 let client = Client::new(
102 client::ContextServerId(self.id.clone()),
103 client::ModelContextServerBinary {
104 executable: Path::new(&command.path).to_path_buf(),
105 args: command.args.clone(),
106 env: command.env.clone(),
107 },
108 cx.clone(),
109 )?;
110
111 let protocol = crate::protocol::ModelContextProtocol::new(client);
112 let client_info = types::Implementation {
113 name: "Zed".to_string(),
114 version: env!("CARGO_PKG_VERSION").to_string(),
115 };
116 let initialized_protocol = protocol.initialize(client_info).await?;
117
118 log::debug!(
119 "context server {} initialized: {:?}",
120 self.id,
121 initialized_protocol.initialize,
122 );
123
124 *self.client.write() = Some(Arc::new(initialized_protocol));
125 Ok(())
126 }
127
128 pub fn stop(&self) -> Result<()> {
129 let mut client = self.client.write();
130 if let Some(protocol) = client.take() {
131 drop(protocol);
132 }
133 Ok(())
134 }
135}
136
137pub struct ContextServerManager {
138 servers: HashMap<Arc<str>, Arc<ContextServer>>,
139 project: Model<Project>,
140 registry: Model<ContextServerFactoryRegistry>,
141 update_servers_task: Option<Task<Result<()>>>,
142 needs_server_update: bool,
143 _subscriptions: Vec<Subscription>,
144}
145
146pub enum Event {
147 ServerStarted { server_id: Arc<str> },
148 ServerStopped { server_id: Arc<str> },
149}
150
151impl EventEmitter<Event> for ContextServerManager {}
152
153impl ContextServerManager {
154 pub fn new(
155 registry: Model<ContextServerFactoryRegistry>,
156 project: Model<Project>,
157 cx: &mut ModelContext<Self>,
158 ) -> Self {
159 let mut this = Self {
160 _subscriptions: vec![
161 cx.observe(®istry, |this, _registry, cx| {
162 this.available_context_servers_changed(cx);
163 }),
164 cx.observe_global::<SettingsStore>(|this, cx| {
165 this.available_context_servers_changed(cx);
166 }),
167 ],
168 project,
169 registry,
170 needs_server_update: false,
171 servers: HashMap::default(),
172 update_servers_task: None,
173 };
174 this.available_context_servers_changed(cx);
175 this
176 }
177
178 fn available_context_servers_changed(&mut self, cx: &mut ModelContext<Self>) {
179 if self.update_servers_task.is_some() {
180 self.needs_server_update = true;
181 } else {
182 self.update_servers_task = Some(cx.spawn(|this, mut cx| async move {
183 this.update(&mut cx, |this, _| {
184 this.needs_server_update = false;
185 })?;
186
187 Self::maintain_servers(this.clone(), cx.clone()).await?;
188
189 this.update(&mut cx, |this, cx| {
190 let has_any_context_servers = !this.servers().is_empty();
191 if has_any_context_servers {
192 CommandPaletteFilter::update_global(cx, |filter, _cx| {
193 filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
194 });
195 }
196
197 this.update_servers_task.take();
198 if this.needs_server_update {
199 this.available_context_servers_changed(cx);
200 }
201 })?;
202
203 Ok(())
204 }));
205 }
206 }
207
208 pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
209 self.servers
210 .get(id)
211 .filter(|server| server.client().is_some())
212 .cloned()
213 }
214
215 pub fn restart_server(
216 &mut self,
217 id: &Arc<str>,
218 cx: &mut ModelContext<Self>,
219 ) -> Task<anyhow::Result<()>> {
220 let id = id.clone();
221 cx.spawn(|this, mut cx| async move {
222 if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
223 server.stop()?;
224 let config = server.config();
225 let new_server = Arc::new(ContextServer::new(id.clone(), config));
226 new_server.clone().start(&cx).await?;
227 this.update(&mut cx, |this, cx| {
228 this.servers.insert(id.clone(), new_server);
229 cx.emit(Event::ServerStopped {
230 server_id: id.clone(),
231 });
232 cx.emit(Event::ServerStarted {
233 server_id: id.clone(),
234 });
235 })?;
236 }
237 Ok(())
238 })
239 }
240
241 pub fn servers(&self) -> Vec<Arc<ContextServer>> {
242 self.servers
243 .values()
244 .filter(|server| server.client().is_some())
245 .cloned()
246 .collect()
247 }
248
249 async fn maintain_servers(this: WeakModel<Self>, mut cx: AsyncAppContext) -> Result<()> {
250 let mut desired_servers = HashMap::default();
251
252 let (registry, project) = this.update(&mut cx, |this, cx| {
253 let location = this.project.read(cx).worktrees(cx).next().map(|worktree| {
254 settings::SettingsLocation {
255 worktree_id: worktree.read(cx).id(),
256 path: Path::new(""),
257 }
258 });
259 let settings = ContextServerSettings::get(location, cx);
260 desired_servers = settings.context_servers.clone();
261
262 (this.registry.clone(), this.project.clone())
263 })?;
264
265 for (id, factory) in
266 registry.read_with(&cx, |registry, _| registry.context_server_factories())?
267 {
268 let config = desired_servers.entry(id).or_default();
269 if config.command.is_none() {
270 if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
271 config.command = Some(extension_command);
272 }
273 }
274 }
275
276 let mut servers_to_start = HashMap::default();
277 let mut servers_to_stop = HashMap::default();
278
279 this.update(&mut cx, |this, _cx| {
280 this.servers.retain(|id, server| {
281 if desired_servers.contains_key(id) {
282 true
283 } else {
284 servers_to_stop.insert(id.clone(), server.clone());
285 false
286 }
287 });
288
289 for (id, config) in desired_servers {
290 let existing_config = this.servers.get(&id).map(|server| server.config());
291 if existing_config.as_deref() != Some(&config) {
292 let config = Arc::new(config);
293 let server = Arc::new(ContextServer::new(id.clone(), config));
294 servers_to_start.insert(id.clone(), server.clone());
295 let old_server = this.servers.insert(id.clone(), server);
296 if let Some(old_server) = old_server {
297 servers_to_stop.insert(id, old_server);
298 }
299 }
300 }
301 })?;
302
303 for (id, server) in servers_to_stop {
304 server.stop().log_err();
305 this.update(&mut cx, |_, cx| {
306 cx.emit(Event::ServerStopped { server_id: id })
307 })?;
308 }
309
310 for (id, server) in servers_to_start {
311 if server.start(&cx).await.log_err().is_some() {
312 this.update(&mut cx, |_, cx| {
313 cx.emit(Event::ServerStarted { server_id: id })
314 })?;
315 }
316 }
317
318 Ok(())
319 }
320}