1pub mod extension;
2pub mod registry;
3
4use std::{path::Path, sync::Arc};
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
9use futures::{FutureExt as _, future::join_all};
10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
11use registry::ContextServerDescriptorRegistry;
12use settings::{Settings as _, SettingsStore};
13use util::ResultExt as _;
14
15use crate::{
16 project_settings::{ContextServerSettings, ProjectSettings},
17 worktree_store::WorktreeStore,
18};
19
20pub fn init(cx: &mut App) {
21 extension::init(cx);
22}
23
24actions!(
25 context_server,
26 [
27 /// Restarts the context server.
28 Restart
29 ]
30);
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33pub enum ContextServerStatus {
34 Starting,
35 Running,
36 Stopped,
37 Error(Arc<str>),
38}
39
40impl ContextServerStatus {
41 fn from_state(state: &ContextServerState) -> Self {
42 match state {
43 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
44 ContextServerState::Running { .. } => ContextServerStatus::Running,
45 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
46 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
47 }
48 }
49}
50
51enum ContextServerState {
52 Starting {
53 server: Arc<ContextServer>,
54 configuration: Arc<ContextServerConfiguration>,
55 _task: Task<()>,
56 },
57 Running {
58 server: Arc<ContextServer>,
59 configuration: Arc<ContextServerConfiguration>,
60 },
61 Stopped {
62 server: Arc<ContextServer>,
63 configuration: Arc<ContextServerConfiguration>,
64 },
65 Error {
66 server: Arc<ContextServer>,
67 configuration: Arc<ContextServerConfiguration>,
68 error: Arc<str>,
69 },
70}
71
72impl ContextServerState {
73 pub fn server(&self) -> Arc<ContextServer> {
74 match self {
75 ContextServerState::Starting { server, .. } => server.clone(),
76 ContextServerState::Running { server, .. } => server.clone(),
77 ContextServerState::Stopped { server, .. } => server.clone(),
78 ContextServerState::Error { server, .. } => server.clone(),
79 }
80 }
81
82 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
83 match self {
84 ContextServerState::Starting { configuration, .. } => configuration.clone(),
85 ContextServerState::Running { configuration, .. } => configuration.clone(),
86 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
87 ContextServerState::Error { configuration, .. } => configuration.clone(),
88 }
89 }
90}
91
92#[derive(Debug, PartialEq, Eq)]
93pub enum ContextServerConfiguration {
94 Custom {
95 command: ContextServerCommand,
96 },
97 Extension {
98 command: ContextServerCommand,
99 settings: serde_json::Value,
100 },
101}
102
103impl ContextServerConfiguration {
104 pub fn command(&self) -> &ContextServerCommand {
105 match self {
106 ContextServerConfiguration::Custom { command } => command,
107 ContextServerConfiguration::Extension { command, .. } => command,
108 }
109 }
110
111 pub async fn from_settings(
112 settings: ContextServerSettings,
113 id: ContextServerId,
114 registry: Entity<ContextServerDescriptorRegistry>,
115 worktree_store: Entity<WorktreeStore>,
116 cx: &AsyncApp,
117 ) -> Option<Self> {
118 match settings {
119 ContextServerSettings::Custom {
120 enabled: _,
121 command,
122 } => Some(ContextServerConfiguration::Custom { command }),
123 ContextServerSettings::Extension {
124 enabled: _,
125 settings,
126 } => {
127 let descriptor = cx
128 .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
129 .ok()
130 .flatten()?;
131
132 let command = descriptor.command(worktree_store, cx).await.log_err()?;
133
134 Some(ContextServerConfiguration::Extension { command, settings })
135 }
136 }
137 }
138}
139
140pub type ContextServerFactory =
141 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
142
143pub struct ContextServerStore {
144 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
145 servers: HashMap<ContextServerId, ContextServerState>,
146 worktree_store: Entity<WorktreeStore>,
147 registry: Entity<ContextServerDescriptorRegistry>,
148 update_servers_task: Option<Task<Result<()>>>,
149 context_server_factory: Option<ContextServerFactory>,
150 needs_server_update: bool,
151 _subscriptions: Vec<Subscription>,
152}
153
154pub enum Event {
155 ServerStatusChanged {
156 server_id: ContextServerId,
157 status: ContextServerStatus,
158 },
159}
160
161impl EventEmitter<Event> for ContextServerStore {}
162
163impl ContextServerStore {
164 pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
165 Self::new_internal(
166 true,
167 None,
168 ContextServerDescriptorRegistry::default_global(cx),
169 worktree_store,
170 cx,
171 )
172 }
173
174 /// Returns all configured context server ids, regardless of enabled state.
175 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
176 self.context_server_settings
177 .keys()
178 .cloned()
179 .map(ContextServerId)
180 .collect()
181 }
182
183 #[cfg(any(test, feature = "test-support"))]
184 pub fn test(
185 registry: Entity<ContextServerDescriptorRegistry>,
186 worktree_store: Entity<WorktreeStore>,
187 cx: &mut Context<Self>,
188 ) -> Self {
189 Self::new_internal(false, None, registry, worktree_store, cx)
190 }
191
192 #[cfg(any(test, feature = "test-support"))]
193 pub fn test_maintain_server_loop(
194 context_server_factory: ContextServerFactory,
195 registry: Entity<ContextServerDescriptorRegistry>,
196 worktree_store: Entity<WorktreeStore>,
197 cx: &mut Context<Self>,
198 ) -> Self {
199 Self::new_internal(
200 true,
201 Some(context_server_factory),
202 registry,
203 worktree_store,
204 cx,
205 )
206 }
207
208 fn new_internal(
209 maintain_server_loop: bool,
210 context_server_factory: Option<ContextServerFactory>,
211 registry: Entity<ContextServerDescriptorRegistry>,
212 worktree_store: Entity<WorktreeStore>,
213 cx: &mut Context<Self>,
214 ) -> Self {
215 let subscriptions = if maintain_server_loop {
216 vec![
217 cx.observe(®istry, |this, _registry, cx| {
218 this.available_context_servers_changed(cx);
219 }),
220 cx.observe_global::<SettingsStore>(|this, cx| {
221 let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
222 if &this.context_server_settings == settings {
223 return;
224 }
225 this.context_server_settings = settings.clone();
226 this.available_context_servers_changed(cx);
227 }),
228 ]
229 } else {
230 Vec::new()
231 };
232
233 let mut this = Self {
234 _subscriptions: subscriptions,
235 context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
236 .clone(),
237 worktree_store,
238 registry,
239 needs_server_update: false,
240 servers: HashMap::default(),
241 update_servers_task: None,
242 context_server_factory,
243 };
244 if maintain_server_loop {
245 this.available_context_servers_changed(cx);
246 }
247 this
248 }
249
250 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
251 self.servers.get(id).map(|state| state.server())
252 }
253
254 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
255 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
256 Some(server.clone())
257 } else {
258 None
259 }
260 }
261
262 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
263 self.servers.get(id).map(ContextServerStatus::from_state)
264 }
265
266 pub fn configuration_for_server(
267 &self,
268 id: &ContextServerId,
269 ) -> Option<Arc<ContextServerConfiguration>> {
270 self.servers.get(id).map(|state| state.configuration())
271 }
272
273 pub fn all_server_ids(&self) -> Vec<ContextServerId> {
274 self.servers.keys().cloned().collect()
275 }
276
277 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
278 self.servers
279 .values()
280 .filter_map(|state| {
281 if let ContextServerState::Running { server, .. } = state {
282 Some(server.clone())
283 } else {
284 None
285 }
286 })
287 .collect()
288 }
289
290 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
291 cx.spawn(async move |this, cx| {
292 let this = this.upgrade().context("Context server store dropped")?;
293 let settings = this
294 .update(cx, |this, _| {
295 this.context_server_settings.get(&server.id().0).cloned()
296 })
297 .ok()
298 .flatten()
299 .context("Failed to get context server settings")?;
300
301 if !settings.enabled() {
302 return Ok(());
303 }
304
305 let (registry, worktree_store) = this.update(cx, |this, _| {
306 (this.registry.clone(), this.worktree_store.clone())
307 })?;
308 let configuration = ContextServerConfiguration::from_settings(
309 settings,
310 server.id(),
311 registry,
312 worktree_store,
313 cx,
314 )
315 .await
316 .context("Failed to create context server configuration")?;
317
318 this.update(cx, |this, cx| {
319 this.run_server(server, Arc::new(configuration), cx)
320 })
321 })
322 .detach_and_log_err(cx);
323 }
324
325 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
326 if matches!(
327 self.servers.get(id),
328 Some(ContextServerState::Stopped { .. })
329 ) {
330 return Ok(());
331 }
332
333 let state = self
334 .servers
335 .remove(id)
336 .context("Context server not found")?;
337
338 let server = state.server();
339 let configuration = state.configuration();
340 let mut result = Ok(());
341 if let ContextServerState::Running { server, .. } = &state {
342 result = server.stop();
343 }
344 drop(state);
345
346 self.update_server_state(
347 id.clone(),
348 ContextServerState::Stopped {
349 configuration,
350 server,
351 },
352 cx,
353 );
354
355 result
356 }
357
358 pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
359 if let Some(state) = self.servers.get(&id) {
360 let configuration = state.configuration();
361
362 self.stop_server(&state.server().id(), cx)?;
363 let new_server = self.create_context_server(id.clone(), configuration.clone())?;
364 self.run_server(new_server, configuration, cx);
365 }
366 Ok(())
367 }
368
369 fn run_server(
370 &mut self,
371 server: Arc<ContextServer>,
372 configuration: Arc<ContextServerConfiguration>,
373 cx: &mut Context<Self>,
374 ) {
375 let id = server.id();
376 if matches!(
377 self.servers.get(&id),
378 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
379 ) {
380 self.stop_server(&id, cx).log_err();
381 }
382
383 let task = cx.spawn({
384 let id = server.id();
385 let server = server.clone();
386 let configuration = configuration.clone();
387 async move |this, cx| {
388 match server.clone().start(&cx).await {
389 Ok(_) => {
390 log::info!("Started {} context server", id);
391 debug_assert!(server.client().is_some());
392
393 this.update(cx, |this, cx| {
394 this.update_server_state(
395 id.clone(),
396 ContextServerState::Running {
397 server,
398 configuration,
399 },
400 cx,
401 )
402 })
403 .log_err()
404 }
405 Err(err) => {
406 log::error!("{} context server failed to start: {}", id, err);
407 this.update(cx, |this, cx| {
408 this.update_server_state(
409 id.clone(),
410 ContextServerState::Error {
411 configuration,
412 server,
413 error: err.to_string().into(),
414 },
415 cx,
416 )
417 })
418 .log_err()
419 }
420 };
421 }
422 });
423
424 self.update_server_state(
425 id.clone(),
426 ContextServerState::Starting {
427 configuration,
428 _task: task,
429 server,
430 },
431 cx,
432 );
433 }
434
435 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
436 let state = self
437 .servers
438 .remove(id)
439 .context("Context server not found")?;
440 drop(state);
441 cx.emit(Event::ServerStatusChanged {
442 server_id: id.clone(),
443 status: ContextServerStatus::Stopped,
444 });
445 Ok(())
446 }
447
448 fn create_context_server(
449 &self,
450 id: ContextServerId,
451 configuration: Arc<ContextServerConfiguration>,
452 ) -> Result<Arc<ContextServer>> {
453 if let Some(factory) = self.context_server_factory.as_ref() {
454 Ok(factory(id, configuration))
455 } else {
456 Ok(Arc::new(ContextServer::stdio(
457 id,
458 configuration.command().clone(),
459 )))
460 }
461 }
462
463 fn resolve_context_server_settings<'a>(
464 worktree_store: &'a Entity<WorktreeStore>,
465 cx: &'a App,
466 ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
467 let location = worktree_store
468 .read(cx)
469 .visible_worktrees(cx)
470 .next()
471 .map(|worktree| settings::SettingsLocation {
472 worktree_id: worktree.read(cx).id(),
473 path: Path::new(""),
474 });
475 &ProjectSettings::get(location, cx).context_servers
476 }
477
478 fn update_server_state(
479 &mut self,
480 id: ContextServerId,
481 state: ContextServerState,
482 cx: &mut Context<Self>,
483 ) {
484 let status = ContextServerStatus::from_state(&state);
485 self.servers.insert(id.clone(), state);
486 cx.emit(Event::ServerStatusChanged {
487 server_id: id,
488 status,
489 });
490 }
491
492 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
493 if self.update_servers_task.is_some() {
494 self.needs_server_update = true;
495 } else {
496 self.needs_server_update = false;
497 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
498 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
499 log::error!("Error maintaining context servers: {}", err);
500 }
501
502 this.update(cx, |this, cx| {
503 this.update_servers_task.take();
504 if this.needs_server_update {
505 this.available_context_servers_changed(cx);
506 }
507 })?;
508
509 Ok(())
510 }));
511 }
512 }
513
514 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
515 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
516 (
517 this.context_server_settings.clone(),
518 this.registry.clone(),
519 this.worktree_store.clone(),
520 )
521 })?;
522
523 for (id, _) in
524 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
525 {
526 configured_servers
527 .entry(id)
528 .or_insert(ContextServerSettings::default_extension());
529 }
530
531 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
532 configured_servers
533 .into_iter()
534 .partition(|(_, settings)| settings.enabled());
535
536 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
537 let id = ContextServerId(id);
538 ContextServerConfiguration::from_settings(
539 settings,
540 id.clone(),
541 registry.clone(),
542 worktree_store.clone(),
543 cx,
544 )
545 .map(|config| (id, config))
546 }))
547 .await
548 .into_iter()
549 .filter_map(|(id, config)| config.map(|config| (id, config)))
550 .collect::<HashMap<_, _>>();
551
552 let mut servers_to_start = Vec::new();
553 let mut servers_to_remove = HashSet::default();
554 let mut servers_to_stop = HashSet::default();
555
556 this.update(cx, |this, _cx| {
557 for server_id in this.servers.keys() {
558 // All servers that are not in desired_servers should be removed from the store.
559 // This can happen if the user removed a server from the context server settings.
560 if !configured_servers.contains_key(&server_id) {
561 if disabled_servers.contains_key(&server_id.0) {
562 servers_to_stop.insert(server_id.clone());
563 } else {
564 servers_to_remove.insert(server_id.clone());
565 }
566 }
567 }
568
569 for (id, config) in configured_servers {
570 let state = this.servers.get(&id);
571 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
572 let existing_config = state.as_ref().map(|state| state.configuration());
573 if existing_config.as_deref() != Some(&config) || is_stopped {
574 let config = Arc::new(config);
575 if let Some(server) = this
576 .create_context_server(id.clone(), config.clone())
577 .log_err()
578 {
579 servers_to_start.push((server, config));
580 if this.servers.contains_key(&id) {
581 servers_to_stop.insert(id);
582 }
583 }
584 }
585 }
586 })?;
587
588 this.update(cx, |this, cx| {
589 for id in servers_to_stop {
590 this.stop_server(&id, cx)?;
591 }
592 for id in servers_to_remove {
593 this.remove_server(&id, cx)?;
594 }
595 for (server, config) in servers_to_start {
596 this.run_server(server, config, cx);
597 }
598 anyhow::Ok(())
599 })?
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606 use crate::{
607 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
608 project_settings::ProjectSettings,
609 };
610 use context_server::test::create_fake_transport;
611 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
612 use serde_json::json;
613 use std::{cell::RefCell, path::PathBuf, rc::Rc};
614 use util::path;
615
616 #[gpui::test]
617 async fn test_context_server_status(cx: &mut TestAppContext) {
618 const SERVER_1_ID: &'static str = "mcp-1";
619 const SERVER_2_ID: &'static str = "mcp-2";
620
621 let (_fs, project) = setup_context_server_test(
622 cx,
623 json!({"code.rs": ""}),
624 vec![
625 (SERVER_1_ID.into(), dummy_server_settings()),
626 (SERVER_2_ID.into(), dummy_server_settings()),
627 ],
628 )
629 .await;
630
631 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
632 let store = cx.new(|cx| {
633 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
634 });
635
636 let server_1_id = ContextServerId(SERVER_1_ID.into());
637 let server_2_id = ContextServerId(SERVER_2_ID.into());
638
639 let server_1 = Arc::new(ContextServer::new(
640 server_1_id.clone(),
641 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
642 ));
643 let server_2 = Arc::new(ContextServer::new(
644 server_2_id.clone(),
645 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
646 ));
647
648 store.update(cx, |store, cx| store.start_server(server_1, cx));
649
650 cx.run_until_parked();
651
652 cx.update(|cx| {
653 assert_eq!(
654 store.read(cx).status_for_server(&server_1_id),
655 Some(ContextServerStatus::Running)
656 );
657 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
658 });
659
660 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
661
662 cx.run_until_parked();
663
664 cx.update(|cx| {
665 assert_eq!(
666 store.read(cx).status_for_server(&server_1_id),
667 Some(ContextServerStatus::Running)
668 );
669 assert_eq!(
670 store.read(cx).status_for_server(&server_2_id),
671 Some(ContextServerStatus::Running)
672 );
673 });
674
675 store
676 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
677 .unwrap();
678
679 cx.update(|cx| {
680 assert_eq!(
681 store.read(cx).status_for_server(&server_1_id),
682 Some(ContextServerStatus::Running)
683 );
684 assert_eq!(
685 store.read(cx).status_for_server(&server_2_id),
686 Some(ContextServerStatus::Stopped)
687 );
688 });
689 }
690
691 #[gpui::test]
692 async fn test_context_server_status_events(cx: &mut TestAppContext) {
693 const SERVER_1_ID: &'static str = "mcp-1";
694 const SERVER_2_ID: &'static str = "mcp-2";
695
696 let (_fs, project) = setup_context_server_test(
697 cx,
698 json!({"code.rs": ""}),
699 vec![
700 (SERVER_1_ID.into(), dummy_server_settings()),
701 (SERVER_2_ID.into(), dummy_server_settings()),
702 ],
703 )
704 .await;
705
706 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
707 let store = cx.new(|cx| {
708 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
709 });
710
711 let server_1_id = ContextServerId(SERVER_1_ID.into());
712 let server_2_id = ContextServerId(SERVER_2_ID.into());
713
714 let server_1 = Arc::new(ContextServer::new(
715 server_1_id.clone(),
716 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
717 ));
718 let server_2 = Arc::new(ContextServer::new(
719 server_2_id.clone(),
720 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
721 ));
722
723 let _server_events = assert_server_events(
724 &store,
725 vec![
726 (server_1_id.clone(), ContextServerStatus::Starting),
727 (server_1_id.clone(), ContextServerStatus::Running),
728 (server_2_id.clone(), ContextServerStatus::Starting),
729 (server_2_id.clone(), ContextServerStatus::Running),
730 (server_2_id.clone(), ContextServerStatus::Stopped),
731 ],
732 cx,
733 );
734
735 store.update(cx, |store, cx| store.start_server(server_1, cx));
736
737 cx.run_until_parked();
738
739 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
740
741 cx.run_until_parked();
742
743 store
744 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
745 .unwrap();
746 }
747
748 #[gpui::test(iterations = 25)]
749 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
750 const SERVER_1_ID: &'static str = "mcp-1";
751
752 let (_fs, project) = setup_context_server_test(
753 cx,
754 json!({"code.rs": ""}),
755 vec![(SERVER_1_ID.into(), dummy_server_settings())],
756 )
757 .await;
758
759 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
760 let store = cx.new(|cx| {
761 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
762 });
763
764 let server_id = ContextServerId(SERVER_1_ID.into());
765
766 let server_with_same_id_1 = Arc::new(ContextServer::new(
767 server_id.clone(),
768 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
769 ));
770 let server_with_same_id_2 = Arc::new(ContextServer::new(
771 server_id.clone(),
772 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
773 ));
774
775 // If we start another server with the same id, we should report that we stopped the previous one
776 let _server_events = assert_server_events(
777 &store,
778 vec![
779 (server_id.clone(), ContextServerStatus::Starting),
780 (server_id.clone(), ContextServerStatus::Stopped),
781 (server_id.clone(), ContextServerStatus::Starting),
782 (server_id.clone(), ContextServerStatus::Running),
783 ],
784 cx,
785 );
786
787 store.update(cx, |store, cx| {
788 store.start_server(server_with_same_id_1.clone(), cx)
789 });
790 store.update(cx, |store, cx| {
791 store.start_server(server_with_same_id_2.clone(), cx)
792 });
793
794 cx.run_until_parked();
795
796 cx.update(|cx| {
797 assert_eq!(
798 store.read(cx).status_for_server(&server_id),
799 Some(ContextServerStatus::Running)
800 );
801 });
802 }
803
804 #[gpui::test]
805 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
806 const SERVER_1_ID: &'static str = "mcp-1";
807 const SERVER_2_ID: &'static str = "mcp-2";
808
809 let server_1_id = ContextServerId(SERVER_1_ID.into());
810 let server_2_id = ContextServerId(SERVER_2_ID.into());
811
812 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
813
814 let (_fs, project) = setup_context_server_test(
815 cx,
816 json!({"code.rs": ""}),
817 vec![(
818 SERVER_1_ID.into(),
819 ContextServerSettings::Extension {
820 enabled: true,
821 settings: json!({
822 "somevalue": true
823 }),
824 },
825 )],
826 )
827 .await;
828
829 let executor = cx.executor();
830 let registry = cx.new(|cx| {
831 let mut registry = ContextServerDescriptorRegistry::new();
832 registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
833 registry
834 });
835 let store = cx.new(|cx| {
836 ContextServerStore::test_maintain_server_loop(
837 Box::new(move |id, _| {
838 Arc::new(ContextServer::new(
839 id.clone(),
840 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
841 ))
842 }),
843 registry.clone(),
844 project.read(cx).worktree_store(),
845 cx,
846 )
847 });
848
849 // Ensure that mcp-1 starts up
850 {
851 let _server_events = assert_server_events(
852 &store,
853 vec![
854 (server_1_id.clone(), ContextServerStatus::Starting),
855 (server_1_id.clone(), ContextServerStatus::Running),
856 ],
857 cx,
858 );
859 cx.run_until_parked();
860 }
861
862 // Ensure that mcp-1 is restarted when the configuration was changed
863 {
864 let _server_events = assert_server_events(
865 &store,
866 vec![
867 (server_1_id.clone(), ContextServerStatus::Stopped),
868 (server_1_id.clone(), ContextServerStatus::Starting),
869 (server_1_id.clone(), ContextServerStatus::Running),
870 ],
871 cx,
872 );
873 set_context_server_configuration(
874 vec![(
875 server_1_id.0.clone(),
876 ContextServerSettings::Extension {
877 enabled: true,
878 settings: json!({
879 "somevalue": false
880 }),
881 },
882 )],
883 cx,
884 );
885
886 cx.run_until_parked();
887 }
888
889 // Ensure that mcp-1 is not restarted when the configuration was not changed
890 {
891 let _server_events = assert_server_events(&store, vec![], cx);
892 set_context_server_configuration(
893 vec![(
894 server_1_id.0.clone(),
895 ContextServerSettings::Extension {
896 enabled: true,
897 settings: json!({
898 "somevalue": false
899 }),
900 },
901 )],
902 cx,
903 );
904
905 cx.run_until_parked();
906 }
907
908 // Ensure that mcp-2 is started once it is added to the settings
909 {
910 let _server_events = assert_server_events(
911 &store,
912 vec![
913 (server_2_id.clone(), ContextServerStatus::Starting),
914 (server_2_id.clone(), ContextServerStatus::Running),
915 ],
916 cx,
917 );
918 set_context_server_configuration(
919 vec![
920 (
921 server_1_id.0.clone(),
922 ContextServerSettings::Extension {
923 enabled: true,
924 settings: json!({
925 "somevalue": false
926 }),
927 },
928 ),
929 (
930 server_2_id.0.clone(),
931 ContextServerSettings::Custom {
932 enabled: true,
933 command: ContextServerCommand {
934 path: "somebinary".into(),
935 args: vec!["arg".to_string()],
936 env: None,
937 },
938 },
939 ),
940 ],
941 cx,
942 );
943
944 cx.run_until_parked();
945 }
946
947 // Ensure that mcp-2 is restarted once the args have changed
948 {
949 let _server_events = assert_server_events(
950 &store,
951 vec![
952 (server_2_id.clone(), ContextServerStatus::Stopped),
953 (server_2_id.clone(), ContextServerStatus::Starting),
954 (server_2_id.clone(), ContextServerStatus::Running),
955 ],
956 cx,
957 );
958 set_context_server_configuration(
959 vec![
960 (
961 server_1_id.0.clone(),
962 ContextServerSettings::Extension {
963 enabled: true,
964 settings: json!({
965 "somevalue": false
966 }),
967 },
968 ),
969 (
970 server_2_id.0.clone(),
971 ContextServerSettings::Custom {
972 enabled: true,
973 command: ContextServerCommand {
974 path: "somebinary".into(),
975 args: vec!["anotherArg".to_string()],
976 env: None,
977 },
978 },
979 ),
980 ],
981 cx,
982 );
983
984 cx.run_until_parked();
985 }
986
987 // Ensure that mcp-2 is removed once it is removed from the settings
988 {
989 let _server_events = assert_server_events(
990 &store,
991 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
992 cx,
993 );
994 set_context_server_configuration(
995 vec![(
996 server_1_id.0.clone(),
997 ContextServerSettings::Extension {
998 enabled: true,
999 settings: json!({
1000 "somevalue": false
1001 }),
1002 },
1003 )],
1004 cx,
1005 );
1006
1007 cx.run_until_parked();
1008
1009 cx.update(|cx| {
1010 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1011 });
1012 }
1013
1014 // Ensure that nothing happens if the settings do not change
1015 {
1016 let _server_events = assert_server_events(&store, vec![], cx);
1017 set_context_server_configuration(
1018 vec![(
1019 server_1_id.0.clone(),
1020 ContextServerSettings::Extension {
1021 enabled: true,
1022 settings: json!({
1023 "somevalue": false
1024 }),
1025 },
1026 )],
1027 cx,
1028 );
1029
1030 cx.run_until_parked();
1031
1032 cx.update(|cx| {
1033 assert_eq!(
1034 store.read(cx).status_for_server(&server_1_id),
1035 Some(ContextServerStatus::Running)
1036 );
1037 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1038 });
1039 }
1040 }
1041
1042 #[gpui::test]
1043 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1044 const SERVER_1_ID: &'static str = "mcp-1";
1045
1046 let server_1_id = ContextServerId(SERVER_1_ID.into());
1047
1048 let (_fs, project) = setup_context_server_test(
1049 cx,
1050 json!({"code.rs": ""}),
1051 vec![(
1052 SERVER_1_ID.into(),
1053 ContextServerSettings::Custom {
1054 enabled: true,
1055 command: ContextServerCommand {
1056 path: "somebinary".into(),
1057 args: vec!["arg".to_string()],
1058 env: None,
1059 },
1060 },
1061 )],
1062 )
1063 .await;
1064
1065 let executor = cx.executor();
1066 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1067 let store = cx.new(|cx| {
1068 ContextServerStore::test_maintain_server_loop(
1069 Box::new(move |id, _| {
1070 Arc::new(ContextServer::new(
1071 id.clone(),
1072 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1073 ))
1074 }),
1075 registry.clone(),
1076 project.read(cx).worktree_store(),
1077 cx,
1078 )
1079 });
1080
1081 // Ensure that mcp-1 starts up
1082 {
1083 let _server_events = assert_server_events(
1084 &store,
1085 vec![
1086 (server_1_id.clone(), ContextServerStatus::Starting),
1087 (server_1_id.clone(), ContextServerStatus::Running),
1088 ],
1089 cx,
1090 );
1091 cx.run_until_parked();
1092 }
1093
1094 // Ensure that mcp-1 is stopped once it is disabled.
1095 {
1096 let _server_events = assert_server_events(
1097 &store,
1098 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1099 cx,
1100 );
1101 set_context_server_configuration(
1102 vec![(
1103 server_1_id.0.clone(),
1104 ContextServerSettings::Custom {
1105 enabled: false,
1106 command: ContextServerCommand {
1107 path: "somebinary".into(),
1108 args: vec!["arg".to_string()],
1109 env: None,
1110 },
1111 },
1112 )],
1113 cx,
1114 );
1115
1116 cx.run_until_parked();
1117 }
1118
1119 // Ensure that mcp-1 is started once it is enabled again.
1120 {
1121 let _server_events = assert_server_events(
1122 &store,
1123 vec![
1124 (server_1_id.clone(), ContextServerStatus::Starting),
1125 (server_1_id.clone(), ContextServerStatus::Running),
1126 ],
1127 cx,
1128 );
1129 set_context_server_configuration(
1130 vec![(
1131 server_1_id.0.clone(),
1132 ContextServerSettings::Custom {
1133 enabled: true,
1134 command: ContextServerCommand {
1135 path: "somebinary".into(),
1136 args: vec!["arg".to_string()],
1137 env: None,
1138 },
1139 },
1140 )],
1141 cx,
1142 );
1143
1144 cx.run_until_parked();
1145 }
1146 }
1147
1148 fn set_context_server_configuration(
1149 context_servers: Vec<(Arc<str>, ContextServerSettings)>,
1150 cx: &mut TestAppContext,
1151 ) {
1152 cx.update(|cx| {
1153 SettingsStore::update_global(cx, |store, cx| {
1154 let mut settings = ProjectSettings::default();
1155 for (id, config) in context_servers {
1156 settings.context_servers.insert(id, config);
1157 }
1158 store
1159 .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
1160 .unwrap();
1161 })
1162 });
1163 }
1164
1165 struct ServerEvents {
1166 received_event_count: Rc<RefCell<usize>>,
1167 expected_event_count: usize,
1168 _subscription: Subscription,
1169 }
1170
1171 impl Drop for ServerEvents {
1172 fn drop(&mut self) {
1173 let actual_event_count = *self.received_event_count.borrow();
1174 assert_eq!(
1175 actual_event_count, self.expected_event_count,
1176 "
1177 Expected to receive {} context server store events, but received {} events",
1178 self.expected_event_count, actual_event_count
1179 );
1180 }
1181 }
1182
1183 fn dummy_server_settings() -> ContextServerSettings {
1184 ContextServerSettings::Custom {
1185 enabled: true,
1186 command: ContextServerCommand {
1187 path: "somebinary".into(),
1188 args: vec!["arg".to_string()],
1189 env: None,
1190 },
1191 }
1192 }
1193
1194 fn assert_server_events(
1195 store: &Entity<ContextServerStore>,
1196 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1197 cx: &mut TestAppContext,
1198 ) -> ServerEvents {
1199 cx.update(|cx| {
1200 let mut ix = 0;
1201 let received_event_count = Rc::new(RefCell::new(0));
1202 let expected_event_count = expected_events.len();
1203 let subscription = cx.subscribe(store, {
1204 let received_event_count = received_event_count.clone();
1205 move |_, event, _| match event {
1206 Event::ServerStatusChanged {
1207 server_id: actual_server_id,
1208 status: actual_status,
1209 } => {
1210 let (expected_server_id, expected_status) = &expected_events[ix];
1211
1212 assert_eq!(
1213 actual_server_id, expected_server_id,
1214 "Expected different server id at index {}",
1215 ix
1216 );
1217 assert_eq!(
1218 actual_status, expected_status,
1219 "Expected different status at index {}",
1220 ix
1221 );
1222 ix += 1;
1223 *received_event_count.borrow_mut() += 1;
1224 }
1225 }
1226 });
1227 ServerEvents {
1228 expected_event_count,
1229 received_event_count,
1230 _subscription: subscription,
1231 }
1232 })
1233 }
1234
1235 async fn setup_context_server_test(
1236 cx: &mut TestAppContext,
1237 files: serde_json::Value,
1238 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1239 ) -> (Arc<FakeFs>, Entity<Project>) {
1240 cx.update(|cx| {
1241 let settings_store = SettingsStore::test(cx);
1242 cx.set_global(settings_store);
1243 Project::init_settings(cx);
1244 let mut settings = ProjectSettings::get_global(cx).clone();
1245 for (id, config) in context_server_configurations {
1246 settings.context_servers.insert(id, config);
1247 }
1248 ProjectSettings::override_global(settings, cx);
1249 });
1250
1251 let fs = FakeFs::new(cx.executor());
1252 fs.insert_tree(path!("/test"), files).await;
1253 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1254
1255 (fs, project)
1256 }
1257
1258 struct FakeContextServerDescriptor {
1259 path: PathBuf,
1260 }
1261
1262 impl FakeContextServerDescriptor {
1263 fn new(path: impl Into<PathBuf>) -> Self {
1264 Self { path: path.into() }
1265 }
1266 }
1267
1268 impl ContextServerDescriptor for FakeContextServerDescriptor {
1269 fn command(
1270 &self,
1271 _worktree_store: Entity<WorktreeStore>,
1272 _cx: &AsyncApp,
1273 ) -> Task<Result<ContextServerCommand>> {
1274 Task::ready(Ok(ContextServerCommand {
1275 path: self.path.clone(),
1276 args: vec!["arg1".to_string(), "arg2".to_string()],
1277 env: None,
1278 }))
1279 }
1280
1281 fn configuration(
1282 &self,
1283 _worktree_store: Entity<WorktreeStore>,
1284 _cx: &AsyncApp,
1285 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1286 Task::ready(Ok(None))
1287 }
1288 }
1289}