1//! This module implements a context server management system for Zed.
2//!
3//! It provides functionality to:
4//! - Define and load context server settings
5//! - Manage individual context servers (start, stop, restart)
6//! - Maintain a global manager for all context servers
7//!
8//! Key components:
9//! - `ContextServerSettings`: Defines the structure for server configurations
10//! - `ContextServer`: Represents an individual context server
11//! - `ContextServerManager`: Manages multiple context servers
12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
13//!
14//! The module also includes initialization logic to set up the context server system
15//! and react to changes in settings.
16
17use std::path::Path;
18use std::sync::Arc;
19
20use anyhow::{Result, bail};
21use collections::HashMap;
22use command_palette_hooks::CommandPaletteFilter;
23use gpui::{AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
24use log;
25use parking_lot::RwLock;
26use project::Project;
27use settings::{Settings, SettingsStore};
28use util::ResultExt as _;
29
30use crate::{ContextServerSettings, ServerConfig};
31
32use crate::{
33 CONTEXT_SERVERS_NAMESPACE, ContextServerFactoryRegistry,
34 client::{self, Client},
35 types,
36};
37
38pub struct ContextServer {
39 pub id: Arc<str>,
40 pub config: Arc<ServerConfig>,
41 pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
42}
43
44impl ContextServer {
45 pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
46 Self {
47 id,
48 config,
49 client: RwLock::new(None),
50 }
51 }
52
53 pub fn id(&self) -> Arc<str> {
54 self.id.clone()
55 }
56
57 pub fn config(&self) -> Arc<ServerConfig> {
58 self.config.clone()
59 }
60
61 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
62 self.client.read().clone()
63 }
64
65 pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
66 log::info!("starting context server {}", self.id);
67 let Some(command) = &self.config.command else {
68 bail!("no command specified for server {}", self.id);
69 };
70 let client = Client::new(
71 client::ContextServerId(self.id.clone()),
72 client::ModelContextServerBinary {
73 executable: Path::new(&command.path).to_path_buf(),
74 args: command.args.clone(),
75 env: command.env.clone(),
76 },
77 cx.clone(),
78 )?;
79
80 let protocol = crate::protocol::ModelContextProtocol::new(client);
81 let client_info = types::Implementation {
82 name: "Zed".to_string(),
83 version: env!("CARGO_PKG_VERSION").to_string(),
84 };
85 let initialized_protocol = protocol.initialize(client_info).await?;
86
87 log::debug!(
88 "context server {} initialized: {:?}",
89 self.id,
90 initialized_protocol.initialize,
91 );
92
93 *self.client.write() = Some(Arc::new(initialized_protocol));
94 Ok(())
95 }
96
97 pub fn stop(&self) -> Result<()> {
98 let mut client = self.client.write();
99 if let Some(protocol) = client.take() {
100 drop(protocol);
101 }
102 Ok(())
103 }
104}
105
106pub struct ContextServerManager {
107 servers: HashMap<Arc<str>, Arc<ContextServer>>,
108 project: Entity<Project>,
109 registry: Entity<ContextServerFactoryRegistry>,
110 update_servers_task: Option<Task<Result<()>>>,
111 needs_server_update: bool,
112 _subscriptions: Vec<Subscription>,
113}
114
115pub enum Event {
116 ServerStarted { server_id: Arc<str> },
117 ServerStopped { server_id: Arc<str> },
118}
119
120impl EventEmitter<Event> for ContextServerManager {}
121
122impl ContextServerManager {
123 pub fn new(
124 registry: Entity<ContextServerFactoryRegistry>,
125 project: Entity<Project>,
126 cx: &mut Context<Self>,
127 ) -> Self {
128 let mut this = Self {
129 _subscriptions: vec![
130 cx.observe(®istry, |this, _registry, cx| {
131 this.available_context_servers_changed(cx);
132 }),
133 cx.observe_global::<SettingsStore>(|this, cx| {
134 this.available_context_servers_changed(cx);
135 }),
136 ],
137 project,
138 registry,
139 needs_server_update: false,
140 servers: HashMap::default(),
141 update_servers_task: None,
142 };
143 this.available_context_servers_changed(cx);
144 this
145 }
146
147 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
148 if self.update_servers_task.is_some() {
149 self.needs_server_update = true;
150 } else {
151 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
152 this.update(cx, |this, _| {
153 this.needs_server_update = false;
154 })?;
155
156 Self::maintain_servers(this.clone(), cx).await?;
157
158 this.update(cx, |this, cx| {
159 let has_any_context_servers = !this.running_servers().is_empty();
160 if has_any_context_servers {
161 CommandPaletteFilter::update_global(cx, |filter, _cx| {
162 filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
163 });
164 }
165
166 this.update_servers_task.take();
167 if this.needs_server_update {
168 this.available_context_servers_changed(cx);
169 }
170 })?;
171
172 Ok(())
173 }));
174 }
175 }
176
177 pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
178 self.servers
179 .get(id)
180 .filter(|server| server.client().is_some())
181 .cloned()
182 }
183
184 pub fn start_server(
185 &self,
186 server: Arc<ContextServer>,
187 cx: &mut Context<Self>,
188 ) -> Task<anyhow::Result<()>> {
189 cx.spawn(async move |this, cx| {
190 let id = server.id.clone();
191 server.start(&cx).await?;
192 this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
193 Ok(())
194 })
195 }
196
197 pub fn stop_server(
198 &self,
199 server: Arc<ContextServer>,
200 cx: &mut Context<Self>,
201 ) -> anyhow::Result<()> {
202 server.stop()?;
203 cx.emit(Event::ServerStopped {
204 server_id: server.id(),
205 });
206 Ok(())
207 }
208
209 pub fn restart_server(
210 &mut self,
211 id: &Arc<str>,
212 cx: &mut Context<Self>,
213 ) -> Task<anyhow::Result<()>> {
214 let id = id.clone();
215 cx.spawn(async move |this, cx| {
216 if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
217 server.stop()?;
218 let config = server.config();
219 let new_server = Arc::new(ContextServer::new(id.clone(), config));
220 new_server.clone().start(&cx).await?;
221 this.update(cx, |this, cx| {
222 this.servers.insert(id.clone(), new_server);
223 cx.emit(Event::ServerStopped {
224 server_id: id.clone(),
225 });
226 cx.emit(Event::ServerStarted {
227 server_id: id.clone(),
228 });
229 })?;
230 }
231 Ok(())
232 })
233 }
234
235 pub fn all_servers(&self) -> Vec<Arc<ContextServer>> {
236 self.servers.values().cloned().collect()
237 }
238
239 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
240 self.servers
241 .values()
242 .filter(|server| server.client().is_some())
243 .cloned()
244 .collect()
245 }
246
247 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
248 let mut desired_servers = HashMap::default();
249
250 let (registry, project) = this.update(cx, |this, cx| {
251 let location = this
252 .project
253 .read(cx)
254 .visible_worktrees(cx)
255 .next()
256 .map(|worktree| settings::SettingsLocation {
257 worktree_id: worktree.read(cx).id(),
258 path: Path::new(""),
259 });
260 let settings = ContextServerSettings::get(location, cx);
261 desired_servers = settings.context_servers.clone();
262
263 (this.registry.clone(), this.project.clone())
264 })?;
265
266 for (id, factory) in
267 registry.read_with(cx, |registry, _| registry.context_server_factories())?
268 {
269 let config = desired_servers.entry(id).or_default();
270 if config.command.is_none() {
271 if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
272 config.command = Some(extension_command);
273 }
274 }
275 }
276
277 let mut servers_to_start = HashMap::default();
278 let mut servers_to_stop = HashMap::default();
279
280 this.update(cx, |this, _cx| {
281 this.servers.retain(|id, server| {
282 if desired_servers.contains_key(id) {
283 true
284 } else {
285 servers_to_stop.insert(id.clone(), server.clone());
286 false
287 }
288 });
289
290 for (id, config) in desired_servers {
291 let existing_config = this.servers.get(&id).map(|server| server.config());
292 if existing_config.as_deref() != Some(&config) {
293 let config = Arc::new(config);
294 let server = Arc::new(ContextServer::new(id.clone(), config));
295 servers_to_start.insert(id.clone(), server.clone());
296 let old_server = this.servers.insert(id.clone(), server);
297 if let Some(old_server) = old_server {
298 servers_to_stop.insert(id, old_server);
299 }
300 }
301 }
302 })?;
303
304 for (id, server) in servers_to_stop {
305 server.stop().log_err();
306 this.update(cx, |_, cx| cx.emit(Event::ServerStopped { server_id: id }))?;
307 }
308
309 for (id, server) in servers_to_start {
310 if server.start(&cx).await.log_err().is_some() {
311 this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
312 }
313 }
314
315 Ok(())
316 }
317}