1pub mod extension;
2pub mod registry;
3
4use std::{path::Path, sync::Arc};
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
9use futures::{FutureExt as _, future::join_all};
10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
11use registry::ContextServerDescriptorRegistry;
12use settings::{Settings as _, SettingsStore};
13use util::ResultExt as _;
14
15use crate::{
16 Project,
17 project_settings::{ContextServerSettings, ProjectSettings},
18 worktree_store::WorktreeStore,
19};
20
21pub fn init(cx: &mut App) {
22 extension::init(cx);
23}
24
25actions!(
26 context_server,
27 [
28 /// Restarts the context server.
29 Restart
30 ]
31);
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub enum ContextServerStatus {
35 Starting,
36 Running,
37 Stopped,
38 Error(Arc<str>),
39}
40
41impl ContextServerStatus {
42 fn from_state(state: &ContextServerState) -> Self {
43 match state {
44 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
45 ContextServerState::Running { .. } => ContextServerStatus::Running,
46 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
47 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
48 }
49 }
50}
51
52enum ContextServerState {
53 Starting {
54 server: Arc<ContextServer>,
55 configuration: Arc<ContextServerConfiguration>,
56 _task: Task<()>,
57 },
58 Running {
59 server: Arc<ContextServer>,
60 configuration: Arc<ContextServerConfiguration>,
61 },
62 Stopped {
63 server: Arc<ContextServer>,
64 configuration: Arc<ContextServerConfiguration>,
65 },
66 Error {
67 server: Arc<ContextServer>,
68 configuration: Arc<ContextServerConfiguration>,
69 error: Arc<str>,
70 },
71}
72
73impl ContextServerState {
74 pub fn server(&self) -> Arc<ContextServer> {
75 match self {
76 ContextServerState::Starting { server, .. } => server.clone(),
77 ContextServerState::Running { server, .. } => server.clone(),
78 ContextServerState::Stopped { server, .. } => server.clone(),
79 ContextServerState::Error { server, .. } => server.clone(),
80 }
81 }
82
83 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
84 match self {
85 ContextServerState::Starting { configuration, .. } => configuration.clone(),
86 ContextServerState::Running { configuration, .. } => configuration.clone(),
87 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
88 ContextServerState::Error { configuration, .. } => configuration.clone(),
89 }
90 }
91}
92
93#[derive(Debug, PartialEq, Eq)]
94pub enum ContextServerConfiguration {
95 Custom {
96 command: ContextServerCommand,
97 },
98 Extension {
99 command: ContextServerCommand,
100 settings: serde_json::Value,
101 },
102}
103
104impl ContextServerConfiguration {
105 pub fn command(&self) -> &ContextServerCommand {
106 match self {
107 ContextServerConfiguration::Custom { command } => command,
108 ContextServerConfiguration::Extension { command, .. } => command,
109 }
110 }
111
112 pub async fn from_settings(
113 settings: ContextServerSettings,
114 id: ContextServerId,
115 registry: Entity<ContextServerDescriptorRegistry>,
116 worktree_store: Entity<WorktreeStore>,
117 cx: &AsyncApp,
118 ) -> Option<Self> {
119 match settings {
120 ContextServerSettings::Custom {
121 enabled: _,
122 command,
123 } => Some(ContextServerConfiguration::Custom { command }),
124 ContextServerSettings::Extension {
125 enabled: _,
126 settings,
127 } => {
128 let descriptor = cx
129 .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
130 .ok()
131 .flatten()?;
132
133 let command = descriptor.command(worktree_store, cx).await.log_err()?;
134
135 Some(ContextServerConfiguration::Extension { command, settings })
136 }
137 }
138 }
139}
140
141pub type ContextServerFactory =
142 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
143
144pub struct ContextServerStore {
145 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
146 servers: HashMap<ContextServerId, ContextServerState>,
147 worktree_store: Entity<WorktreeStore>,
148 project: WeakEntity<Project>,
149 registry: Entity<ContextServerDescriptorRegistry>,
150 update_servers_task: Option<Task<Result<()>>>,
151 context_server_factory: Option<ContextServerFactory>,
152 needs_server_update: bool,
153 _subscriptions: Vec<Subscription>,
154}
155
156pub enum Event {
157 ServerStatusChanged {
158 server_id: ContextServerId,
159 status: ContextServerStatus,
160 },
161}
162
163impl EventEmitter<Event> for ContextServerStore {}
164
165impl ContextServerStore {
166 pub fn new(
167 worktree_store: Entity<WorktreeStore>,
168 weak_project: WeakEntity<Project>,
169 cx: &mut Context<Self>,
170 ) -> Self {
171 Self::new_internal(
172 true,
173 None,
174 ContextServerDescriptorRegistry::default_global(cx),
175 worktree_store,
176 weak_project,
177 cx,
178 )
179 }
180
181 /// Returns all configured context server ids, regardless of enabled state.
182 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
183 self.context_server_settings
184 .keys()
185 .cloned()
186 .map(ContextServerId)
187 .collect()
188 }
189
190 #[cfg(any(test, feature = "test-support"))]
191 pub fn test(
192 registry: Entity<ContextServerDescriptorRegistry>,
193 worktree_store: Entity<WorktreeStore>,
194 weak_project: WeakEntity<Project>,
195 cx: &mut Context<Self>,
196 ) -> Self {
197 Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
198 }
199
200 #[cfg(any(test, feature = "test-support"))]
201 pub fn test_maintain_server_loop(
202 context_server_factory: ContextServerFactory,
203 registry: Entity<ContextServerDescriptorRegistry>,
204 worktree_store: Entity<WorktreeStore>,
205 weak_project: WeakEntity<Project>,
206 cx: &mut Context<Self>,
207 ) -> Self {
208 Self::new_internal(
209 true,
210 Some(context_server_factory),
211 registry,
212 worktree_store,
213 weak_project,
214 cx,
215 )
216 }
217
218 fn new_internal(
219 maintain_server_loop: bool,
220 context_server_factory: Option<ContextServerFactory>,
221 registry: Entity<ContextServerDescriptorRegistry>,
222 worktree_store: Entity<WorktreeStore>,
223 weak_project: WeakEntity<Project>,
224 cx: &mut Context<Self>,
225 ) -> Self {
226 let subscriptions = if maintain_server_loop {
227 vec![
228 cx.observe(®istry, |this, _registry, cx| {
229 this.available_context_servers_changed(cx);
230 }),
231 cx.observe_global::<SettingsStore>(|this, cx| {
232 let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
233 if &this.context_server_settings == settings {
234 return;
235 }
236 this.context_server_settings = settings.clone();
237 this.available_context_servers_changed(cx);
238 }),
239 ]
240 } else {
241 Vec::new()
242 };
243
244 let mut this = Self {
245 _subscriptions: subscriptions,
246 context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
247 .clone(),
248 worktree_store,
249 project: weak_project,
250 registry,
251 needs_server_update: false,
252 servers: HashMap::default(),
253 update_servers_task: None,
254 context_server_factory,
255 };
256 if maintain_server_loop {
257 this.available_context_servers_changed(cx);
258 }
259 this
260 }
261
262 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
263 self.servers.get(id).map(|state| state.server())
264 }
265
266 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
267 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
268 Some(server.clone())
269 } else {
270 None
271 }
272 }
273
274 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
275 self.servers.get(id).map(ContextServerStatus::from_state)
276 }
277
278 pub fn configuration_for_server(
279 &self,
280 id: &ContextServerId,
281 ) -> Option<Arc<ContextServerConfiguration>> {
282 self.servers.get(id).map(|state| state.configuration())
283 }
284
285 pub fn all_server_ids(&self) -> Vec<ContextServerId> {
286 self.servers.keys().cloned().collect()
287 }
288
289 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
290 self.servers
291 .values()
292 .filter_map(|state| {
293 if let ContextServerState::Running { server, .. } = state {
294 Some(server.clone())
295 } else {
296 None
297 }
298 })
299 .collect()
300 }
301
302 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
303 cx.spawn(async move |this, cx| {
304 let this = this.upgrade().context("Context server store dropped")?;
305 let settings = this
306 .update(cx, |this, _| {
307 this.context_server_settings.get(&server.id().0).cloned()
308 })
309 .ok()
310 .flatten()
311 .context("Failed to get context server settings")?;
312
313 if !settings.enabled() {
314 return Ok(());
315 }
316
317 let (registry, worktree_store) = this.update(cx, |this, _| {
318 (this.registry.clone(), this.worktree_store.clone())
319 })?;
320 let configuration = ContextServerConfiguration::from_settings(
321 settings,
322 server.id(),
323 registry,
324 worktree_store,
325 cx,
326 )
327 .await
328 .context("Failed to create context server configuration")?;
329
330 this.update(cx, |this, cx| {
331 this.run_server(server, Arc::new(configuration), cx)
332 })
333 })
334 .detach_and_log_err(cx);
335 }
336
337 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
338 if matches!(
339 self.servers.get(id),
340 Some(ContextServerState::Stopped { .. })
341 ) {
342 return Ok(());
343 }
344
345 let state = self
346 .servers
347 .remove(id)
348 .context("Context server not found")?;
349
350 let server = state.server();
351 let configuration = state.configuration();
352 let mut result = Ok(());
353 if let ContextServerState::Running { server, .. } = &state {
354 result = server.stop();
355 }
356 drop(state);
357
358 self.update_server_state(
359 id.clone(),
360 ContextServerState::Stopped {
361 configuration,
362 server,
363 },
364 cx,
365 );
366
367 result
368 }
369
370 pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
371 if let Some(state) = self.servers.get(id) {
372 let configuration = state.configuration();
373
374 self.stop_server(&state.server().id(), cx)?;
375 let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
376 self.run_server(new_server, configuration, cx);
377 }
378 Ok(())
379 }
380
381 fn run_server(
382 &mut self,
383 server: Arc<ContextServer>,
384 configuration: Arc<ContextServerConfiguration>,
385 cx: &mut Context<Self>,
386 ) {
387 let id = server.id();
388 if matches!(
389 self.servers.get(&id),
390 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
391 ) {
392 self.stop_server(&id, cx).log_err();
393 }
394
395 let task = cx.spawn({
396 let id = server.id();
397 let server = server.clone();
398 let configuration = configuration.clone();
399 async move |this, cx| {
400 match server.clone().start(cx).await {
401 Ok(_) => {
402 debug_assert!(server.client().is_some());
403
404 this.update(cx, |this, cx| {
405 this.update_server_state(
406 id.clone(),
407 ContextServerState::Running {
408 server,
409 configuration,
410 },
411 cx,
412 )
413 })
414 .log_err()
415 }
416 Err(err) => {
417 log::error!("{} context server failed to start: {}", id, err);
418 this.update(cx, |this, cx| {
419 this.update_server_state(
420 id.clone(),
421 ContextServerState::Error {
422 configuration,
423 server,
424 error: err.to_string().into(),
425 },
426 cx,
427 )
428 })
429 .log_err()
430 }
431 };
432 }
433 });
434
435 self.update_server_state(
436 id.clone(),
437 ContextServerState::Starting {
438 configuration,
439 _task: task,
440 server,
441 },
442 cx,
443 );
444 }
445
446 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
447 let state = self
448 .servers
449 .remove(id)
450 .context("Context server not found")?;
451 drop(state);
452 cx.emit(Event::ServerStatusChanged {
453 server_id: id.clone(),
454 status: ContextServerStatus::Stopped,
455 });
456 Ok(())
457 }
458
459 fn create_context_server(
460 &self,
461 id: ContextServerId,
462 configuration: Arc<ContextServerConfiguration>,
463 cx: &mut Context<Self>,
464 ) -> Arc<ContextServer> {
465 let root_path = self
466 .project
467 .read_with(cx, |project, cx| project.active_project_directory(cx))
468 .ok()
469 .flatten()
470 .or_else(|| {
471 self.worktree_store.read_with(cx, |store, cx| {
472 store.visible_worktrees(cx).fold(None, |acc, item| {
473 if acc.is_none() {
474 item.read(cx).root_dir()
475 } else {
476 acc
477 }
478 })
479 })
480 });
481
482 if let Some(factory) = self.context_server_factory.as_ref() {
483 factory(id, configuration)
484 } else {
485 Arc::new(ContextServer::stdio(
486 id,
487 configuration.command().clone(),
488 root_path,
489 ))
490 }
491 }
492
493 fn resolve_context_server_settings<'a>(
494 worktree_store: &'a Entity<WorktreeStore>,
495 cx: &'a App,
496 ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
497 let location = worktree_store
498 .read(cx)
499 .visible_worktrees(cx)
500 .next()
501 .map(|worktree| settings::SettingsLocation {
502 worktree_id: worktree.read(cx).id(),
503 path: Path::new(""),
504 });
505 &ProjectSettings::get(location, cx).context_servers
506 }
507
508 fn update_server_state(
509 &mut self,
510 id: ContextServerId,
511 state: ContextServerState,
512 cx: &mut Context<Self>,
513 ) {
514 let status = ContextServerStatus::from_state(&state);
515 self.servers.insert(id.clone(), state);
516 cx.emit(Event::ServerStatusChanged {
517 server_id: id,
518 status,
519 });
520 }
521
522 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
523 if self.update_servers_task.is_some() {
524 self.needs_server_update = true;
525 } else {
526 self.needs_server_update = false;
527 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
528 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
529 log::error!("Error maintaining context servers: {}", err);
530 }
531
532 this.update(cx, |this, cx| {
533 this.update_servers_task.take();
534 if this.needs_server_update {
535 this.available_context_servers_changed(cx);
536 }
537 })?;
538
539 Ok(())
540 }));
541 }
542 }
543
544 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
545 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
546 (
547 this.context_server_settings.clone(),
548 this.registry.clone(),
549 this.worktree_store.clone(),
550 )
551 })?;
552
553 for (id, _) in
554 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
555 {
556 configured_servers
557 .entry(id)
558 .or_insert(ContextServerSettings::default_extension());
559 }
560
561 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
562 configured_servers
563 .into_iter()
564 .partition(|(_, settings)| settings.enabled());
565
566 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
567 let id = ContextServerId(id);
568 ContextServerConfiguration::from_settings(
569 settings,
570 id.clone(),
571 registry.clone(),
572 worktree_store.clone(),
573 cx,
574 )
575 .map(|config| (id, config))
576 }))
577 .await
578 .into_iter()
579 .filter_map(|(id, config)| config.map(|config| (id, config)))
580 .collect::<HashMap<_, _>>();
581
582 let mut servers_to_start = Vec::new();
583 let mut servers_to_remove = HashSet::default();
584 let mut servers_to_stop = HashSet::default();
585
586 this.update(cx, |this, cx| {
587 for server_id in this.servers.keys() {
588 // All servers that are not in desired_servers should be removed from the store.
589 // This can happen if the user removed a server from the context server settings.
590 if !configured_servers.contains_key(server_id) {
591 if disabled_servers.contains_key(&server_id.0) {
592 servers_to_stop.insert(server_id.clone());
593 } else {
594 servers_to_remove.insert(server_id.clone());
595 }
596 }
597 }
598
599 for (id, config) in configured_servers {
600 let state = this.servers.get(&id);
601 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
602 let existing_config = state.as_ref().map(|state| state.configuration());
603 if existing_config.as_deref() != Some(&config) || is_stopped {
604 let config = Arc::new(config);
605 let server = this.create_context_server(id.clone(), config.clone(), cx);
606 servers_to_start.push((server, config));
607 if this.servers.contains_key(&id) {
608 servers_to_stop.insert(id);
609 }
610 }
611 }
612 })?;
613
614 this.update(cx, |this, cx| {
615 for id in servers_to_stop {
616 this.stop_server(&id, cx)?;
617 }
618 for id in servers_to_remove {
619 this.remove_server(&id, cx)?;
620 }
621 for (server, config) in servers_to_start {
622 this.run_server(server, config, cx);
623 }
624 anyhow::Ok(())
625 })?
626 }
627}
628
629#[cfg(test)]
630mod tests {
631 use super::*;
632 use crate::{
633 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
634 project_settings::ProjectSettings,
635 };
636 use context_server::test::create_fake_transport;
637 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
638 use serde_json::json;
639 use std::{cell::RefCell, path::PathBuf, rc::Rc};
640 use util::path;
641
642 #[gpui::test]
643 async fn test_context_server_status(cx: &mut TestAppContext) {
644 const SERVER_1_ID: &str = "mcp-1";
645 const SERVER_2_ID: &str = "mcp-2";
646
647 let (_fs, project) = setup_context_server_test(
648 cx,
649 json!({"code.rs": ""}),
650 vec![
651 (SERVER_1_ID.into(), dummy_server_settings()),
652 (SERVER_2_ID.into(), dummy_server_settings()),
653 ],
654 )
655 .await;
656
657 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
658 let store = cx.new(|cx| {
659 ContextServerStore::test(
660 registry.clone(),
661 project.read(cx).worktree_store(),
662 project.downgrade(),
663 cx,
664 )
665 });
666
667 let server_1_id = ContextServerId(SERVER_1_ID.into());
668 let server_2_id = ContextServerId(SERVER_2_ID.into());
669
670 let server_1 = Arc::new(ContextServer::new(
671 server_1_id.clone(),
672 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
673 ));
674 let server_2 = Arc::new(ContextServer::new(
675 server_2_id.clone(),
676 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
677 ));
678
679 store.update(cx, |store, cx| store.start_server(server_1, cx));
680
681 cx.run_until_parked();
682
683 cx.update(|cx| {
684 assert_eq!(
685 store.read(cx).status_for_server(&server_1_id),
686 Some(ContextServerStatus::Running)
687 );
688 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
689 });
690
691 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
692
693 cx.run_until_parked();
694
695 cx.update(|cx| {
696 assert_eq!(
697 store.read(cx).status_for_server(&server_1_id),
698 Some(ContextServerStatus::Running)
699 );
700 assert_eq!(
701 store.read(cx).status_for_server(&server_2_id),
702 Some(ContextServerStatus::Running)
703 );
704 });
705
706 store
707 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
708 .unwrap();
709
710 cx.update(|cx| {
711 assert_eq!(
712 store.read(cx).status_for_server(&server_1_id),
713 Some(ContextServerStatus::Running)
714 );
715 assert_eq!(
716 store.read(cx).status_for_server(&server_2_id),
717 Some(ContextServerStatus::Stopped)
718 );
719 });
720 }
721
722 #[gpui::test]
723 async fn test_context_server_status_events(cx: &mut TestAppContext) {
724 const SERVER_1_ID: &str = "mcp-1";
725 const SERVER_2_ID: &str = "mcp-2";
726
727 let (_fs, project) = setup_context_server_test(
728 cx,
729 json!({"code.rs": ""}),
730 vec![
731 (SERVER_1_ID.into(), dummy_server_settings()),
732 (SERVER_2_ID.into(), dummy_server_settings()),
733 ],
734 )
735 .await;
736
737 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
738 let store = cx.new(|cx| {
739 ContextServerStore::test(
740 registry.clone(),
741 project.read(cx).worktree_store(),
742 project.downgrade(),
743 cx,
744 )
745 });
746
747 let server_1_id = ContextServerId(SERVER_1_ID.into());
748 let server_2_id = ContextServerId(SERVER_2_ID.into());
749
750 let server_1 = Arc::new(ContextServer::new(
751 server_1_id.clone(),
752 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
753 ));
754 let server_2 = Arc::new(ContextServer::new(
755 server_2_id.clone(),
756 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
757 ));
758
759 let _server_events = assert_server_events(
760 &store,
761 vec![
762 (server_1_id.clone(), ContextServerStatus::Starting),
763 (server_1_id, ContextServerStatus::Running),
764 (server_2_id.clone(), ContextServerStatus::Starting),
765 (server_2_id.clone(), ContextServerStatus::Running),
766 (server_2_id.clone(), ContextServerStatus::Stopped),
767 ],
768 cx,
769 );
770
771 store.update(cx, |store, cx| store.start_server(server_1, cx));
772
773 cx.run_until_parked();
774
775 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
776
777 cx.run_until_parked();
778
779 store
780 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
781 .unwrap();
782 }
783
784 #[gpui::test(iterations = 25)]
785 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
786 const SERVER_1_ID: &str = "mcp-1";
787
788 let (_fs, project) = setup_context_server_test(
789 cx,
790 json!({"code.rs": ""}),
791 vec![(SERVER_1_ID.into(), dummy_server_settings())],
792 )
793 .await;
794
795 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
796 let store = cx.new(|cx| {
797 ContextServerStore::test(
798 registry.clone(),
799 project.read(cx).worktree_store(),
800 project.downgrade(),
801 cx,
802 )
803 });
804
805 let server_id = ContextServerId(SERVER_1_ID.into());
806
807 let server_with_same_id_1 = Arc::new(ContextServer::new(
808 server_id.clone(),
809 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
810 ));
811 let server_with_same_id_2 = Arc::new(ContextServer::new(
812 server_id.clone(),
813 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
814 ));
815
816 // If we start another server with the same id, we should report that we stopped the previous one
817 let _server_events = assert_server_events(
818 &store,
819 vec![
820 (server_id.clone(), ContextServerStatus::Starting),
821 (server_id.clone(), ContextServerStatus::Stopped),
822 (server_id.clone(), ContextServerStatus::Starting),
823 (server_id.clone(), ContextServerStatus::Running),
824 ],
825 cx,
826 );
827
828 store.update(cx, |store, cx| {
829 store.start_server(server_with_same_id_1.clone(), cx)
830 });
831 store.update(cx, |store, cx| {
832 store.start_server(server_with_same_id_2.clone(), cx)
833 });
834
835 cx.run_until_parked();
836
837 cx.update(|cx| {
838 assert_eq!(
839 store.read(cx).status_for_server(&server_id),
840 Some(ContextServerStatus::Running)
841 );
842 });
843 }
844
845 #[gpui::test]
846 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
847 const SERVER_1_ID: &str = "mcp-1";
848 const SERVER_2_ID: &str = "mcp-2";
849
850 let server_1_id = ContextServerId(SERVER_1_ID.into());
851 let server_2_id = ContextServerId(SERVER_2_ID.into());
852
853 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
854
855 let (_fs, project) = setup_context_server_test(
856 cx,
857 json!({"code.rs": ""}),
858 vec![(
859 SERVER_1_ID.into(),
860 ContextServerSettings::Extension {
861 enabled: true,
862 settings: json!({
863 "somevalue": true
864 }),
865 },
866 )],
867 )
868 .await;
869
870 let executor = cx.executor();
871 let registry = cx.new(|cx| {
872 let mut registry = ContextServerDescriptorRegistry::new();
873 registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
874 registry
875 });
876 let store = cx.new(|cx| {
877 ContextServerStore::test_maintain_server_loop(
878 Box::new(move |id, _| {
879 Arc::new(ContextServer::new(
880 id.clone(),
881 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
882 ))
883 }),
884 registry.clone(),
885 project.read(cx).worktree_store(),
886 project.downgrade(),
887 cx,
888 )
889 });
890
891 // Ensure that mcp-1 starts up
892 {
893 let _server_events = assert_server_events(
894 &store,
895 vec![
896 (server_1_id.clone(), ContextServerStatus::Starting),
897 (server_1_id.clone(), ContextServerStatus::Running),
898 ],
899 cx,
900 );
901 cx.run_until_parked();
902 }
903
904 // Ensure that mcp-1 is restarted when the configuration was changed
905 {
906 let _server_events = assert_server_events(
907 &store,
908 vec![
909 (server_1_id.clone(), ContextServerStatus::Stopped),
910 (server_1_id.clone(), ContextServerStatus::Starting),
911 (server_1_id.clone(), ContextServerStatus::Running),
912 ],
913 cx,
914 );
915 set_context_server_configuration(
916 vec![(
917 server_1_id.0.clone(),
918 ContextServerSettings::Extension {
919 enabled: true,
920 settings: json!({
921 "somevalue": false
922 }),
923 },
924 )],
925 cx,
926 );
927
928 cx.run_until_parked();
929 }
930
931 // Ensure that mcp-1 is not restarted when the configuration was not changed
932 {
933 let _server_events = assert_server_events(&store, vec![], cx);
934 set_context_server_configuration(
935 vec![(
936 server_1_id.0.clone(),
937 ContextServerSettings::Extension {
938 enabled: true,
939 settings: json!({
940 "somevalue": false
941 }),
942 },
943 )],
944 cx,
945 );
946
947 cx.run_until_parked();
948 }
949
950 // Ensure that mcp-2 is started once it is added to the settings
951 {
952 let _server_events = assert_server_events(
953 &store,
954 vec![
955 (server_2_id.clone(), ContextServerStatus::Starting),
956 (server_2_id.clone(), ContextServerStatus::Running),
957 ],
958 cx,
959 );
960 set_context_server_configuration(
961 vec![
962 (
963 server_1_id.0.clone(),
964 ContextServerSettings::Extension {
965 enabled: true,
966 settings: json!({
967 "somevalue": false
968 }),
969 },
970 ),
971 (
972 server_2_id.0.clone(),
973 ContextServerSettings::Custom {
974 enabled: true,
975 command: ContextServerCommand {
976 path: "somebinary".into(),
977 args: vec!["arg".to_string()],
978 env: None,
979 timeout: None,
980 },
981 },
982 ),
983 ],
984 cx,
985 );
986
987 cx.run_until_parked();
988 }
989
990 // Ensure that mcp-2 is restarted once the args have changed
991 {
992 let _server_events = assert_server_events(
993 &store,
994 vec![
995 (server_2_id.clone(), ContextServerStatus::Stopped),
996 (server_2_id.clone(), ContextServerStatus::Starting),
997 (server_2_id.clone(), ContextServerStatus::Running),
998 ],
999 cx,
1000 );
1001 set_context_server_configuration(
1002 vec![
1003 (
1004 server_1_id.0.clone(),
1005 ContextServerSettings::Extension {
1006 enabled: true,
1007 settings: json!({
1008 "somevalue": false
1009 }),
1010 },
1011 ),
1012 (
1013 server_2_id.0.clone(),
1014 ContextServerSettings::Custom {
1015 enabled: true,
1016 command: ContextServerCommand {
1017 path: "somebinary".into(),
1018 args: vec!["anotherArg".to_string()],
1019 env: None,
1020 timeout: None,
1021 },
1022 },
1023 ),
1024 ],
1025 cx,
1026 );
1027
1028 cx.run_until_parked();
1029 }
1030
1031 // Ensure that mcp-2 is removed once it is removed from the settings
1032 {
1033 let _server_events = assert_server_events(
1034 &store,
1035 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1036 cx,
1037 );
1038 set_context_server_configuration(
1039 vec![(
1040 server_1_id.0.clone(),
1041 ContextServerSettings::Extension {
1042 enabled: true,
1043 settings: json!({
1044 "somevalue": false
1045 }),
1046 },
1047 )],
1048 cx,
1049 );
1050
1051 cx.run_until_parked();
1052
1053 cx.update(|cx| {
1054 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1055 });
1056 }
1057
1058 // Ensure that nothing happens if the settings do not change
1059 {
1060 let _server_events = assert_server_events(&store, vec![], cx);
1061 set_context_server_configuration(
1062 vec![(
1063 server_1_id.0.clone(),
1064 ContextServerSettings::Extension {
1065 enabled: true,
1066 settings: json!({
1067 "somevalue": false
1068 }),
1069 },
1070 )],
1071 cx,
1072 );
1073
1074 cx.run_until_parked();
1075
1076 cx.update(|cx| {
1077 assert_eq!(
1078 store.read(cx).status_for_server(&server_1_id),
1079 Some(ContextServerStatus::Running)
1080 );
1081 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1082 });
1083 }
1084 }
1085
1086 #[gpui::test]
1087 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1088 const SERVER_1_ID: &str = "mcp-1";
1089
1090 let server_1_id = ContextServerId(SERVER_1_ID.into());
1091
1092 let (_fs, project) = setup_context_server_test(
1093 cx,
1094 json!({"code.rs": ""}),
1095 vec![(
1096 SERVER_1_ID.into(),
1097 ContextServerSettings::Custom {
1098 enabled: true,
1099 command: ContextServerCommand {
1100 path: "somebinary".into(),
1101 args: vec!["arg".to_string()],
1102 env: None,
1103 timeout: None,
1104 },
1105 },
1106 )],
1107 )
1108 .await;
1109
1110 let executor = cx.executor();
1111 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1112 let store = cx.new(|cx| {
1113 ContextServerStore::test_maintain_server_loop(
1114 Box::new(move |id, _| {
1115 Arc::new(ContextServer::new(
1116 id.clone(),
1117 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1118 ))
1119 }),
1120 registry.clone(),
1121 project.read(cx).worktree_store(),
1122 project.downgrade(),
1123 cx,
1124 )
1125 });
1126
1127 // Ensure that mcp-1 starts up
1128 {
1129 let _server_events = assert_server_events(
1130 &store,
1131 vec![
1132 (server_1_id.clone(), ContextServerStatus::Starting),
1133 (server_1_id.clone(), ContextServerStatus::Running),
1134 ],
1135 cx,
1136 );
1137 cx.run_until_parked();
1138 }
1139
1140 // Ensure that mcp-1 is stopped once it is disabled.
1141 {
1142 let _server_events = assert_server_events(
1143 &store,
1144 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1145 cx,
1146 );
1147 set_context_server_configuration(
1148 vec![(
1149 server_1_id.0.clone(),
1150 ContextServerSettings::Custom {
1151 enabled: false,
1152 command: ContextServerCommand {
1153 path: "somebinary".into(),
1154 args: vec!["arg".to_string()],
1155 env: None,
1156 timeout: None,
1157 },
1158 },
1159 )],
1160 cx,
1161 );
1162
1163 cx.run_until_parked();
1164 }
1165
1166 // Ensure that mcp-1 is started once it is enabled again.
1167 {
1168 let _server_events = assert_server_events(
1169 &store,
1170 vec![
1171 (server_1_id.clone(), ContextServerStatus::Starting),
1172 (server_1_id.clone(), ContextServerStatus::Running),
1173 ],
1174 cx,
1175 );
1176 set_context_server_configuration(
1177 vec![(
1178 server_1_id.0.clone(),
1179 ContextServerSettings::Custom {
1180 enabled: true,
1181 command: ContextServerCommand {
1182 path: "somebinary".into(),
1183 args: vec!["arg".to_string()],
1184 timeout: None,
1185 env: None,
1186 },
1187 },
1188 )],
1189 cx,
1190 );
1191
1192 cx.run_until_parked();
1193 }
1194 }
1195
1196 fn set_context_server_configuration(
1197 context_servers: Vec<(Arc<str>, ContextServerSettings)>,
1198 cx: &mut TestAppContext,
1199 ) {
1200 cx.update(|cx| {
1201 SettingsStore::update_global(cx, |store, cx| {
1202 let mut settings = ProjectSettings::default();
1203 for (id, config) in context_servers {
1204 settings.context_servers.insert(id, config);
1205 }
1206 store
1207 .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
1208 .unwrap();
1209 })
1210 });
1211 }
1212
1213 struct ServerEvents {
1214 received_event_count: Rc<RefCell<usize>>,
1215 expected_event_count: usize,
1216 _subscription: Subscription,
1217 }
1218
1219 impl Drop for ServerEvents {
1220 fn drop(&mut self) {
1221 let actual_event_count = *self.received_event_count.borrow();
1222 assert_eq!(
1223 actual_event_count, self.expected_event_count,
1224 "
1225 Expected to receive {} context server store events, but received {} events",
1226 self.expected_event_count, actual_event_count
1227 );
1228 }
1229 }
1230
1231 fn dummy_server_settings() -> ContextServerSettings {
1232 ContextServerSettings::Custom {
1233 enabled: true,
1234 command: ContextServerCommand {
1235 path: "somebinary".into(),
1236 args: vec!["arg".to_string()],
1237 env: None,
1238 timeout: None,
1239 },
1240 }
1241 }
1242
1243 fn assert_server_events(
1244 store: &Entity<ContextServerStore>,
1245 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1246 cx: &mut TestAppContext,
1247 ) -> ServerEvents {
1248 cx.update(|cx| {
1249 let mut ix = 0;
1250 let received_event_count = Rc::new(RefCell::new(0));
1251 let expected_event_count = expected_events.len();
1252 let subscription = cx.subscribe(store, {
1253 let received_event_count = received_event_count.clone();
1254 move |_, event, _| match event {
1255 Event::ServerStatusChanged {
1256 server_id: actual_server_id,
1257 status: actual_status,
1258 } => {
1259 let (expected_server_id, expected_status) = &expected_events[ix];
1260
1261 assert_eq!(
1262 actual_server_id, expected_server_id,
1263 "Expected different server id at index {}",
1264 ix
1265 );
1266 assert_eq!(
1267 actual_status, expected_status,
1268 "Expected different status at index {}",
1269 ix
1270 );
1271 ix += 1;
1272 *received_event_count.borrow_mut() += 1;
1273 }
1274 }
1275 });
1276 ServerEvents {
1277 expected_event_count,
1278 received_event_count,
1279 _subscription: subscription,
1280 }
1281 })
1282 }
1283
1284 async fn setup_context_server_test(
1285 cx: &mut TestAppContext,
1286 files: serde_json::Value,
1287 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1288 ) -> (Arc<FakeFs>, Entity<Project>) {
1289 cx.update(|cx| {
1290 let settings_store = SettingsStore::test(cx);
1291 cx.set_global(settings_store);
1292 Project::init_settings(cx);
1293 let mut settings = ProjectSettings::get_global(cx).clone();
1294 for (id, config) in context_server_configurations {
1295 settings.context_servers.insert(id, config);
1296 }
1297 ProjectSettings::override_global(settings, cx);
1298 });
1299
1300 let fs = FakeFs::new(cx.executor());
1301 fs.insert_tree(path!("/test"), files).await;
1302 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1303
1304 (fs, project)
1305 }
1306
1307 struct FakeContextServerDescriptor {
1308 path: PathBuf,
1309 }
1310
1311 impl FakeContextServerDescriptor {
1312 fn new(path: impl Into<PathBuf>) -> Self {
1313 Self { path: path.into() }
1314 }
1315 }
1316
1317 impl ContextServerDescriptor for FakeContextServerDescriptor {
1318 fn command(
1319 &self,
1320 _worktree_store: Entity<WorktreeStore>,
1321 _cx: &AsyncApp,
1322 ) -> Task<Result<ContextServerCommand>> {
1323 Task::ready(Ok(ContextServerCommand {
1324 path: self.path.clone(),
1325 args: vec!["arg1".to_string(), "arg2".to_string()],
1326 env: None,
1327 timeout: None,
1328 }))
1329 }
1330
1331 fn configuration(
1332 &self,
1333 _worktree_store: Entity<WorktreeStore>,
1334 _cx: &AsyncApp,
1335 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1336 Task::ready(Ok(None))
1337 }
1338 }
1339}