1//! This module implements a context server management system for Zed.
2//!
3//! It provides functionality to:
4//! - Define and load context server settings
5//! - Manage individual context servers (start, stop, restart)
6//! - Maintain a global manager for all context servers
7//!
8//! Key components:
9//! - `ContextServerSettings`: Defines the structure for server configurations
10//! - `ContextServer`: Represents an individual context server
11//! - `ContextServerManager`: Manages multiple context servers
12//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
13//!
14//! The module also includes initialization logic to set up the context server system
15//! and react to changes in settings.
16
17use std::path::Path;
18use std::sync::Arc;
19
20use collections::{HashMap, HashSet};
21use gpui::{AsyncAppContext, EventEmitter, ModelContext, Task};
22use log;
23use parking_lot::RwLock;
24use schemars::JsonSchema;
25use serde::{Deserialize, Serialize};
26use settings::{Settings, SettingsSources};
27
28use crate::{
29 client::{self, Client},
30 types,
31};
32
33#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
34pub struct ContextServerSettings {
35 pub servers: Vec<ServerConfig>,
36}
37
38#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
39pub struct ServerConfig {
40 pub id: String,
41 pub executable: String,
42 pub args: Vec<String>,
43 pub env: Option<HashMap<String, String>>,
44}
45
46impl Settings for ContextServerSettings {
47 const KEY: Option<&'static str> = Some("experimental.context_servers");
48
49 type FileContent = Self;
50
51 fn load(
52 sources: SettingsSources<Self::FileContent>,
53 _: &mut gpui::AppContext,
54 ) -> anyhow::Result<Self> {
55 sources.json_merge()
56 }
57}
58
59pub struct ContextServer {
60 pub id: String,
61 pub config: ServerConfig,
62 pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
63}
64
65impl ContextServer {
66 fn new(config: ServerConfig) -> Self {
67 Self {
68 id: config.id.clone(),
69 config,
70 client: RwLock::new(None),
71 }
72 }
73
74 async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> {
75 log::info!("starting context server {}", self.config.id,);
76 let client = Client::new(
77 client::ContextServerId(self.config.id.clone()),
78 client::ModelContextServerBinary {
79 executable: Path::new(&self.config.executable).to_path_buf(),
80 args: self.config.args.clone(),
81 env: self.config.env.clone(),
82 },
83 cx.clone(),
84 )?;
85
86 let protocol = crate::protocol::ModelContextProtocol::new(client);
87 let client_info = types::Implementation {
88 name: "Zed".to_string(),
89 version: env!("CARGO_PKG_VERSION").to_string(),
90 };
91 let initialized_protocol = protocol.initialize(client_info).await?;
92
93 log::debug!(
94 "context server {} initialized: {:?}",
95 self.config.id,
96 initialized_protocol.initialize,
97 );
98
99 *self.client.write() = Some(Arc::new(initialized_protocol));
100 Ok(())
101 }
102
103 async fn stop(&self) -> anyhow::Result<()> {
104 let mut client = self.client.write();
105 if let Some(protocol) = client.take() {
106 drop(protocol);
107 }
108 Ok(())
109 }
110}
111
112/// A Context server manager manages the starting and stopping
113/// of all servers. To obtain a server to interact with, a crate
114/// must go through the `GlobalContextServerManager` which holds
115/// a model to the ContextServerManager.
116pub struct ContextServerManager {
117 servers: HashMap<String, Arc<ContextServer>>,
118 pending_servers: HashSet<String>,
119}
120
121pub enum Event {
122 ServerStarted { server_id: String },
123 ServerStopped { server_id: String },
124}
125
126impl EventEmitter<Event> for ContextServerManager {}
127
128impl Default for ContextServerManager {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl ContextServerManager {
135 pub fn new() -> Self {
136 Self {
137 servers: HashMap::default(),
138 pending_servers: HashSet::default(),
139 }
140 }
141
142 pub fn add_server(
143 &mut self,
144 config: ServerConfig,
145 cx: &ModelContext<Self>,
146 ) -> Task<anyhow::Result<()>> {
147 let server_id = config.id.clone();
148
149 if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
150 return Task::ready(Ok(()));
151 }
152
153 let task = {
154 let server_id = server_id.clone();
155 cx.spawn(|this, mut cx| async move {
156 let server = Arc::new(ContextServer::new(config));
157 server.start(&cx).await?;
158 this.update(&mut cx, |this, cx| {
159 this.servers.insert(server_id.clone(), server);
160 this.pending_servers.remove(&server_id);
161 cx.emit(Event::ServerStarted {
162 server_id: server_id.clone(),
163 });
164 })?;
165 Ok(())
166 })
167 };
168
169 self.pending_servers.insert(server_id);
170 task
171 }
172
173 pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
174 self.servers.get(id).cloned()
175 }
176
177 pub fn remove_server(&mut self, id: &str, cx: &ModelContext<Self>) -> Task<anyhow::Result<()>> {
178 let id = id.to_string();
179 cx.spawn(|this, mut cx| async move {
180 if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
181 server.stop().await?;
182 }
183 this.update(&mut cx, |this, cx| {
184 this.pending_servers.remove(&id);
185 cx.emit(Event::ServerStopped {
186 server_id: id.clone(),
187 })
188 })?;
189 Ok(())
190 })
191 }
192
193 pub fn restart_server(
194 &mut self,
195 id: &str,
196 cx: &mut ModelContext<Self>,
197 ) -> Task<anyhow::Result<()>> {
198 let id = id.to_string();
199 cx.spawn(|this, mut cx| async move {
200 if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
201 server.stop().await?;
202 let config = server.config.clone();
203 let new_server = Arc::new(ContextServer::new(config));
204 new_server.start(&cx).await?;
205 this.update(&mut cx, |this, cx| {
206 this.servers.insert(id.clone(), new_server);
207 cx.emit(Event::ServerStopped {
208 server_id: id.clone(),
209 });
210 cx.emit(Event::ServerStarted {
211 server_id: id.clone(),
212 });
213 })?;
214 }
215 Ok(())
216 })
217 }
218
219 pub fn servers(&self) -> Vec<Arc<ContextServer>> {
220 self.servers.values().cloned().collect()
221 }
222
223 pub fn maintain_servers(&mut self, settings: &ContextServerSettings, cx: &ModelContext<Self>) {
224 let current_servers = self
225 .servers()
226 .into_iter()
227 .map(|server| (server.id.clone(), server.config.clone()))
228 .collect::<HashMap<_, _>>();
229
230 let new_servers = settings
231 .servers
232 .iter()
233 .map(|config| (config.id.clone(), config.clone()))
234 .collect::<HashMap<_, _>>();
235
236 let servers_to_add = new_servers
237 .values()
238 .filter(|config| !current_servers.contains_key(&config.id))
239 .cloned()
240 .collect::<Vec<_>>();
241
242 let servers_to_remove = current_servers
243 .keys()
244 .filter(|id| !new_servers.contains_key(*id))
245 .cloned()
246 .collect::<Vec<_>>();
247
248 log::trace!("servers_to_add={:?}", servers_to_add);
249 for config in servers_to_add {
250 self.add_server(config, cx).detach_and_log_err(cx);
251 }
252
253 for id in servers_to_remove {
254 self.remove_server(&id, cx).detach_and_log_err(cx);
255 }
256 }
257}