1//! This module implements a context server management system for Zed.
2//!
3//! It provides functionality to:
4//! - Define and load context server settings
5//! - Manage individual context servers (start, stop, restart)
6//! - Maintain a global manager for all context servers
7//!
8//! Key components:
9//! - `ContextServerSettings`: Defines the structure for server configurations
10//! - `ContextServer`: Represents an individual context server
11//! - `ContextServerManager`: Manages multiple context servers
12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
13//!
14//! The module also includes initialization logic to set up the context server system
15//! and react to changes in settings.
16
17use std::path::Path;
18use std::sync::Arc;
19
20use anyhow::{bail, Result};
21use collections::HashMap;
22use command_palette_hooks::CommandPaletteFilter;
23use gpui::{AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
24use log;
25use parking_lot::RwLock;
26use project::Project;
27use settings::{Settings, SettingsStore};
28use util::ResultExt as _;
29
30use crate::{ContextServerSettings, ServerConfig};
31
32use crate::{
33 client::{self, Client},
34 types, ContextServerFactoryRegistry, CONTEXT_SERVERS_NAMESPACE,
35};
36
37pub struct ContextServer {
38 pub id: Arc<str>,
39 pub config: Arc<ServerConfig>,
40 pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
41}
42
43impl ContextServer {
44 pub fn new(id: Arc<str>, config: Arc<ServerConfig>) -> Self {
45 Self {
46 id,
47 config,
48 client: RwLock::new(None),
49 }
50 }
51
52 pub fn id(&self) -> Arc<str> {
53 self.id.clone()
54 }
55
56 pub fn config(&self) -> Arc<ServerConfig> {
57 self.config.clone()
58 }
59
60 pub fn client(&self) -> Option<Arc<crate::protocol::InitializedContextServerProtocol>> {
61 self.client.read().clone()
62 }
63
64 pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
65 log::info!("starting context server {}", self.id);
66 let Some(command) = &self.config.command else {
67 bail!("no command specified for server {}", self.id);
68 };
69 let client = Client::new(
70 client::ContextServerId(self.id.clone()),
71 client::ModelContextServerBinary {
72 executable: Path::new(&command.path).to_path_buf(),
73 args: command.args.clone(),
74 env: command.env.clone(),
75 },
76 cx.clone(),
77 )?;
78
79 let protocol = crate::protocol::ModelContextProtocol::new(client);
80 let client_info = types::Implementation {
81 name: "Zed".to_string(),
82 version: env!("CARGO_PKG_VERSION").to_string(),
83 };
84 let initialized_protocol = protocol.initialize(client_info).await?;
85
86 log::debug!(
87 "context server {} initialized: {:?}",
88 self.id,
89 initialized_protocol.initialize,
90 );
91
92 *self.client.write() = Some(Arc::new(initialized_protocol));
93 Ok(())
94 }
95
96 pub fn stop(&self) -> Result<()> {
97 let mut client = self.client.write();
98 if let Some(protocol) = client.take() {
99 drop(protocol);
100 }
101 Ok(())
102 }
103}
104
105pub struct ContextServerManager {
106 servers: HashMap<Arc<str>, Arc<ContextServer>>,
107 project: Entity<Project>,
108 registry: Entity<ContextServerFactoryRegistry>,
109 update_servers_task: Option<Task<Result<()>>>,
110 needs_server_update: bool,
111 _subscriptions: Vec<Subscription>,
112}
113
114pub enum Event {
115 ServerStarted { server_id: Arc<str> },
116 ServerStopped { server_id: Arc<str> },
117}
118
119impl EventEmitter<Event> for ContextServerManager {}
120
121impl ContextServerManager {
122 pub fn new(
123 registry: Entity<ContextServerFactoryRegistry>,
124 project: Entity<Project>,
125 cx: &mut Context<Self>,
126 ) -> Self {
127 let mut this = Self {
128 _subscriptions: vec![
129 cx.observe(®istry, |this, _registry, cx| {
130 this.available_context_servers_changed(cx);
131 }),
132 cx.observe_global::<SettingsStore>(|this, cx| {
133 this.available_context_servers_changed(cx);
134 }),
135 ],
136 project,
137 registry,
138 needs_server_update: false,
139 servers: HashMap::default(),
140 update_servers_task: None,
141 };
142 this.available_context_servers_changed(cx);
143 this
144 }
145
146 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
147 if self.update_servers_task.is_some() {
148 self.needs_server_update = true;
149 } else {
150 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
151 this.update(cx, |this, _| {
152 this.needs_server_update = false;
153 })?;
154
155 Self::maintain_servers(this.clone(), cx).await?;
156
157 this.update(cx, |this, cx| {
158 let has_any_context_servers = !this.running_servers().is_empty();
159 if has_any_context_servers {
160 CommandPaletteFilter::update_global(cx, |filter, _cx| {
161 filter.show_namespace(CONTEXT_SERVERS_NAMESPACE);
162 });
163 }
164
165 this.update_servers_task.take();
166 if this.needs_server_update {
167 this.available_context_servers_changed(cx);
168 }
169 })?;
170
171 Ok(())
172 }));
173 }
174 }
175
176 pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
177 self.servers
178 .get(id)
179 .filter(|server| server.client().is_some())
180 .cloned()
181 }
182
183 pub fn start_server(
184 &self,
185 server: Arc<ContextServer>,
186 cx: &mut Context<Self>,
187 ) -> Task<anyhow::Result<()>> {
188 cx.spawn(async move |this, cx| {
189 let id = server.id.clone();
190 server.start(&cx).await?;
191 this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
192 Ok(())
193 })
194 }
195
196 pub fn stop_server(
197 &self,
198 server: Arc<ContextServer>,
199 cx: &mut Context<Self>,
200 ) -> anyhow::Result<()> {
201 server.stop()?;
202 cx.emit(Event::ServerStopped {
203 server_id: server.id(),
204 });
205 Ok(())
206 }
207
208 pub fn restart_server(
209 &mut self,
210 id: &Arc<str>,
211 cx: &mut Context<Self>,
212 ) -> Task<anyhow::Result<()>> {
213 let id = id.clone();
214 cx.spawn(async move |this, cx| {
215 if let Some(server) = this.update(cx, |this, _cx| this.servers.remove(&id))? {
216 server.stop()?;
217 let config = server.config();
218 let new_server = Arc::new(ContextServer::new(id.clone(), config));
219 new_server.clone().start(&cx).await?;
220 this.update(cx, |this, cx| {
221 this.servers.insert(id.clone(), new_server);
222 cx.emit(Event::ServerStopped {
223 server_id: id.clone(),
224 });
225 cx.emit(Event::ServerStarted {
226 server_id: id.clone(),
227 });
228 })?;
229 }
230 Ok(())
231 })
232 }
233
234 pub fn all_servers(&self) -> Vec<Arc<ContextServer>> {
235 self.servers.values().cloned().collect()
236 }
237
238 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
239 self.servers
240 .values()
241 .filter(|server| server.client().is_some())
242 .cloned()
243 .collect()
244 }
245
246 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
247 let mut desired_servers = HashMap::default();
248
249 let (registry, project) = this.update(cx, |this, cx| {
250 let location = this
251 .project
252 .read(cx)
253 .visible_worktrees(cx)
254 .next()
255 .map(|worktree| settings::SettingsLocation {
256 worktree_id: worktree.read(cx).id(),
257 path: Path::new(""),
258 });
259 let settings = ContextServerSettings::get(location, cx);
260 desired_servers = settings.context_servers.clone();
261
262 (this.registry.clone(), this.project.clone())
263 })?;
264
265 for (id, factory) in
266 registry.read_with(cx, |registry, _| registry.context_server_factories())?
267 {
268 let config = desired_servers.entry(id).or_default();
269 if config.command.is_none() {
270 if let Some(extension_command) = factory(project.clone(), &cx).await.log_err() {
271 config.command = Some(extension_command);
272 }
273 }
274 }
275
276 let mut servers_to_start = HashMap::default();
277 let mut servers_to_stop = HashMap::default();
278
279 this.update(cx, |this, _cx| {
280 this.servers.retain(|id, server| {
281 if desired_servers.contains_key(id) {
282 true
283 } else {
284 servers_to_stop.insert(id.clone(), server.clone());
285 false
286 }
287 });
288
289 for (id, config) in desired_servers {
290 let existing_config = this.servers.get(&id).map(|server| server.config());
291 if existing_config.as_deref() != Some(&config) {
292 let config = Arc::new(config);
293 let server = Arc::new(ContextServer::new(id.clone(), config));
294 servers_to_start.insert(id.clone(), server.clone());
295 let old_server = this.servers.insert(id.clone(), server);
296 if let Some(old_server) = old_server {
297 servers_to_stop.insert(id, old_server);
298 }
299 }
300 }
301 })?;
302
303 for (id, server) in servers_to_stop {
304 server.stop().log_err();
305 this.update(cx, |_, cx| cx.emit(Event::ServerStopped { server_id: id }))?;
306 }
307
308 for (id, server) in servers_to_start {
309 if server.start(&cx).await.log_err().is_some() {
310 this.update(cx, |_, cx| cx.emit(Event::ServerStarted { server_id: id }))?;
311 }
312 }
313
314 Ok(())
315 }
316}