1pub mod extension;
2pub mod registry;
3
4use std::{path::Path, sync::Arc};
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
9use futures::{FutureExt as _, future::join_all};
10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
11use registry::ContextServerDescriptorRegistry;
12use settings::{Settings as _, SettingsStore};
13use util::ResultExt as _;
14
15use crate::{
16 Project,
17 project_settings::{ContextServerSettings, ProjectSettings},
18 worktree_store::WorktreeStore,
19};
20
21pub fn init(cx: &mut App) {
22 extension::init(cx);
23}
24
25actions!(
26 context_server,
27 [
28 /// Restarts the context server.
29 Restart
30 ]
31);
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub enum ContextServerStatus {
35 Starting,
36 Running,
37 Stopped,
38 Error(Arc<str>),
39}
40
41impl ContextServerStatus {
42 fn from_state(state: &ContextServerState) -> Self {
43 match state {
44 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
45 ContextServerState::Running { .. } => ContextServerStatus::Running,
46 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
47 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
48 }
49 }
50}
51
52enum ContextServerState {
53 Starting {
54 server: Arc<ContextServer>,
55 configuration: Arc<ContextServerConfiguration>,
56 _task: Task<()>,
57 },
58 Running {
59 server: Arc<ContextServer>,
60 configuration: Arc<ContextServerConfiguration>,
61 },
62 Stopped {
63 server: Arc<ContextServer>,
64 configuration: Arc<ContextServerConfiguration>,
65 },
66 Error {
67 server: Arc<ContextServer>,
68 configuration: Arc<ContextServerConfiguration>,
69 error: Arc<str>,
70 },
71}
72
73impl ContextServerState {
74 pub fn server(&self) -> Arc<ContextServer> {
75 match self {
76 ContextServerState::Starting { server, .. } => server.clone(),
77 ContextServerState::Running { server, .. } => server.clone(),
78 ContextServerState::Stopped { server, .. } => server.clone(),
79 ContextServerState::Error { server, .. } => server.clone(),
80 }
81 }
82
83 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
84 match self {
85 ContextServerState::Starting { configuration, .. } => configuration.clone(),
86 ContextServerState::Running { configuration, .. } => configuration.clone(),
87 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
88 ContextServerState::Error { configuration, .. } => configuration.clone(),
89 }
90 }
91}
92
93#[derive(Debug, PartialEq, Eq)]
94pub enum ContextServerConfiguration {
95 Custom {
96 command: ContextServerCommand,
97 },
98 Extension {
99 command: ContextServerCommand,
100 settings: serde_json::Value,
101 },
102}
103
104impl ContextServerConfiguration {
105 pub fn command(&self) -> &ContextServerCommand {
106 match self {
107 ContextServerConfiguration::Custom { command } => command,
108 ContextServerConfiguration::Extension { command, .. } => command,
109 }
110 }
111
112 pub async fn from_settings(
113 settings: ContextServerSettings,
114 id: ContextServerId,
115 registry: Entity<ContextServerDescriptorRegistry>,
116 worktree_store: Entity<WorktreeStore>,
117 cx: &AsyncApp,
118 ) -> Option<Self> {
119 match settings {
120 ContextServerSettings::Custom {
121 enabled: _,
122 command,
123 } => Some(ContextServerConfiguration::Custom { command }),
124 ContextServerSettings::Extension {
125 enabled: _,
126 settings,
127 } => {
128 let descriptor = cx
129 .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
130 .ok()
131 .flatten()?;
132
133 let command = descriptor.command(worktree_store, cx).await.log_err()?;
134
135 Some(ContextServerConfiguration::Extension { command, settings })
136 }
137 }
138 }
139}
140
141pub type ContextServerFactory =
142 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
143
144pub struct ContextServerStore {
145 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
146 servers: HashMap<ContextServerId, ContextServerState>,
147 worktree_store: Entity<WorktreeStore>,
148 project: WeakEntity<Project>,
149 registry: Entity<ContextServerDescriptorRegistry>,
150 update_servers_task: Option<Task<Result<()>>>,
151 context_server_factory: Option<ContextServerFactory>,
152 needs_server_update: bool,
153 _subscriptions: Vec<Subscription>,
154}
155
156pub enum Event {
157 ServerStatusChanged {
158 server_id: ContextServerId,
159 status: ContextServerStatus,
160 },
161}
162
163impl EventEmitter<Event> for ContextServerStore {}
164
165impl ContextServerStore {
166 pub fn new(
167 worktree_store: Entity<WorktreeStore>,
168 weak_project: WeakEntity<Project>,
169 cx: &mut Context<Self>,
170 ) -> Self {
171 Self::new_internal(
172 true,
173 None,
174 ContextServerDescriptorRegistry::default_global(cx),
175 worktree_store,
176 weak_project,
177 cx,
178 )
179 }
180
181 /// Returns all configured context server ids, regardless of enabled state.
182 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
183 self.context_server_settings
184 .keys()
185 .cloned()
186 .map(ContextServerId)
187 .collect()
188 }
189
190 #[cfg(any(test, feature = "test-support"))]
191 pub fn test(
192 registry: Entity<ContextServerDescriptorRegistry>,
193 worktree_store: Entity<WorktreeStore>,
194 weak_project: WeakEntity<Project>,
195 cx: &mut Context<Self>,
196 ) -> Self {
197 Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
198 }
199
200 #[cfg(any(test, feature = "test-support"))]
201 pub fn test_maintain_server_loop(
202 context_server_factory: ContextServerFactory,
203 registry: Entity<ContextServerDescriptorRegistry>,
204 worktree_store: Entity<WorktreeStore>,
205 weak_project: WeakEntity<Project>,
206 cx: &mut Context<Self>,
207 ) -> Self {
208 Self::new_internal(
209 true,
210 Some(context_server_factory),
211 registry,
212 worktree_store,
213 weak_project,
214 cx,
215 )
216 }
217
218 fn new_internal(
219 maintain_server_loop: bool,
220 context_server_factory: Option<ContextServerFactory>,
221 registry: Entity<ContextServerDescriptorRegistry>,
222 worktree_store: Entity<WorktreeStore>,
223 weak_project: WeakEntity<Project>,
224 cx: &mut Context<Self>,
225 ) -> Self {
226 let subscriptions = if maintain_server_loop {
227 vec![
228 cx.observe(®istry, |this, _registry, cx| {
229 this.available_context_servers_changed(cx);
230 }),
231 cx.observe_global::<SettingsStore>(|this, cx| {
232 let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
233 if &this.context_server_settings == settings {
234 return;
235 }
236 this.context_server_settings = settings.clone();
237 this.available_context_servers_changed(cx);
238 }),
239 ]
240 } else {
241 Vec::new()
242 };
243
244 let mut this = Self {
245 _subscriptions: subscriptions,
246 context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
247 .clone(),
248 worktree_store,
249 project: weak_project,
250 registry,
251 needs_server_update: false,
252 servers: HashMap::default(),
253 update_servers_task: None,
254 context_server_factory,
255 };
256 if maintain_server_loop {
257 this.available_context_servers_changed(cx);
258 }
259 this
260 }
261
262 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
263 self.servers.get(id).map(|state| state.server())
264 }
265
266 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
267 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
268 Some(server.clone())
269 } else {
270 None
271 }
272 }
273
274 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
275 self.servers.get(id).map(ContextServerStatus::from_state)
276 }
277
278 pub fn configuration_for_server(
279 &self,
280 id: &ContextServerId,
281 ) -> Option<Arc<ContextServerConfiguration>> {
282 self.servers.get(id).map(|state| state.configuration())
283 }
284
285 pub fn all_server_ids(&self) -> Vec<ContextServerId> {
286 self.servers.keys().cloned().collect()
287 }
288
289 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
290 self.servers
291 .values()
292 .filter_map(|state| {
293 if let ContextServerState::Running { server, .. } = state {
294 Some(server.clone())
295 } else {
296 None
297 }
298 })
299 .collect()
300 }
301
302 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
303 cx.spawn(async move |this, cx| {
304 let this = this.upgrade().context("Context server store dropped")?;
305 let settings = this
306 .update(cx, |this, _| {
307 this.context_server_settings.get(&server.id().0).cloned()
308 })
309 .ok()
310 .flatten()
311 .context("Failed to get context server settings")?;
312
313 if !settings.enabled() {
314 return Ok(());
315 }
316
317 let (registry, worktree_store) = this.update(cx, |this, _| {
318 (this.registry.clone(), this.worktree_store.clone())
319 })?;
320 let configuration = ContextServerConfiguration::from_settings(
321 settings,
322 server.id(),
323 registry,
324 worktree_store,
325 cx,
326 )
327 .await
328 .context("Failed to create context server configuration")?;
329
330 this.update(cx, |this, cx| {
331 this.run_server(server, Arc::new(configuration), cx)
332 })
333 })
334 .detach_and_log_err(cx);
335 }
336
337 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
338 if matches!(
339 self.servers.get(id),
340 Some(ContextServerState::Stopped { .. })
341 ) {
342 return Ok(());
343 }
344
345 let state = self
346 .servers
347 .remove(id)
348 .context("Context server not found")?;
349
350 let server = state.server();
351 let configuration = state.configuration();
352 let mut result = Ok(());
353 if let ContextServerState::Running { server, .. } = &state {
354 result = server.stop();
355 }
356 drop(state);
357
358 self.update_server_state(
359 id.clone(),
360 ContextServerState::Stopped {
361 configuration,
362 server,
363 },
364 cx,
365 );
366
367 result
368 }
369
370 pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
371 if let Some(state) = self.servers.get(id) {
372 let configuration = state.configuration();
373
374 self.stop_server(&state.server().id(), cx)?;
375 let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
376 self.run_server(new_server, configuration, cx);
377 }
378 Ok(())
379 }
380
381 fn run_server(
382 &mut self,
383 server: Arc<ContextServer>,
384 configuration: Arc<ContextServerConfiguration>,
385 cx: &mut Context<Self>,
386 ) {
387 let id = server.id();
388 if matches!(
389 self.servers.get(&id),
390 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
391 ) {
392 self.stop_server(&id, cx).log_err();
393 }
394
395 let task = cx.spawn({
396 let id = server.id();
397 let server = server.clone();
398 let configuration = configuration.clone();
399 async move |this, cx| {
400 match server.clone().start(cx).await {
401 Ok(_) => {
402 log::info!("Started {} context server", id);
403 debug_assert!(server.client().is_some());
404
405 this.update(cx, |this, cx| {
406 this.update_server_state(
407 id.clone(),
408 ContextServerState::Running {
409 server,
410 configuration,
411 },
412 cx,
413 )
414 })
415 .log_err()
416 }
417 Err(err) => {
418 log::error!("{} context server failed to start: {}", id, err);
419 this.update(cx, |this, cx| {
420 this.update_server_state(
421 id.clone(),
422 ContextServerState::Error {
423 configuration,
424 server,
425 error: err.to_string().into(),
426 },
427 cx,
428 )
429 })
430 .log_err()
431 }
432 };
433 }
434 });
435
436 self.update_server_state(
437 id.clone(),
438 ContextServerState::Starting {
439 configuration,
440 _task: task,
441 server,
442 },
443 cx,
444 );
445 }
446
447 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
448 let state = self
449 .servers
450 .remove(id)
451 .context("Context server not found")?;
452 drop(state);
453 cx.emit(Event::ServerStatusChanged {
454 server_id: id.clone(),
455 status: ContextServerStatus::Stopped,
456 });
457 Ok(())
458 }
459
460 fn create_context_server(
461 &self,
462 id: ContextServerId,
463 configuration: Arc<ContextServerConfiguration>,
464 cx: &mut Context<Self>,
465 ) -> Arc<ContextServer> {
466 let root_path = self
467 .project
468 .read_with(cx, |project, cx| project.active_project_directory(cx))
469 .ok()
470 .flatten()
471 .or_else(|| {
472 self.worktree_store.read_with(cx, |store, cx| {
473 store.visible_worktrees(cx).fold(None, |acc, item| {
474 if acc.is_none() {
475 item.read(cx).root_dir()
476 } else {
477 acc
478 }
479 })
480 })
481 });
482
483 if let Some(factory) = self.context_server_factory.as_ref() {
484 factory(id, configuration)
485 } else {
486 Arc::new(ContextServer::stdio(
487 id,
488 configuration.command().clone(),
489 root_path,
490 ))
491 }
492 }
493
494 fn resolve_context_server_settings<'a>(
495 worktree_store: &'a Entity<WorktreeStore>,
496 cx: &'a App,
497 ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
498 let location = worktree_store
499 .read(cx)
500 .visible_worktrees(cx)
501 .next()
502 .map(|worktree| settings::SettingsLocation {
503 worktree_id: worktree.read(cx).id(),
504 path: Path::new(""),
505 });
506 &ProjectSettings::get(location, cx).context_servers
507 }
508
509 fn update_server_state(
510 &mut self,
511 id: ContextServerId,
512 state: ContextServerState,
513 cx: &mut Context<Self>,
514 ) {
515 let status = ContextServerStatus::from_state(&state);
516 self.servers.insert(id.clone(), state);
517 cx.emit(Event::ServerStatusChanged {
518 server_id: id,
519 status,
520 });
521 }
522
523 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
524 if self.update_servers_task.is_some() {
525 self.needs_server_update = true;
526 } else {
527 self.needs_server_update = false;
528 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
529 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
530 log::error!("Error maintaining context servers: {}", err);
531 }
532
533 this.update(cx, |this, cx| {
534 this.update_servers_task.take();
535 if this.needs_server_update {
536 this.available_context_servers_changed(cx);
537 }
538 })?;
539
540 Ok(())
541 }));
542 }
543 }
544
545 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
546 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
547 (
548 this.context_server_settings.clone(),
549 this.registry.clone(),
550 this.worktree_store.clone(),
551 )
552 })?;
553
554 for (id, _) in
555 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
556 {
557 configured_servers
558 .entry(id)
559 .or_insert(ContextServerSettings::default_extension());
560 }
561
562 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
563 configured_servers
564 .into_iter()
565 .partition(|(_, settings)| settings.enabled());
566
567 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
568 let id = ContextServerId(id);
569 ContextServerConfiguration::from_settings(
570 settings,
571 id.clone(),
572 registry.clone(),
573 worktree_store.clone(),
574 cx,
575 )
576 .map(|config| (id, config))
577 }))
578 .await
579 .into_iter()
580 .filter_map(|(id, config)| config.map(|config| (id, config)))
581 .collect::<HashMap<_, _>>();
582
583 let mut servers_to_start = Vec::new();
584 let mut servers_to_remove = HashSet::default();
585 let mut servers_to_stop = HashSet::default();
586
587 this.update(cx, |this, cx| {
588 for server_id in this.servers.keys() {
589 // All servers that are not in desired_servers should be removed from the store.
590 // This can happen if the user removed a server from the context server settings.
591 if !configured_servers.contains_key(server_id) {
592 if disabled_servers.contains_key(&server_id.0) {
593 servers_to_stop.insert(server_id.clone());
594 } else {
595 servers_to_remove.insert(server_id.clone());
596 }
597 }
598 }
599
600 for (id, config) in configured_servers {
601 let state = this.servers.get(&id);
602 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
603 let existing_config = state.as_ref().map(|state| state.configuration());
604 if existing_config.as_deref() != Some(&config) || is_stopped {
605 let config = Arc::new(config);
606 let server = this.create_context_server(id.clone(), config.clone(), cx);
607 servers_to_start.push((server, config));
608 if this.servers.contains_key(&id) {
609 servers_to_stop.insert(id);
610 }
611 }
612 }
613 })?;
614
615 this.update(cx, |this, cx| {
616 for id in servers_to_stop {
617 this.stop_server(&id, cx)?;
618 }
619 for id in servers_to_remove {
620 this.remove_server(&id, cx)?;
621 }
622 for (server, config) in servers_to_start {
623 this.run_server(server, config, cx);
624 }
625 anyhow::Ok(())
626 })?
627 }
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633 use crate::{
634 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
635 project_settings::ProjectSettings,
636 };
637 use context_server::test::create_fake_transport;
638 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
639 use serde_json::json;
640 use std::{cell::RefCell, path::PathBuf, rc::Rc};
641 use util::path;
642
643 #[gpui::test]
644 async fn test_context_server_status(cx: &mut TestAppContext) {
645 const SERVER_1_ID: &'static str = "mcp-1";
646 const SERVER_2_ID: &'static str = "mcp-2";
647
648 let (_fs, project) = setup_context_server_test(
649 cx,
650 json!({"code.rs": ""}),
651 vec![
652 (SERVER_1_ID.into(), dummy_server_settings()),
653 (SERVER_2_ID.into(), dummy_server_settings()),
654 ],
655 )
656 .await;
657
658 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
659 let store = cx.new(|cx| {
660 ContextServerStore::test(
661 registry.clone(),
662 project.read(cx).worktree_store(),
663 project.downgrade(),
664 cx,
665 )
666 });
667
668 let server_1_id = ContextServerId(SERVER_1_ID.into());
669 let server_2_id = ContextServerId(SERVER_2_ID.into());
670
671 let server_1 = Arc::new(ContextServer::new(
672 server_1_id.clone(),
673 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
674 ));
675 let server_2 = Arc::new(ContextServer::new(
676 server_2_id.clone(),
677 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
678 ));
679
680 store.update(cx, |store, cx| store.start_server(server_1, cx));
681
682 cx.run_until_parked();
683
684 cx.update(|cx| {
685 assert_eq!(
686 store.read(cx).status_for_server(&server_1_id),
687 Some(ContextServerStatus::Running)
688 );
689 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
690 });
691
692 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
693
694 cx.run_until_parked();
695
696 cx.update(|cx| {
697 assert_eq!(
698 store.read(cx).status_for_server(&server_1_id),
699 Some(ContextServerStatus::Running)
700 );
701 assert_eq!(
702 store.read(cx).status_for_server(&server_2_id),
703 Some(ContextServerStatus::Running)
704 );
705 });
706
707 store
708 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
709 .unwrap();
710
711 cx.update(|cx| {
712 assert_eq!(
713 store.read(cx).status_for_server(&server_1_id),
714 Some(ContextServerStatus::Running)
715 );
716 assert_eq!(
717 store.read(cx).status_for_server(&server_2_id),
718 Some(ContextServerStatus::Stopped)
719 );
720 });
721 }
722
723 #[gpui::test]
724 async fn test_context_server_status_events(cx: &mut TestAppContext) {
725 const SERVER_1_ID: &'static str = "mcp-1";
726 const SERVER_2_ID: &'static str = "mcp-2";
727
728 let (_fs, project) = setup_context_server_test(
729 cx,
730 json!({"code.rs": ""}),
731 vec![
732 (SERVER_1_ID.into(), dummy_server_settings()),
733 (SERVER_2_ID.into(), dummy_server_settings()),
734 ],
735 )
736 .await;
737
738 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
739 let store = cx.new(|cx| {
740 ContextServerStore::test(
741 registry.clone(),
742 project.read(cx).worktree_store(),
743 project.downgrade(),
744 cx,
745 )
746 });
747
748 let server_1_id = ContextServerId(SERVER_1_ID.into());
749 let server_2_id = ContextServerId(SERVER_2_ID.into());
750
751 let server_1 = Arc::new(ContextServer::new(
752 server_1_id.clone(),
753 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
754 ));
755 let server_2 = Arc::new(ContextServer::new(
756 server_2_id.clone(),
757 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
758 ));
759
760 let _server_events = assert_server_events(
761 &store,
762 vec![
763 (server_1_id.clone(), ContextServerStatus::Starting),
764 (server_1_id.clone(), ContextServerStatus::Running),
765 (server_2_id.clone(), ContextServerStatus::Starting),
766 (server_2_id.clone(), ContextServerStatus::Running),
767 (server_2_id.clone(), ContextServerStatus::Stopped),
768 ],
769 cx,
770 );
771
772 store.update(cx, |store, cx| store.start_server(server_1, cx));
773
774 cx.run_until_parked();
775
776 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
777
778 cx.run_until_parked();
779
780 store
781 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
782 .unwrap();
783 }
784
785 #[gpui::test(iterations = 25)]
786 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
787 const SERVER_1_ID: &'static str = "mcp-1";
788
789 let (_fs, project) = setup_context_server_test(
790 cx,
791 json!({"code.rs": ""}),
792 vec![(SERVER_1_ID.into(), dummy_server_settings())],
793 )
794 .await;
795
796 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
797 let store = cx.new(|cx| {
798 ContextServerStore::test(
799 registry.clone(),
800 project.read(cx).worktree_store(),
801 project.downgrade(),
802 cx,
803 )
804 });
805
806 let server_id = ContextServerId(SERVER_1_ID.into());
807
808 let server_with_same_id_1 = Arc::new(ContextServer::new(
809 server_id.clone(),
810 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
811 ));
812 let server_with_same_id_2 = Arc::new(ContextServer::new(
813 server_id.clone(),
814 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
815 ));
816
817 // If we start another server with the same id, we should report that we stopped the previous one
818 let _server_events = assert_server_events(
819 &store,
820 vec![
821 (server_id.clone(), ContextServerStatus::Starting),
822 (server_id.clone(), ContextServerStatus::Stopped),
823 (server_id.clone(), ContextServerStatus::Starting),
824 (server_id.clone(), ContextServerStatus::Running),
825 ],
826 cx,
827 );
828
829 store.update(cx, |store, cx| {
830 store.start_server(server_with_same_id_1.clone(), cx)
831 });
832 store.update(cx, |store, cx| {
833 store.start_server(server_with_same_id_2.clone(), cx)
834 });
835
836 cx.run_until_parked();
837
838 cx.update(|cx| {
839 assert_eq!(
840 store.read(cx).status_for_server(&server_id),
841 Some(ContextServerStatus::Running)
842 );
843 });
844 }
845
846 #[gpui::test]
847 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
848 const SERVER_1_ID: &'static str = "mcp-1";
849 const SERVER_2_ID: &'static str = "mcp-2";
850
851 let server_1_id = ContextServerId(SERVER_1_ID.into());
852 let server_2_id = ContextServerId(SERVER_2_ID.into());
853
854 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
855
856 let (_fs, project) = setup_context_server_test(
857 cx,
858 json!({"code.rs": ""}),
859 vec![(
860 SERVER_1_ID.into(),
861 ContextServerSettings::Extension {
862 enabled: true,
863 settings: json!({
864 "somevalue": true
865 }),
866 },
867 )],
868 )
869 .await;
870
871 let executor = cx.executor();
872 let registry = cx.new(|cx| {
873 let mut registry = ContextServerDescriptorRegistry::new();
874 registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
875 registry
876 });
877 let store = cx.new(|cx| {
878 ContextServerStore::test_maintain_server_loop(
879 Box::new(move |id, _| {
880 Arc::new(ContextServer::new(
881 id.clone(),
882 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
883 ))
884 }),
885 registry.clone(),
886 project.read(cx).worktree_store(),
887 project.downgrade(),
888 cx,
889 )
890 });
891
892 // Ensure that mcp-1 starts up
893 {
894 let _server_events = assert_server_events(
895 &store,
896 vec![
897 (server_1_id.clone(), ContextServerStatus::Starting),
898 (server_1_id.clone(), ContextServerStatus::Running),
899 ],
900 cx,
901 );
902 cx.run_until_parked();
903 }
904
905 // Ensure that mcp-1 is restarted when the configuration was changed
906 {
907 let _server_events = assert_server_events(
908 &store,
909 vec![
910 (server_1_id.clone(), ContextServerStatus::Stopped),
911 (server_1_id.clone(), ContextServerStatus::Starting),
912 (server_1_id.clone(), ContextServerStatus::Running),
913 ],
914 cx,
915 );
916 set_context_server_configuration(
917 vec![(
918 server_1_id.0.clone(),
919 ContextServerSettings::Extension {
920 enabled: true,
921 settings: json!({
922 "somevalue": false
923 }),
924 },
925 )],
926 cx,
927 );
928
929 cx.run_until_parked();
930 }
931
932 // Ensure that mcp-1 is not restarted when the configuration was not changed
933 {
934 let _server_events = assert_server_events(&store, vec![], cx);
935 set_context_server_configuration(
936 vec![(
937 server_1_id.0.clone(),
938 ContextServerSettings::Extension {
939 enabled: true,
940 settings: json!({
941 "somevalue": false
942 }),
943 },
944 )],
945 cx,
946 );
947
948 cx.run_until_parked();
949 }
950
951 // Ensure that mcp-2 is started once it is added to the settings
952 {
953 let _server_events = assert_server_events(
954 &store,
955 vec![
956 (server_2_id.clone(), ContextServerStatus::Starting),
957 (server_2_id.clone(), ContextServerStatus::Running),
958 ],
959 cx,
960 );
961 set_context_server_configuration(
962 vec![
963 (
964 server_1_id.0.clone(),
965 ContextServerSettings::Extension {
966 enabled: true,
967 settings: json!({
968 "somevalue": false
969 }),
970 },
971 ),
972 (
973 server_2_id.0.clone(),
974 ContextServerSettings::Custom {
975 enabled: true,
976 command: ContextServerCommand {
977 path: "somebinary".into(),
978 args: vec!["arg".to_string()],
979 env: None,
980 },
981 },
982 ),
983 ],
984 cx,
985 );
986
987 cx.run_until_parked();
988 }
989
990 // Ensure that mcp-2 is restarted once the args have changed
991 {
992 let _server_events = assert_server_events(
993 &store,
994 vec![
995 (server_2_id.clone(), ContextServerStatus::Stopped),
996 (server_2_id.clone(), ContextServerStatus::Starting),
997 (server_2_id.clone(), ContextServerStatus::Running),
998 ],
999 cx,
1000 );
1001 set_context_server_configuration(
1002 vec![
1003 (
1004 server_1_id.0.clone(),
1005 ContextServerSettings::Extension {
1006 enabled: true,
1007 settings: json!({
1008 "somevalue": false
1009 }),
1010 },
1011 ),
1012 (
1013 server_2_id.0.clone(),
1014 ContextServerSettings::Custom {
1015 enabled: true,
1016 command: ContextServerCommand {
1017 path: "somebinary".into(),
1018 args: vec!["anotherArg".to_string()],
1019 env: None,
1020 },
1021 },
1022 ),
1023 ],
1024 cx,
1025 );
1026
1027 cx.run_until_parked();
1028 }
1029
1030 // Ensure that mcp-2 is removed once it is removed from the settings
1031 {
1032 let _server_events = assert_server_events(
1033 &store,
1034 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1035 cx,
1036 );
1037 set_context_server_configuration(
1038 vec![(
1039 server_1_id.0.clone(),
1040 ContextServerSettings::Extension {
1041 enabled: true,
1042 settings: json!({
1043 "somevalue": false
1044 }),
1045 },
1046 )],
1047 cx,
1048 );
1049
1050 cx.run_until_parked();
1051
1052 cx.update(|cx| {
1053 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1054 });
1055 }
1056
1057 // Ensure that nothing happens if the settings do not change
1058 {
1059 let _server_events = assert_server_events(&store, vec![], cx);
1060 set_context_server_configuration(
1061 vec![(
1062 server_1_id.0.clone(),
1063 ContextServerSettings::Extension {
1064 enabled: true,
1065 settings: json!({
1066 "somevalue": false
1067 }),
1068 },
1069 )],
1070 cx,
1071 );
1072
1073 cx.run_until_parked();
1074
1075 cx.update(|cx| {
1076 assert_eq!(
1077 store.read(cx).status_for_server(&server_1_id),
1078 Some(ContextServerStatus::Running)
1079 );
1080 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1081 });
1082 }
1083 }
1084
1085 #[gpui::test]
1086 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1087 const SERVER_1_ID: &'static str = "mcp-1";
1088
1089 let server_1_id = ContextServerId(SERVER_1_ID.into());
1090
1091 let (_fs, project) = setup_context_server_test(
1092 cx,
1093 json!({"code.rs": ""}),
1094 vec![(
1095 SERVER_1_ID.into(),
1096 ContextServerSettings::Custom {
1097 enabled: true,
1098 command: ContextServerCommand {
1099 path: "somebinary".into(),
1100 args: vec!["arg".to_string()],
1101 env: None,
1102 },
1103 },
1104 )],
1105 )
1106 .await;
1107
1108 let executor = cx.executor();
1109 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1110 let store = cx.new(|cx| {
1111 ContextServerStore::test_maintain_server_loop(
1112 Box::new(move |id, _| {
1113 Arc::new(ContextServer::new(
1114 id.clone(),
1115 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1116 ))
1117 }),
1118 registry.clone(),
1119 project.read(cx).worktree_store(),
1120 project.downgrade(),
1121 cx,
1122 )
1123 });
1124
1125 // Ensure that mcp-1 starts up
1126 {
1127 let _server_events = assert_server_events(
1128 &store,
1129 vec![
1130 (server_1_id.clone(), ContextServerStatus::Starting),
1131 (server_1_id.clone(), ContextServerStatus::Running),
1132 ],
1133 cx,
1134 );
1135 cx.run_until_parked();
1136 }
1137
1138 // Ensure that mcp-1 is stopped once it is disabled.
1139 {
1140 let _server_events = assert_server_events(
1141 &store,
1142 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1143 cx,
1144 );
1145 set_context_server_configuration(
1146 vec![(
1147 server_1_id.0.clone(),
1148 ContextServerSettings::Custom {
1149 enabled: false,
1150 command: ContextServerCommand {
1151 path: "somebinary".into(),
1152 args: vec!["arg".to_string()],
1153 env: None,
1154 },
1155 },
1156 )],
1157 cx,
1158 );
1159
1160 cx.run_until_parked();
1161 }
1162
1163 // Ensure that mcp-1 is started once it is enabled again.
1164 {
1165 let _server_events = assert_server_events(
1166 &store,
1167 vec![
1168 (server_1_id.clone(), ContextServerStatus::Starting),
1169 (server_1_id.clone(), ContextServerStatus::Running),
1170 ],
1171 cx,
1172 );
1173 set_context_server_configuration(
1174 vec![(
1175 server_1_id.0.clone(),
1176 ContextServerSettings::Custom {
1177 enabled: true,
1178 command: ContextServerCommand {
1179 path: "somebinary".into(),
1180 args: vec!["arg".to_string()],
1181 env: None,
1182 },
1183 },
1184 )],
1185 cx,
1186 );
1187
1188 cx.run_until_parked();
1189 }
1190 }
1191
1192 fn set_context_server_configuration(
1193 context_servers: Vec<(Arc<str>, ContextServerSettings)>,
1194 cx: &mut TestAppContext,
1195 ) {
1196 cx.update(|cx| {
1197 SettingsStore::update_global(cx, |store, cx| {
1198 let mut settings = ProjectSettings::default();
1199 for (id, config) in context_servers {
1200 settings.context_servers.insert(id, config);
1201 }
1202 store
1203 .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
1204 .unwrap();
1205 })
1206 });
1207 }
1208
1209 struct ServerEvents {
1210 received_event_count: Rc<RefCell<usize>>,
1211 expected_event_count: usize,
1212 _subscription: Subscription,
1213 }
1214
1215 impl Drop for ServerEvents {
1216 fn drop(&mut self) {
1217 let actual_event_count = *self.received_event_count.borrow();
1218 assert_eq!(
1219 actual_event_count, self.expected_event_count,
1220 "
1221 Expected to receive {} context server store events, but received {} events",
1222 self.expected_event_count, actual_event_count
1223 );
1224 }
1225 }
1226
1227 fn dummy_server_settings() -> ContextServerSettings {
1228 ContextServerSettings::Custom {
1229 enabled: true,
1230 command: ContextServerCommand {
1231 path: "somebinary".into(),
1232 args: vec!["arg".to_string()],
1233 env: None,
1234 },
1235 }
1236 }
1237
1238 fn assert_server_events(
1239 store: &Entity<ContextServerStore>,
1240 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1241 cx: &mut TestAppContext,
1242 ) -> ServerEvents {
1243 cx.update(|cx| {
1244 let mut ix = 0;
1245 let received_event_count = Rc::new(RefCell::new(0));
1246 let expected_event_count = expected_events.len();
1247 let subscription = cx.subscribe(store, {
1248 let received_event_count = received_event_count.clone();
1249 move |_, event, _| match event {
1250 Event::ServerStatusChanged {
1251 server_id: actual_server_id,
1252 status: actual_status,
1253 } => {
1254 let (expected_server_id, expected_status) = &expected_events[ix];
1255
1256 assert_eq!(
1257 actual_server_id, expected_server_id,
1258 "Expected different server id at index {}",
1259 ix
1260 );
1261 assert_eq!(
1262 actual_status, expected_status,
1263 "Expected different status at index {}",
1264 ix
1265 );
1266 ix += 1;
1267 *received_event_count.borrow_mut() += 1;
1268 }
1269 }
1270 });
1271 ServerEvents {
1272 expected_event_count,
1273 received_event_count,
1274 _subscription: subscription,
1275 }
1276 })
1277 }
1278
1279 async fn setup_context_server_test(
1280 cx: &mut TestAppContext,
1281 files: serde_json::Value,
1282 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1283 ) -> (Arc<FakeFs>, Entity<Project>) {
1284 cx.update(|cx| {
1285 let settings_store = SettingsStore::test(cx);
1286 cx.set_global(settings_store);
1287 Project::init_settings(cx);
1288 let mut settings = ProjectSettings::get_global(cx).clone();
1289 for (id, config) in context_server_configurations {
1290 settings.context_servers.insert(id, config);
1291 }
1292 ProjectSettings::override_global(settings, cx);
1293 });
1294
1295 let fs = FakeFs::new(cx.executor());
1296 fs.insert_tree(path!("/test"), files).await;
1297 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1298
1299 (fs, project)
1300 }
1301
1302 struct FakeContextServerDescriptor {
1303 path: PathBuf,
1304 }
1305
1306 impl FakeContextServerDescriptor {
1307 fn new(path: impl Into<PathBuf>) -> Self {
1308 Self { path: path.into() }
1309 }
1310 }
1311
1312 impl ContextServerDescriptor for FakeContextServerDescriptor {
1313 fn command(
1314 &self,
1315 _worktree_store: Entity<WorktreeStore>,
1316 _cx: &AsyncApp,
1317 ) -> Task<Result<ContextServerCommand>> {
1318 Task::ready(Ok(ContextServerCommand {
1319 path: self.path.clone(),
1320 args: vec!["arg1".to_string(), "arg2".to_string()],
1321 env: None,
1322 }))
1323 }
1324
1325 fn configuration(
1326 &self,
1327 _worktree_store: Entity<WorktreeStore>,
1328 _cx: &AsyncApp,
1329 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1330 Task::ready(Ok(None))
1331 }
1332 }
1333}