1pub mod extension;
2pub mod registry;
3
4use std::{path::Path, sync::Arc};
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
9use futures::{FutureExt as _, future::join_all};
10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
11use registry::ContextServerDescriptorRegistry;
12use settings::{Settings as _, SettingsStore};
13use util::ResultExt as _;
14
15use crate::{
16 project_settings::{ContextServerSettings, ProjectSettings},
17 worktree_store::WorktreeStore,
18};
19
20pub fn init(cx: &mut App) {
21 extension::init(cx);
22}
23
24actions!(context_server, [Restart]);
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27pub enum ContextServerStatus {
28 Starting,
29 Running,
30 Stopped,
31 Error(Arc<str>),
32}
33
34impl ContextServerStatus {
35 fn from_state(state: &ContextServerState) -> Self {
36 match state {
37 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
38 ContextServerState::Running { .. } => ContextServerStatus::Running,
39 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
40 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
41 }
42 }
43}
44
45enum ContextServerState {
46 Starting {
47 server: Arc<ContextServer>,
48 configuration: Arc<ContextServerConfiguration>,
49 _task: Task<()>,
50 },
51 Running {
52 server: Arc<ContextServer>,
53 configuration: Arc<ContextServerConfiguration>,
54 },
55 Stopped {
56 server: Arc<ContextServer>,
57 configuration: Arc<ContextServerConfiguration>,
58 },
59 Error {
60 server: Arc<ContextServer>,
61 configuration: Arc<ContextServerConfiguration>,
62 error: Arc<str>,
63 },
64}
65
66impl ContextServerState {
67 pub fn server(&self) -> Arc<ContextServer> {
68 match self {
69 ContextServerState::Starting { server, .. } => server.clone(),
70 ContextServerState::Running { server, .. } => server.clone(),
71 ContextServerState::Stopped { server, .. } => server.clone(),
72 ContextServerState::Error { server, .. } => server.clone(),
73 }
74 }
75
76 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
77 match self {
78 ContextServerState::Starting { configuration, .. } => configuration.clone(),
79 ContextServerState::Running { configuration, .. } => configuration.clone(),
80 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
81 ContextServerState::Error { configuration, .. } => configuration.clone(),
82 }
83 }
84}
85
86#[derive(Debug, PartialEq, Eq)]
87pub enum ContextServerConfiguration {
88 Custom {
89 command: ContextServerCommand,
90 },
91 Extension {
92 command: ContextServerCommand,
93 settings: serde_json::Value,
94 },
95}
96
97impl ContextServerConfiguration {
98 pub fn command(&self) -> &ContextServerCommand {
99 match self {
100 ContextServerConfiguration::Custom { command } => command,
101 ContextServerConfiguration::Extension { command, .. } => command,
102 }
103 }
104
105 pub async fn from_settings(
106 settings: ContextServerSettings,
107 id: ContextServerId,
108 registry: Entity<ContextServerDescriptorRegistry>,
109 worktree_store: Entity<WorktreeStore>,
110 cx: &AsyncApp,
111 ) -> Option<Self> {
112 match settings {
113 ContextServerSettings::Custom {
114 enabled: _,
115 command,
116 } => Some(ContextServerConfiguration::Custom { command }),
117 ContextServerSettings::Extension {
118 enabled: _,
119 settings,
120 } => {
121 let descriptor = cx
122 .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
123 .ok()
124 .flatten()?;
125
126 let command = descriptor.command(worktree_store, cx).await.log_err()?;
127
128 Some(ContextServerConfiguration::Extension { command, settings })
129 }
130 }
131 }
132}
133
134pub type ContextServerFactory =
135 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
136
137pub struct ContextServerStore {
138 servers: HashMap<ContextServerId, ContextServerState>,
139 worktree_store: Entity<WorktreeStore>,
140 registry: Entity<ContextServerDescriptorRegistry>,
141 update_servers_task: Option<Task<Result<()>>>,
142 context_server_factory: Option<ContextServerFactory>,
143 needs_server_update: bool,
144 _subscriptions: Vec<Subscription>,
145}
146
147pub enum Event {
148 ServerStatusChanged {
149 server_id: ContextServerId,
150 status: ContextServerStatus,
151 },
152}
153
154impl EventEmitter<Event> for ContextServerStore {}
155
156impl ContextServerStore {
157 pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
158 Self::new_internal(
159 true,
160 None,
161 ContextServerDescriptorRegistry::default_global(cx),
162 worktree_store,
163 cx,
164 )
165 }
166
167 #[cfg(any(test, feature = "test-support"))]
168 pub fn test(
169 registry: Entity<ContextServerDescriptorRegistry>,
170 worktree_store: Entity<WorktreeStore>,
171 cx: &mut Context<Self>,
172 ) -> Self {
173 Self::new_internal(false, None, registry, worktree_store, cx)
174 }
175
176 #[cfg(any(test, feature = "test-support"))]
177 pub fn test_maintain_server_loop(
178 context_server_factory: ContextServerFactory,
179 registry: Entity<ContextServerDescriptorRegistry>,
180 worktree_store: Entity<WorktreeStore>,
181 cx: &mut Context<Self>,
182 ) -> Self {
183 Self::new_internal(
184 true,
185 Some(context_server_factory),
186 registry,
187 worktree_store,
188 cx,
189 )
190 }
191
192 fn new_internal(
193 maintain_server_loop: bool,
194 context_server_factory: Option<ContextServerFactory>,
195 registry: Entity<ContextServerDescriptorRegistry>,
196 worktree_store: Entity<WorktreeStore>,
197 cx: &mut Context<Self>,
198 ) -> Self {
199 let subscriptions = if maintain_server_loop {
200 vec![
201 cx.observe(®istry, |this, _registry, cx| {
202 this.available_context_servers_changed(cx);
203 }),
204 cx.observe_global::<SettingsStore>(|this, cx| {
205 this.available_context_servers_changed(cx);
206 }),
207 ]
208 } else {
209 Vec::new()
210 };
211
212 let mut this = Self {
213 _subscriptions: subscriptions,
214 worktree_store,
215 registry,
216 needs_server_update: false,
217 servers: HashMap::default(),
218 update_servers_task: None,
219 context_server_factory,
220 };
221 if maintain_server_loop {
222 this.available_context_servers_changed(cx);
223 }
224 this
225 }
226
227 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
228 self.servers.get(id).map(|state| state.server())
229 }
230
231 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
232 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
233 Some(server.clone())
234 } else {
235 None
236 }
237 }
238
239 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
240 self.servers.get(id).map(ContextServerStatus::from_state)
241 }
242
243 pub fn configuration_for_server(
244 &self,
245 id: &ContextServerId,
246 ) -> Option<Arc<ContextServerConfiguration>> {
247 self.servers.get(id).map(|state| state.configuration())
248 }
249
250 pub fn all_server_ids(&self) -> Vec<ContextServerId> {
251 self.servers.keys().cloned().collect()
252 }
253
254 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
255 self.servers
256 .values()
257 .filter_map(|state| {
258 if let ContextServerState::Running { server, .. } = state {
259 Some(server.clone())
260 } else {
261 None
262 }
263 })
264 .collect()
265 }
266
267 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
268 cx.spawn(async move |this, cx| {
269 let this = this.upgrade().context("Context server store dropped")?;
270 let settings = this
271 .update(cx, |this, cx| {
272 this.context_server_settings(cx)
273 .get(&server.id().0)
274 .cloned()
275 })
276 .ok()
277 .flatten()
278 .context("Failed to get context server settings")?;
279
280 if !settings.enabled() {
281 return Ok(());
282 }
283
284 let (registry, worktree_store) = this.update(cx, |this, _| {
285 (this.registry.clone(), this.worktree_store.clone())
286 })?;
287 let configuration = ContextServerConfiguration::from_settings(
288 settings,
289 server.id(),
290 registry,
291 worktree_store,
292 cx,
293 )
294 .await
295 .context("Failed to create context server configuration")?;
296
297 this.update(cx, |this, cx| {
298 this.run_server(server, Arc::new(configuration), cx)
299 })
300 })
301 .detach_and_log_err(cx);
302 }
303
304 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
305 if matches!(
306 self.servers.get(id),
307 Some(ContextServerState::Stopped { .. })
308 ) {
309 return Ok(());
310 }
311
312 let state = self
313 .servers
314 .remove(id)
315 .context("Context server not found")?;
316
317 let server = state.server();
318 let configuration = state.configuration();
319 let mut result = Ok(());
320 if let ContextServerState::Running { server, .. } = &state {
321 result = server.stop();
322 }
323 drop(state);
324
325 self.update_server_state(
326 id.clone(),
327 ContextServerState::Stopped {
328 configuration,
329 server,
330 },
331 cx,
332 );
333
334 result
335 }
336
337 pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
338 if let Some(state) = self.servers.get(&id) {
339 let configuration = state.configuration();
340
341 self.stop_server(&state.server().id(), cx)?;
342 let new_server = self.create_context_server(id.clone(), configuration.clone())?;
343 self.run_server(new_server, configuration, cx);
344 }
345 Ok(())
346 }
347
348 fn run_server(
349 &mut self,
350 server: Arc<ContextServer>,
351 configuration: Arc<ContextServerConfiguration>,
352 cx: &mut Context<Self>,
353 ) {
354 let id = server.id();
355 if matches!(
356 self.servers.get(&id),
357 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
358 ) {
359 self.stop_server(&id, cx).log_err();
360 }
361
362 let task = cx.spawn({
363 let id = server.id();
364 let server = server.clone();
365 let configuration = configuration.clone();
366 async move |this, cx| {
367 match server.clone().start(&cx).await {
368 Ok(_) => {
369 log::info!("Started {} context server", id);
370 debug_assert!(server.client().is_some());
371
372 this.update(cx, |this, cx| {
373 this.update_server_state(
374 id.clone(),
375 ContextServerState::Running {
376 server,
377 configuration,
378 },
379 cx,
380 )
381 })
382 .log_err()
383 }
384 Err(err) => {
385 log::error!("{} context server failed to start: {}", id, err);
386 this.update(cx, |this, cx| {
387 this.update_server_state(
388 id.clone(),
389 ContextServerState::Error {
390 configuration,
391 server,
392 error: err.to_string().into(),
393 },
394 cx,
395 )
396 })
397 .log_err()
398 }
399 };
400 }
401 });
402
403 self.update_server_state(
404 id.clone(),
405 ContextServerState::Starting {
406 configuration,
407 _task: task,
408 server,
409 },
410 cx,
411 );
412 }
413
414 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
415 let state = self
416 .servers
417 .remove(id)
418 .context("Context server not found")?;
419 drop(state);
420 cx.emit(Event::ServerStatusChanged {
421 server_id: id.clone(),
422 status: ContextServerStatus::Stopped,
423 });
424 Ok(())
425 }
426
427 fn create_context_server(
428 &self,
429 id: ContextServerId,
430 configuration: Arc<ContextServerConfiguration>,
431 ) -> Result<Arc<ContextServer>> {
432 if let Some(factory) = self.context_server_factory.as_ref() {
433 Ok(factory(id, configuration))
434 } else {
435 Ok(Arc::new(ContextServer::stdio(
436 id,
437 configuration.command().clone(),
438 )))
439 }
440 }
441
442 fn context_server_settings<'a>(
443 &'a self,
444 cx: &'a App,
445 ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
446 let location = self
447 .worktree_store
448 .read(cx)
449 .visible_worktrees(cx)
450 .next()
451 .map(|worktree| settings::SettingsLocation {
452 worktree_id: worktree.read(cx).id(),
453 path: Path::new(""),
454 });
455 &ProjectSettings::get(location, cx).context_servers
456 }
457
458 fn update_server_state(
459 &mut self,
460 id: ContextServerId,
461 state: ContextServerState,
462 cx: &mut Context<Self>,
463 ) {
464 let status = ContextServerStatus::from_state(&state);
465 self.servers.insert(id.clone(), state);
466 cx.emit(Event::ServerStatusChanged {
467 server_id: id,
468 status,
469 });
470 }
471
472 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
473 if self.update_servers_task.is_some() {
474 self.needs_server_update = true;
475 } else {
476 self.needs_server_update = false;
477 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
478 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
479 log::error!("Error maintaining context servers: {}", err);
480 }
481
482 this.update(cx, |this, cx| {
483 this.update_servers_task.take();
484 if this.needs_server_update {
485 this.available_context_servers_changed(cx);
486 }
487 })?;
488
489 Ok(())
490 }));
491 }
492 }
493
494 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
495 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, cx| {
496 (
497 this.context_server_settings(cx).clone(),
498 this.registry.clone(),
499 this.worktree_store.clone(),
500 )
501 })?;
502
503 for (id, _) in
504 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
505 {
506 configured_servers
507 .entry(id)
508 .or_insert(ContextServerSettings::Extension {
509 enabled: true,
510 settings: serde_json::json!({}),
511 });
512 }
513
514 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
515 configured_servers
516 .into_iter()
517 .partition(|(_, settings)| settings.enabled());
518
519 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
520 let id = ContextServerId(id);
521 ContextServerConfiguration::from_settings(
522 settings,
523 id.clone(),
524 registry.clone(),
525 worktree_store.clone(),
526 cx,
527 )
528 .map(|config| (id, config))
529 }))
530 .await
531 .into_iter()
532 .filter_map(|(id, config)| config.map(|config| (id, config)))
533 .collect::<HashMap<_, _>>();
534
535 let mut servers_to_start = Vec::new();
536 let mut servers_to_remove = HashSet::default();
537 let mut servers_to_stop = HashSet::default();
538
539 this.update(cx, |this, _cx| {
540 for server_id in this.servers.keys() {
541 // All servers that are not in desired_servers should be removed from the store.
542 // This can happen if the user removed a server from the context server settings.
543 if !configured_servers.contains_key(&server_id) {
544 if disabled_servers.contains_key(&server_id.0) {
545 servers_to_stop.insert(server_id.clone());
546 } else {
547 servers_to_remove.insert(server_id.clone());
548 }
549 }
550 }
551
552 for (id, config) in configured_servers {
553 let state = this.servers.get(&id);
554 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
555 let existing_config = state.as_ref().map(|state| state.configuration());
556 if existing_config.as_deref() != Some(&config) || is_stopped {
557 let config = Arc::new(config);
558 if let Some(server) = this
559 .create_context_server(id.clone(), config.clone())
560 .log_err()
561 {
562 servers_to_start.push((server, config));
563 if this.servers.contains_key(&id) {
564 servers_to_stop.insert(id);
565 }
566 }
567 }
568 }
569 })?;
570
571 this.update(cx, |this, cx| {
572 for id in servers_to_stop {
573 this.stop_server(&id, cx)?;
574 }
575 for id in servers_to_remove {
576 this.remove_server(&id, cx)?;
577 }
578 for (server, config) in servers_to_start {
579 this.run_server(server, config, cx);
580 }
581 anyhow::Ok(())
582 })?
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589 use crate::{
590 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
591 project_settings::ProjectSettings,
592 };
593 use context_server::test::create_fake_transport;
594 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
595 use serde_json::json;
596 use std::{cell::RefCell, rc::Rc};
597 use util::path;
598
599 #[gpui::test]
600 async fn test_context_server_status(cx: &mut TestAppContext) {
601 const SERVER_1_ID: &'static str = "mcp-1";
602 const SERVER_2_ID: &'static str = "mcp-2";
603
604 let (_fs, project) = setup_context_server_test(
605 cx,
606 json!({"code.rs": ""}),
607 vec![
608 (SERVER_1_ID.into(), dummy_server_settings()),
609 (SERVER_2_ID.into(), dummy_server_settings()),
610 ],
611 )
612 .await;
613
614 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
615 let store = cx.new(|cx| {
616 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
617 });
618
619 let server_1_id = ContextServerId(SERVER_1_ID.into());
620 let server_2_id = ContextServerId(SERVER_2_ID.into());
621
622 let server_1 = Arc::new(ContextServer::new(
623 server_1_id.clone(),
624 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
625 ));
626 let server_2 = Arc::new(ContextServer::new(
627 server_2_id.clone(),
628 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
629 ));
630
631 store.update(cx, |store, cx| store.start_server(server_1, cx));
632
633 cx.run_until_parked();
634
635 cx.update(|cx| {
636 assert_eq!(
637 store.read(cx).status_for_server(&server_1_id),
638 Some(ContextServerStatus::Running)
639 );
640 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
641 });
642
643 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
644
645 cx.run_until_parked();
646
647 cx.update(|cx| {
648 assert_eq!(
649 store.read(cx).status_for_server(&server_1_id),
650 Some(ContextServerStatus::Running)
651 );
652 assert_eq!(
653 store.read(cx).status_for_server(&server_2_id),
654 Some(ContextServerStatus::Running)
655 );
656 });
657
658 store
659 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
660 .unwrap();
661
662 cx.update(|cx| {
663 assert_eq!(
664 store.read(cx).status_for_server(&server_1_id),
665 Some(ContextServerStatus::Running)
666 );
667 assert_eq!(
668 store.read(cx).status_for_server(&server_2_id),
669 Some(ContextServerStatus::Stopped)
670 );
671 });
672 }
673
674 #[gpui::test]
675 async fn test_context_server_status_events(cx: &mut TestAppContext) {
676 const SERVER_1_ID: &'static str = "mcp-1";
677 const SERVER_2_ID: &'static str = "mcp-2";
678
679 let (_fs, project) = setup_context_server_test(
680 cx,
681 json!({"code.rs": ""}),
682 vec![
683 (SERVER_1_ID.into(), dummy_server_settings()),
684 (SERVER_2_ID.into(), dummy_server_settings()),
685 ],
686 )
687 .await;
688
689 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
690 let store = cx.new(|cx| {
691 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
692 });
693
694 let server_1_id = ContextServerId(SERVER_1_ID.into());
695 let server_2_id = ContextServerId(SERVER_2_ID.into());
696
697 let server_1 = Arc::new(ContextServer::new(
698 server_1_id.clone(),
699 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
700 ));
701 let server_2 = Arc::new(ContextServer::new(
702 server_2_id.clone(),
703 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
704 ));
705
706 let _server_events = assert_server_events(
707 &store,
708 vec![
709 (server_1_id.clone(), ContextServerStatus::Starting),
710 (server_1_id.clone(), ContextServerStatus::Running),
711 (server_2_id.clone(), ContextServerStatus::Starting),
712 (server_2_id.clone(), ContextServerStatus::Running),
713 (server_2_id.clone(), ContextServerStatus::Stopped),
714 ],
715 cx,
716 );
717
718 store.update(cx, |store, cx| store.start_server(server_1, cx));
719
720 cx.run_until_parked();
721
722 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
723
724 cx.run_until_parked();
725
726 store
727 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
728 .unwrap();
729 }
730
731 #[gpui::test(iterations = 25)]
732 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
733 const SERVER_1_ID: &'static str = "mcp-1";
734
735 let (_fs, project) = setup_context_server_test(
736 cx,
737 json!({"code.rs": ""}),
738 vec![(SERVER_1_ID.into(), dummy_server_settings())],
739 )
740 .await;
741
742 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
743 let store = cx.new(|cx| {
744 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
745 });
746
747 let server_id = ContextServerId(SERVER_1_ID.into());
748
749 let server_with_same_id_1 = Arc::new(ContextServer::new(
750 server_id.clone(),
751 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
752 ));
753 let server_with_same_id_2 = Arc::new(ContextServer::new(
754 server_id.clone(),
755 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
756 ));
757
758 // If we start another server with the same id, we should report that we stopped the previous one
759 let _server_events = assert_server_events(
760 &store,
761 vec![
762 (server_id.clone(), ContextServerStatus::Starting),
763 (server_id.clone(), ContextServerStatus::Stopped),
764 (server_id.clone(), ContextServerStatus::Starting),
765 (server_id.clone(), ContextServerStatus::Running),
766 ],
767 cx,
768 );
769
770 store.update(cx, |store, cx| {
771 store.start_server(server_with_same_id_1.clone(), cx)
772 });
773 store.update(cx, |store, cx| {
774 store.start_server(server_with_same_id_2.clone(), cx)
775 });
776
777 cx.run_until_parked();
778
779 cx.update(|cx| {
780 assert_eq!(
781 store.read(cx).status_for_server(&server_id),
782 Some(ContextServerStatus::Running)
783 );
784 });
785 }
786
787 #[gpui::test]
788 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
789 const SERVER_1_ID: &'static str = "mcp-1";
790 const SERVER_2_ID: &'static str = "mcp-2";
791
792 let server_1_id = ContextServerId(SERVER_1_ID.into());
793 let server_2_id = ContextServerId(SERVER_2_ID.into());
794
795 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
796
797 let (_fs, project) = setup_context_server_test(
798 cx,
799 json!({"code.rs": ""}),
800 vec![(
801 SERVER_1_ID.into(),
802 ContextServerSettings::Extension {
803 enabled: true,
804 settings: json!({
805 "somevalue": true
806 }),
807 },
808 )],
809 )
810 .await;
811
812 let executor = cx.executor();
813 let registry = cx.new(|_| {
814 let mut registry = ContextServerDescriptorRegistry::new();
815 registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1);
816 registry
817 });
818 let store = cx.new(|cx| {
819 ContextServerStore::test_maintain_server_loop(
820 Box::new(move |id, _| {
821 Arc::new(ContextServer::new(
822 id.clone(),
823 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
824 ))
825 }),
826 registry.clone(),
827 project.read(cx).worktree_store(),
828 cx,
829 )
830 });
831
832 // Ensure that mcp-1 starts up
833 {
834 let _server_events = assert_server_events(
835 &store,
836 vec![
837 (server_1_id.clone(), ContextServerStatus::Starting),
838 (server_1_id.clone(), ContextServerStatus::Running),
839 ],
840 cx,
841 );
842 cx.run_until_parked();
843 }
844
845 // Ensure that mcp-1 is restarted when the configuration was changed
846 {
847 let _server_events = assert_server_events(
848 &store,
849 vec![
850 (server_1_id.clone(), ContextServerStatus::Stopped),
851 (server_1_id.clone(), ContextServerStatus::Starting),
852 (server_1_id.clone(), ContextServerStatus::Running),
853 ],
854 cx,
855 );
856 set_context_server_configuration(
857 vec![(
858 server_1_id.0.clone(),
859 ContextServerSettings::Extension {
860 enabled: true,
861 settings: json!({
862 "somevalue": false
863 }),
864 },
865 )],
866 cx,
867 );
868
869 cx.run_until_parked();
870 }
871
872 // Ensure that mcp-1 is not restarted when the configuration was not changed
873 {
874 let _server_events = assert_server_events(&store, vec![], cx);
875 set_context_server_configuration(
876 vec![(
877 server_1_id.0.clone(),
878 ContextServerSettings::Extension {
879 enabled: true,
880 settings: json!({
881 "somevalue": false
882 }),
883 },
884 )],
885 cx,
886 );
887
888 cx.run_until_parked();
889 }
890
891 // Ensure that mcp-2 is started once it is added to the settings
892 {
893 let _server_events = assert_server_events(
894 &store,
895 vec![
896 (server_2_id.clone(), ContextServerStatus::Starting),
897 (server_2_id.clone(), ContextServerStatus::Running),
898 ],
899 cx,
900 );
901 set_context_server_configuration(
902 vec![
903 (
904 server_1_id.0.clone(),
905 ContextServerSettings::Extension {
906 enabled: true,
907 settings: json!({
908 "somevalue": false
909 }),
910 },
911 ),
912 (
913 server_2_id.0.clone(),
914 ContextServerSettings::Custom {
915 enabled: true,
916 command: ContextServerCommand {
917 path: "somebinary".to_string(),
918 args: vec!["arg".to_string()],
919 env: None,
920 },
921 },
922 ),
923 ],
924 cx,
925 );
926
927 cx.run_until_parked();
928 }
929
930 // Ensure that mcp-2 is restarted once the args have changed
931 {
932 let _server_events = assert_server_events(
933 &store,
934 vec![
935 (server_2_id.clone(), ContextServerStatus::Stopped),
936 (server_2_id.clone(), ContextServerStatus::Starting),
937 (server_2_id.clone(), ContextServerStatus::Running),
938 ],
939 cx,
940 );
941 set_context_server_configuration(
942 vec![
943 (
944 server_1_id.0.clone(),
945 ContextServerSettings::Extension {
946 enabled: true,
947 settings: json!({
948 "somevalue": false
949 }),
950 },
951 ),
952 (
953 server_2_id.0.clone(),
954 ContextServerSettings::Custom {
955 enabled: true,
956 command: ContextServerCommand {
957 path: "somebinary".to_string(),
958 args: vec!["anotherArg".to_string()],
959 env: None,
960 },
961 },
962 ),
963 ],
964 cx,
965 );
966
967 cx.run_until_parked();
968 }
969
970 // Ensure that mcp-2 is removed once it is removed from the settings
971 {
972 let _server_events = assert_server_events(
973 &store,
974 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
975 cx,
976 );
977 set_context_server_configuration(
978 vec![(
979 server_1_id.0.clone(),
980 ContextServerSettings::Extension {
981 enabled: true,
982 settings: json!({
983 "somevalue": false
984 }),
985 },
986 )],
987 cx,
988 );
989
990 cx.run_until_parked();
991
992 cx.update(|cx| {
993 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
994 });
995 }
996 }
997
998 #[gpui::test]
999 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1000 const SERVER_1_ID: &'static str = "mcp-1";
1001
1002 let server_1_id = ContextServerId(SERVER_1_ID.into());
1003
1004 let (_fs, project) = setup_context_server_test(
1005 cx,
1006 json!({"code.rs": ""}),
1007 vec![(
1008 SERVER_1_ID.into(),
1009 ContextServerSettings::Custom {
1010 enabled: true,
1011 command: ContextServerCommand {
1012 path: "somebinary".to_string(),
1013 args: vec!["arg".to_string()],
1014 env: None,
1015 },
1016 },
1017 )],
1018 )
1019 .await;
1020
1021 let executor = cx.executor();
1022 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1023 let store = cx.new(|cx| {
1024 ContextServerStore::test_maintain_server_loop(
1025 Box::new(move |id, _| {
1026 Arc::new(ContextServer::new(
1027 id.clone(),
1028 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1029 ))
1030 }),
1031 registry.clone(),
1032 project.read(cx).worktree_store(),
1033 cx,
1034 )
1035 });
1036
1037 // Ensure that mcp-1 starts up
1038 {
1039 let _server_events = assert_server_events(
1040 &store,
1041 vec![
1042 (server_1_id.clone(), ContextServerStatus::Starting),
1043 (server_1_id.clone(), ContextServerStatus::Running),
1044 ],
1045 cx,
1046 );
1047 cx.run_until_parked();
1048 }
1049
1050 // Ensure that mcp-1 is stopped once it is disabled.
1051 {
1052 let _server_events = assert_server_events(
1053 &store,
1054 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1055 cx,
1056 );
1057 set_context_server_configuration(
1058 vec![(
1059 server_1_id.0.clone(),
1060 ContextServerSettings::Custom {
1061 enabled: false,
1062 command: ContextServerCommand {
1063 path: "somebinary".to_string(),
1064 args: vec!["arg".to_string()],
1065 env: None,
1066 },
1067 },
1068 )],
1069 cx,
1070 );
1071
1072 cx.run_until_parked();
1073 }
1074
1075 // Ensure that mcp-1 is started once it is enabled again.
1076 {
1077 let _server_events = assert_server_events(
1078 &store,
1079 vec![
1080 (server_1_id.clone(), ContextServerStatus::Starting),
1081 (server_1_id.clone(), ContextServerStatus::Running),
1082 ],
1083 cx,
1084 );
1085 set_context_server_configuration(
1086 vec![(
1087 server_1_id.0.clone(),
1088 ContextServerSettings::Custom {
1089 enabled: true,
1090 command: ContextServerCommand {
1091 path: "somebinary".to_string(),
1092 args: vec!["arg".to_string()],
1093 env: None,
1094 },
1095 },
1096 )],
1097 cx,
1098 );
1099
1100 cx.run_until_parked();
1101 }
1102 }
1103
1104 fn set_context_server_configuration(
1105 context_servers: Vec<(Arc<str>, ContextServerSettings)>,
1106 cx: &mut TestAppContext,
1107 ) {
1108 cx.update(|cx| {
1109 SettingsStore::update_global(cx, |store, cx| {
1110 let mut settings = ProjectSettings::default();
1111 for (id, config) in context_servers {
1112 settings.context_servers.insert(id, config);
1113 }
1114 store
1115 .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
1116 .unwrap();
1117 })
1118 });
1119 }
1120
1121 struct ServerEvents {
1122 received_event_count: Rc<RefCell<usize>>,
1123 expected_event_count: usize,
1124 _subscription: Subscription,
1125 }
1126
1127 impl Drop for ServerEvents {
1128 fn drop(&mut self) {
1129 let actual_event_count = *self.received_event_count.borrow();
1130 assert_eq!(
1131 actual_event_count, self.expected_event_count,
1132 "
1133 Expected to receive {} context server store events, but received {} events",
1134 self.expected_event_count, actual_event_count
1135 );
1136 }
1137 }
1138
1139 fn dummy_server_settings() -> ContextServerSettings {
1140 ContextServerSettings::Custom {
1141 enabled: true,
1142 command: ContextServerCommand {
1143 path: "somebinary".to_string(),
1144 args: vec!["arg".to_string()],
1145 env: None,
1146 },
1147 }
1148 }
1149
1150 fn assert_server_events(
1151 store: &Entity<ContextServerStore>,
1152 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1153 cx: &mut TestAppContext,
1154 ) -> ServerEvents {
1155 cx.update(|cx| {
1156 let mut ix = 0;
1157 let received_event_count = Rc::new(RefCell::new(0));
1158 let expected_event_count = expected_events.len();
1159 let subscription = cx.subscribe(store, {
1160 let received_event_count = received_event_count.clone();
1161 move |_, event, _| match event {
1162 Event::ServerStatusChanged {
1163 server_id: actual_server_id,
1164 status: actual_status,
1165 } => {
1166 let (expected_server_id, expected_status) = &expected_events[ix];
1167
1168 assert_eq!(
1169 actual_server_id, expected_server_id,
1170 "Expected different server id at index {}",
1171 ix
1172 );
1173 assert_eq!(
1174 actual_status, expected_status,
1175 "Expected different status at index {}",
1176 ix
1177 );
1178 ix += 1;
1179 *received_event_count.borrow_mut() += 1;
1180 }
1181 }
1182 });
1183 ServerEvents {
1184 expected_event_count,
1185 received_event_count,
1186 _subscription: subscription,
1187 }
1188 })
1189 }
1190
1191 async fn setup_context_server_test(
1192 cx: &mut TestAppContext,
1193 files: serde_json::Value,
1194 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1195 ) -> (Arc<FakeFs>, Entity<Project>) {
1196 cx.update(|cx| {
1197 let settings_store = SettingsStore::test(cx);
1198 cx.set_global(settings_store);
1199 Project::init_settings(cx);
1200 let mut settings = ProjectSettings::get_global(cx).clone();
1201 for (id, config) in context_server_configurations {
1202 settings.context_servers.insert(id, config);
1203 }
1204 ProjectSettings::override_global(settings, cx);
1205 });
1206
1207 let fs = FakeFs::new(cx.executor());
1208 fs.insert_tree(path!("/test"), files).await;
1209 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1210
1211 (fs, project)
1212 }
1213
1214 struct FakeContextServerDescriptor {
1215 path: String,
1216 }
1217
1218 impl FakeContextServerDescriptor {
1219 fn new(path: impl Into<String>) -> Self {
1220 Self { path: path.into() }
1221 }
1222 }
1223
1224 impl ContextServerDescriptor for FakeContextServerDescriptor {
1225 fn command(
1226 &self,
1227 _worktree_store: Entity<WorktreeStore>,
1228 _cx: &AsyncApp,
1229 ) -> Task<Result<ContextServerCommand>> {
1230 Task::ready(Ok(ContextServerCommand {
1231 path: self.path.clone(),
1232 args: vec!["arg1".to_string(), "arg2".to_string()],
1233 env: None,
1234 }))
1235 }
1236
1237 fn configuration(
1238 &self,
1239 _worktree_store: Entity<WorktreeStore>,
1240 _cx: &AsyncApp,
1241 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1242 Task::ready(Ok(None))
1243 }
1244 }
1245}