1pub mod extension;
2pub mod registry;
3
4use std::sync::Arc;
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
9use futures::{FutureExt as _, future::join_all};
10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
11use registry::ContextServerDescriptorRegistry;
12use settings::{Settings as _, SettingsStore};
13use util::{ResultExt as _, rel_path::RelPath};
14
15use crate::{
16 Project,
17 project_settings::{ContextServerSettings, ProjectSettings},
18 worktree_store::WorktreeStore,
19};
20
21pub fn init(cx: &mut App) {
22 extension::init(cx);
23}
24
25actions!(
26 context_server,
27 [
28 /// Restarts the context server.
29 Restart
30 ]
31);
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub enum ContextServerStatus {
35 Starting,
36 Running,
37 Stopped,
38 Error(Arc<str>),
39}
40
41impl ContextServerStatus {
42 fn from_state(state: &ContextServerState) -> Self {
43 match state {
44 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
45 ContextServerState::Running { .. } => ContextServerStatus::Running,
46 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
47 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
48 }
49 }
50}
51
52enum ContextServerState {
53 Starting {
54 server: Arc<ContextServer>,
55 configuration: Arc<ContextServerConfiguration>,
56 _task: Task<()>,
57 },
58 Running {
59 server: Arc<ContextServer>,
60 configuration: Arc<ContextServerConfiguration>,
61 },
62 Stopped {
63 server: Arc<ContextServer>,
64 configuration: Arc<ContextServerConfiguration>,
65 },
66 Error {
67 server: Arc<ContextServer>,
68 configuration: Arc<ContextServerConfiguration>,
69 error: Arc<str>,
70 },
71}
72
73impl ContextServerState {
74 pub fn server(&self) -> Arc<ContextServer> {
75 match self {
76 ContextServerState::Starting { server, .. } => server.clone(),
77 ContextServerState::Running { server, .. } => server.clone(),
78 ContextServerState::Stopped { server, .. } => server.clone(),
79 ContextServerState::Error { server, .. } => server.clone(),
80 }
81 }
82
83 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
84 match self {
85 ContextServerState::Starting { configuration, .. } => configuration.clone(),
86 ContextServerState::Running { configuration, .. } => configuration.clone(),
87 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
88 ContextServerState::Error { configuration, .. } => configuration.clone(),
89 }
90 }
91}
92
93#[derive(Debug, PartialEq, Eq)]
94pub enum ContextServerConfiguration {
95 Custom {
96 command: ContextServerCommand,
97 },
98 Extension {
99 command: ContextServerCommand,
100 settings: serde_json::Value,
101 },
102}
103
104impl ContextServerConfiguration {
105 pub fn command(&self) -> &ContextServerCommand {
106 match self {
107 ContextServerConfiguration::Custom { command } => command,
108 ContextServerConfiguration::Extension { command, .. } => command,
109 }
110 }
111
112 pub async fn from_settings(
113 settings: ContextServerSettings,
114 id: ContextServerId,
115 registry: Entity<ContextServerDescriptorRegistry>,
116 worktree_store: Entity<WorktreeStore>,
117 cx: &AsyncApp,
118 ) -> Option<Self> {
119 match settings {
120 ContextServerSettings::Custom {
121 enabled: _,
122 command,
123 } => Some(ContextServerConfiguration::Custom { command }),
124 ContextServerSettings::Extension {
125 enabled: _,
126 settings,
127 } => {
128 let descriptor = cx
129 .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
130 .ok()
131 .flatten()?;
132
133 match descriptor.command(worktree_store, cx).await {
134 Ok(command) => {
135 Some(ContextServerConfiguration::Extension { command, settings })
136 }
137 Err(e) => {
138 log::error!(
139 "Failed to create context server configuration from settings: {e:#}"
140 );
141 None
142 }
143 }
144 }
145 }
146 }
147}
148
149pub type ContextServerFactory =
150 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
151
152pub struct ContextServerStore {
153 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
154 servers: HashMap<ContextServerId, ContextServerState>,
155 worktree_store: Entity<WorktreeStore>,
156 project: WeakEntity<Project>,
157 registry: Entity<ContextServerDescriptorRegistry>,
158 update_servers_task: Option<Task<Result<()>>>,
159 context_server_factory: Option<ContextServerFactory>,
160 needs_server_update: bool,
161 _subscriptions: Vec<Subscription>,
162}
163
164pub enum Event {
165 ServerStatusChanged {
166 server_id: ContextServerId,
167 status: ContextServerStatus,
168 },
169}
170
171impl EventEmitter<Event> for ContextServerStore {}
172
173impl ContextServerStore {
174 pub fn new(
175 worktree_store: Entity<WorktreeStore>,
176 weak_project: WeakEntity<Project>,
177 cx: &mut Context<Self>,
178 ) -> Self {
179 Self::new_internal(
180 true,
181 None,
182 ContextServerDescriptorRegistry::default_global(cx),
183 worktree_store,
184 weak_project,
185 cx,
186 )
187 }
188
189 /// Returns all configured context server ids, regardless of enabled state.
190 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
191 self.context_server_settings
192 .keys()
193 .cloned()
194 .map(ContextServerId)
195 .collect()
196 }
197
198 #[cfg(any(test, feature = "test-support"))]
199 pub fn test(
200 registry: Entity<ContextServerDescriptorRegistry>,
201 worktree_store: Entity<WorktreeStore>,
202 weak_project: WeakEntity<Project>,
203 cx: &mut Context<Self>,
204 ) -> Self {
205 Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
206 }
207
208 #[cfg(any(test, feature = "test-support"))]
209 pub fn test_maintain_server_loop(
210 context_server_factory: ContextServerFactory,
211 registry: Entity<ContextServerDescriptorRegistry>,
212 worktree_store: Entity<WorktreeStore>,
213 weak_project: WeakEntity<Project>,
214 cx: &mut Context<Self>,
215 ) -> Self {
216 Self::new_internal(
217 true,
218 Some(context_server_factory),
219 registry,
220 worktree_store,
221 weak_project,
222 cx,
223 )
224 }
225
226 fn new_internal(
227 maintain_server_loop: bool,
228 context_server_factory: Option<ContextServerFactory>,
229 registry: Entity<ContextServerDescriptorRegistry>,
230 worktree_store: Entity<WorktreeStore>,
231 weak_project: WeakEntity<Project>,
232 cx: &mut Context<Self>,
233 ) -> Self {
234 let subscriptions = if maintain_server_loop {
235 vec![
236 cx.observe(®istry, |this, _registry, cx| {
237 this.available_context_servers_changed(cx);
238 }),
239 cx.observe_global::<SettingsStore>(|this, cx| {
240 let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
241 if &this.context_server_settings == settings {
242 return;
243 }
244 this.context_server_settings = settings.clone();
245 this.available_context_servers_changed(cx);
246 }),
247 ]
248 } else {
249 Vec::new()
250 };
251
252 let mut this = Self {
253 _subscriptions: subscriptions,
254 context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
255 .clone(),
256 worktree_store,
257 project: weak_project,
258 registry,
259 needs_server_update: false,
260 servers: HashMap::default(),
261 update_servers_task: None,
262 context_server_factory,
263 };
264 if maintain_server_loop {
265 this.available_context_servers_changed(cx);
266 }
267 this
268 }
269
270 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
271 self.servers.get(id).map(|state| state.server())
272 }
273
274 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
275 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
276 Some(server.clone())
277 } else {
278 None
279 }
280 }
281
282 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
283 self.servers.get(id).map(ContextServerStatus::from_state)
284 }
285
286 pub fn configuration_for_server(
287 &self,
288 id: &ContextServerId,
289 ) -> Option<Arc<ContextServerConfiguration>> {
290 self.servers.get(id).map(|state| state.configuration())
291 }
292
293 pub fn server_ids(&self, cx: &App) -> HashSet<ContextServerId> {
294 self.servers
295 .keys()
296 .cloned()
297 .chain(
298 self.registry
299 .read(cx)
300 .context_server_descriptors()
301 .into_iter()
302 .map(|(id, _)| ContextServerId(id)),
303 )
304 .collect()
305 }
306
307 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
308 self.servers
309 .values()
310 .filter_map(|state| {
311 if let ContextServerState::Running { server, .. } = state {
312 Some(server.clone())
313 } else {
314 None
315 }
316 })
317 .collect()
318 }
319
320 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
321 cx.spawn(async move |this, cx| {
322 let this = this.upgrade().context("Context server store dropped")?;
323 let settings = this
324 .update(cx, |this, _| {
325 this.context_server_settings.get(&server.id().0).cloned()
326 })
327 .ok()
328 .flatten()
329 .context("Failed to get context server settings")?;
330
331 if !settings.enabled() {
332 return Ok(());
333 }
334
335 let (registry, worktree_store) = this.update(cx, |this, _| {
336 (this.registry.clone(), this.worktree_store.clone())
337 })?;
338 let configuration = ContextServerConfiguration::from_settings(
339 settings,
340 server.id(),
341 registry,
342 worktree_store,
343 cx,
344 )
345 .await
346 .context("Failed to create context server configuration")?;
347
348 this.update(cx, |this, cx| {
349 this.run_server(server, Arc::new(configuration), cx)
350 })
351 })
352 .detach_and_log_err(cx);
353 }
354
355 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
356 if matches!(
357 self.servers.get(id),
358 Some(ContextServerState::Stopped { .. })
359 ) {
360 return Ok(());
361 }
362
363 let state = self
364 .servers
365 .remove(id)
366 .context("Context server not found")?;
367
368 let server = state.server();
369 let configuration = state.configuration();
370 let mut result = Ok(());
371 if let ContextServerState::Running { server, .. } = &state {
372 result = server.stop();
373 }
374 drop(state);
375
376 self.update_server_state(
377 id.clone(),
378 ContextServerState::Stopped {
379 configuration,
380 server,
381 },
382 cx,
383 );
384
385 result
386 }
387
388 pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
389 if let Some(state) = self.servers.get(id) {
390 let configuration = state.configuration();
391
392 self.stop_server(&state.server().id(), cx)?;
393 let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
394 self.run_server(new_server, configuration, cx);
395 }
396 Ok(())
397 }
398
399 fn run_server(
400 &mut self,
401 server: Arc<ContextServer>,
402 configuration: Arc<ContextServerConfiguration>,
403 cx: &mut Context<Self>,
404 ) {
405 let id = server.id();
406 if matches!(
407 self.servers.get(&id),
408 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
409 ) {
410 self.stop_server(&id, cx).log_err();
411 }
412
413 let task = cx.spawn({
414 let id = server.id();
415 let server = server.clone();
416 let configuration = configuration.clone();
417 async move |this, cx| {
418 match server.clone().start(cx).await {
419 Ok(_) => {
420 debug_assert!(server.client().is_some());
421
422 this.update(cx, |this, cx| {
423 this.update_server_state(
424 id.clone(),
425 ContextServerState::Running {
426 server,
427 configuration,
428 },
429 cx,
430 )
431 })
432 .log_err()
433 }
434 Err(err) => {
435 log::error!("{} context server failed to start: {}", id, err);
436 this.update(cx, |this, cx| {
437 this.update_server_state(
438 id.clone(),
439 ContextServerState::Error {
440 configuration,
441 server,
442 error: err.to_string().into(),
443 },
444 cx,
445 )
446 })
447 .log_err()
448 }
449 };
450 }
451 });
452
453 self.update_server_state(
454 id.clone(),
455 ContextServerState::Starting {
456 configuration,
457 _task: task,
458 server,
459 },
460 cx,
461 );
462 }
463
464 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
465 let state = self
466 .servers
467 .remove(id)
468 .context("Context server not found")?;
469 drop(state);
470 cx.emit(Event::ServerStatusChanged {
471 server_id: id.clone(),
472 status: ContextServerStatus::Stopped,
473 });
474 Ok(())
475 }
476
477 fn create_context_server(
478 &self,
479 id: ContextServerId,
480 configuration: Arc<ContextServerConfiguration>,
481 cx: &mut Context<Self>,
482 ) -> Arc<ContextServer> {
483 let project = self.project.upgrade();
484 let mut root_path = None;
485 if let Some(project) = project {
486 let project = project.read(cx);
487 if project.is_local() {
488 if let Some(path) = project.active_project_directory(cx) {
489 root_path = Some(path);
490 } else {
491 for worktree in self.worktree_store.read(cx).visible_worktrees(cx) {
492 if let Some(path) = worktree.read(cx).root_dir() {
493 root_path = Some(path);
494 break;
495 }
496 }
497 }
498 }
499 };
500
501 if let Some(factory) = self.context_server_factory.as_ref() {
502 factory(id, configuration)
503 } else {
504 Arc::new(ContextServer::stdio(
505 id,
506 configuration.command().clone(),
507 root_path,
508 ))
509 }
510 }
511
512 fn resolve_context_server_settings<'a>(
513 worktree_store: &'a Entity<WorktreeStore>,
514 cx: &'a App,
515 ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
516 let location = worktree_store
517 .read(cx)
518 .visible_worktrees(cx)
519 .next()
520 .map(|worktree| settings::SettingsLocation {
521 worktree_id: worktree.read(cx).id(),
522 path: RelPath::empty(),
523 });
524 &ProjectSettings::get(location, cx).context_servers
525 }
526
527 fn update_server_state(
528 &mut self,
529 id: ContextServerId,
530 state: ContextServerState,
531 cx: &mut Context<Self>,
532 ) {
533 let status = ContextServerStatus::from_state(&state);
534 self.servers.insert(id.clone(), state);
535 cx.emit(Event::ServerStatusChanged {
536 server_id: id,
537 status,
538 });
539 }
540
541 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
542 if self.update_servers_task.is_some() {
543 self.needs_server_update = true;
544 } else {
545 self.needs_server_update = false;
546 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
547 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
548 log::error!("Error maintaining context servers: {}", err);
549 }
550
551 this.update(cx, |this, cx| {
552 this.update_servers_task.take();
553 if this.needs_server_update {
554 this.available_context_servers_changed(cx);
555 }
556 })?;
557
558 Ok(())
559 }));
560 }
561 }
562
563 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
564 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
565 (
566 this.context_server_settings.clone(),
567 this.registry.clone(),
568 this.worktree_store.clone(),
569 )
570 })?;
571
572 for (id, _) in
573 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
574 {
575 configured_servers
576 .entry(id)
577 .or_insert(ContextServerSettings::default_extension());
578 }
579
580 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
581 configured_servers
582 .into_iter()
583 .partition(|(_, settings)| settings.enabled());
584
585 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
586 let id = ContextServerId(id);
587 ContextServerConfiguration::from_settings(
588 settings,
589 id.clone(),
590 registry.clone(),
591 worktree_store.clone(),
592 cx,
593 )
594 .map(|config| (id, config))
595 }))
596 .await
597 .into_iter()
598 .filter_map(|(id, config)| config.map(|config| (id, config)))
599 .collect::<HashMap<_, _>>();
600
601 let mut servers_to_start = Vec::new();
602 let mut servers_to_remove = HashSet::default();
603 let mut servers_to_stop = HashSet::default();
604
605 this.update(cx, |this, cx| {
606 for server_id in this.servers.keys() {
607 // All servers that are not in desired_servers should be removed from the store.
608 // This can happen if the user removed a server from the context server settings.
609 if !configured_servers.contains_key(server_id) {
610 if disabled_servers.contains_key(&server_id.0) {
611 servers_to_stop.insert(server_id.clone());
612 } else {
613 servers_to_remove.insert(server_id.clone());
614 }
615 }
616 }
617
618 for (id, config) in configured_servers {
619 let state = this.servers.get(&id);
620 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
621 let existing_config = state.as_ref().map(|state| state.configuration());
622 if existing_config.as_deref() != Some(&config) || is_stopped {
623 let config = Arc::new(config);
624 let server = this.create_context_server(id.clone(), config.clone(), cx);
625 servers_to_start.push((server, config));
626 if this.servers.contains_key(&id) {
627 servers_to_stop.insert(id);
628 }
629 }
630 }
631 })?;
632
633 this.update(cx, |this, cx| {
634 for id in servers_to_stop {
635 this.stop_server(&id, cx)?;
636 }
637 for id in servers_to_remove {
638 this.remove_server(&id, cx)?;
639 }
640 for (server, config) in servers_to_start {
641 this.run_server(server, config, cx);
642 }
643 anyhow::Ok(())
644 })?
645 }
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651 use crate::{
652 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
653 project_settings::ProjectSettings,
654 };
655 use context_server::test::create_fake_transport;
656 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
657 use serde_json::json;
658 use std::{cell::RefCell, path::PathBuf, rc::Rc};
659 use util::path;
660
661 #[gpui::test]
662 async fn test_context_server_status(cx: &mut TestAppContext) {
663 const SERVER_1_ID: &str = "mcp-1";
664 const SERVER_2_ID: &str = "mcp-2";
665
666 let (_fs, project) = setup_context_server_test(
667 cx,
668 json!({"code.rs": ""}),
669 vec![
670 (SERVER_1_ID.into(), dummy_server_settings()),
671 (SERVER_2_ID.into(), dummy_server_settings()),
672 ],
673 )
674 .await;
675
676 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
677 let store = cx.new(|cx| {
678 ContextServerStore::test(
679 registry.clone(),
680 project.read(cx).worktree_store(),
681 project.downgrade(),
682 cx,
683 )
684 });
685
686 let server_1_id = ContextServerId(SERVER_1_ID.into());
687 let server_2_id = ContextServerId(SERVER_2_ID.into());
688
689 let server_1 = Arc::new(ContextServer::new(
690 server_1_id.clone(),
691 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
692 ));
693 let server_2 = Arc::new(ContextServer::new(
694 server_2_id.clone(),
695 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
696 ));
697
698 store.update(cx, |store, cx| store.start_server(server_1, cx));
699
700 cx.run_until_parked();
701
702 cx.update(|cx| {
703 assert_eq!(
704 store.read(cx).status_for_server(&server_1_id),
705 Some(ContextServerStatus::Running)
706 );
707 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
708 });
709
710 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
711
712 cx.run_until_parked();
713
714 cx.update(|cx| {
715 assert_eq!(
716 store.read(cx).status_for_server(&server_1_id),
717 Some(ContextServerStatus::Running)
718 );
719 assert_eq!(
720 store.read(cx).status_for_server(&server_2_id),
721 Some(ContextServerStatus::Running)
722 );
723 });
724
725 store
726 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
727 .unwrap();
728
729 cx.update(|cx| {
730 assert_eq!(
731 store.read(cx).status_for_server(&server_1_id),
732 Some(ContextServerStatus::Running)
733 );
734 assert_eq!(
735 store.read(cx).status_for_server(&server_2_id),
736 Some(ContextServerStatus::Stopped)
737 );
738 });
739 }
740
741 #[gpui::test]
742 async fn test_context_server_status_events(cx: &mut TestAppContext) {
743 const SERVER_1_ID: &str = "mcp-1";
744 const SERVER_2_ID: &str = "mcp-2";
745
746 let (_fs, project) = setup_context_server_test(
747 cx,
748 json!({"code.rs": ""}),
749 vec![
750 (SERVER_1_ID.into(), dummy_server_settings()),
751 (SERVER_2_ID.into(), dummy_server_settings()),
752 ],
753 )
754 .await;
755
756 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
757 let store = cx.new(|cx| {
758 ContextServerStore::test(
759 registry.clone(),
760 project.read(cx).worktree_store(),
761 project.downgrade(),
762 cx,
763 )
764 });
765
766 let server_1_id = ContextServerId(SERVER_1_ID.into());
767 let server_2_id = ContextServerId(SERVER_2_ID.into());
768
769 let server_1 = Arc::new(ContextServer::new(
770 server_1_id.clone(),
771 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
772 ));
773 let server_2 = Arc::new(ContextServer::new(
774 server_2_id.clone(),
775 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
776 ));
777
778 let _server_events = assert_server_events(
779 &store,
780 vec![
781 (server_1_id.clone(), ContextServerStatus::Starting),
782 (server_1_id, ContextServerStatus::Running),
783 (server_2_id.clone(), ContextServerStatus::Starting),
784 (server_2_id.clone(), ContextServerStatus::Running),
785 (server_2_id.clone(), ContextServerStatus::Stopped),
786 ],
787 cx,
788 );
789
790 store.update(cx, |store, cx| store.start_server(server_1, cx));
791
792 cx.run_until_parked();
793
794 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
795
796 cx.run_until_parked();
797
798 store
799 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
800 .unwrap();
801 }
802
803 #[gpui::test(iterations = 25)]
804 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
805 const SERVER_1_ID: &str = "mcp-1";
806
807 let (_fs, project) = setup_context_server_test(
808 cx,
809 json!({"code.rs": ""}),
810 vec![(SERVER_1_ID.into(), dummy_server_settings())],
811 )
812 .await;
813
814 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
815 let store = cx.new(|cx| {
816 ContextServerStore::test(
817 registry.clone(),
818 project.read(cx).worktree_store(),
819 project.downgrade(),
820 cx,
821 )
822 });
823
824 let server_id = ContextServerId(SERVER_1_ID.into());
825
826 let server_with_same_id_1 = Arc::new(ContextServer::new(
827 server_id.clone(),
828 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
829 ));
830 let server_with_same_id_2 = Arc::new(ContextServer::new(
831 server_id.clone(),
832 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
833 ));
834
835 // If we start another server with the same id, we should report that we stopped the previous one
836 let _server_events = assert_server_events(
837 &store,
838 vec![
839 (server_id.clone(), ContextServerStatus::Starting),
840 (server_id.clone(), ContextServerStatus::Stopped),
841 (server_id.clone(), ContextServerStatus::Starting),
842 (server_id.clone(), ContextServerStatus::Running),
843 ],
844 cx,
845 );
846
847 store.update(cx, |store, cx| {
848 store.start_server(server_with_same_id_1.clone(), cx)
849 });
850 store.update(cx, |store, cx| {
851 store.start_server(server_with_same_id_2.clone(), cx)
852 });
853
854 cx.run_until_parked();
855
856 cx.update(|cx| {
857 assert_eq!(
858 store.read(cx).status_for_server(&server_id),
859 Some(ContextServerStatus::Running)
860 );
861 });
862 }
863
864 #[gpui::test]
865 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
866 const SERVER_1_ID: &str = "mcp-1";
867 const SERVER_2_ID: &str = "mcp-2";
868
869 let server_1_id = ContextServerId(SERVER_1_ID.into());
870 let server_2_id = ContextServerId(SERVER_2_ID.into());
871
872 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
873
874 let (_fs, project) = setup_context_server_test(
875 cx,
876 json!({"code.rs": ""}),
877 vec![(
878 SERVER_1_ID.into(),
879 ContextServerSettings::Extension {
880 enabled: true,
881 settings: json!({
882 "somevalue": true
883 }),
884 },
885 )],
886 )
887 .await;
888
889 let executor = cx.executor();
890 let registry = cx.new(|cx| {
891 let mut registry = ContextServerDescriptorRegistry::new();
892 registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
893 registry
894 });
895 let store = cx.new(|cx| {
896 ContextServerStore::test_maintain_server_loop(
897 Box::new(move |id, _| {
898 Arc::new(ContextServer::new(
899 id.clone(),
900 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
901 ))
902 }),
903 registry.clone(),
904 project.read(cx).worktree_store(),
905 project.downgrade(),
906 cx,
907 )
908 });
909
910 // Ensure that mcp-1 starts up
911 {
912 let _server_events = assert_server_events(
913 &store,
914 vec![
915 (server_1_id.clone(), ContextServerStatus::Starting),
916 (server_1_id.clone(), ContextServerStatus::Running),
917 ],
918 cx,
919 );
920 cx.run_until_parked();
921 }
922
923 // Ensure that mcp-1 is restarted when the configuration was changed
924 {
925 let _server_events = assert_server_events(
926 &store,
927 vec![
928 (server_1_id.clone(), ContextServerStatus::Stopped),
929 (server_1_id.clone(), ContextServerStatus::Starting),
930 (server_1_id.clone(), ContextServerStatus::Running),
931 ],
932 cx,
933 );
934 set_context_server_configuration(
935 vec![(
936 server_1_id.0.clone(),
937 settings::ContextServerSettingsContent::Extension {
938 enabled: true,
939 settings: json!({
940 "somevalue": false
941 }),
942 },
943 )],
944 cx,
945 );
946
947 cx.run_until_parked();
948 }
949
950 // Ensure that mcp-1 is not restarted when the configuration was not changed
951 {
952 let _server_events = assert_server_events(&store, vec![], cx);
953 set_context_server_configuration(
954 vec![(
955 server_1_id.0.clone(),
956 settings::ContextServerSettingsContent::Extension {
957 enabled: true,
958 settings: json!({
959 "somevalue": false
960 }),
961 },
962 )],
963 cx,
964 );
965
966 cx.run_until_parked();
967 }
968
969 // Ensure that mcp-2 is started once it is added to the settings
970 {
971 let _server_events = assert_server_events(
972 &store,
973 vec![
974 (server_2_id.clone(), ContextServerStatus::Starting),
975 (server_2_id.clone(), ContextServerStatus::Running),
976 ],
977 cx,
978 );
979 set_context_server_configuration(
980 vec![
981 (
982 server_1_id.0.clone(),
983 settings::ContextServerSettingsContent::Extension {
984 enabled: true,
985 settings: json!({
986 "somevalue": false
987 }),
988 },
989 ),
990 (
991 server_2_id.0.clone(),
992 settings::ContextServerSettingsContent::Custom {
993 enabled: true,
994 command: ContextServerCommand {
995 path: "somebinary".into(),
996 args: vec!["arg".to_string()],
997 env: None,
998 timeout: None,
999 },
1000 },
1001 ),
1002 ],
1003 cx,
1004 );
1005
1006 cx.run_until_parked();
1007 }
1008
1009 // Ensure that mcp-2 is restarted once the args have changed
1010 {
1011 let _server_events = assert_server_events(
1012 &store,
1013 vec![
1014 (server_2_id.clone(), ContextServerStatus::Stopped),
1015 (server_2_id.clone(), ContextServerStatus::Starting),
1016 (server_2_id.clone(), ContextServerStatus::Running),
1017 ],
1018 cx,
1019 );
1020 set_context_server_configuration(
1021 vec![
1022 (
1023 server_1_id.0.clone(),
1024 settings::ContextServerSettingsContent::Extension {
1025 enabled: true,
1026 settings: json!({
1027 "somevalue": false
1028 }),
1029 },
1030 ),
1031 (
1032 server_2_id.0.clone(),
1033 settings::ContextServerSettingsContent::Custom {
1034 enabled: true,
1035 command: ContextServerCommand {
1036 path: "somebinary".into(),
1037 args: vec!["anotherArg".to_string()],
1038 env: None,
1039 timeout: None,
1040 },
1041 },
1042 ),
1043 ],
1044 cx,
1045 );
1046
1047 cx.run_until_parked();
1048 }
1049
1050 // Ensure that mcp-2 is removed once it is removed from the settings
1051 {
1052 let _server_events = assert_server_events(
1053 &store,
1054 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1055 cx,
1056 );
1057 set_context_server_configuration(
1058 vec![(
1059 server_1_id.0.clone(),
1060 settings::ContextServerSettingsContent::Extension {
1061 enabled: true,
1062 settings: json!({
1063 "somevalue": false
1064 }),
1065 },
1066 )],
1067 cx,
1068 );
1069
1070 cx.run_until_parked();
1071
1072 cx.update(|cx| {
1073 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1074 });
1075 }
1076
1077 // Ensure that nothing happens if the settings do not change
1078 {
1079 let _server_events = assert_server_events(&store, vec![], cx);
1080 set_context_server_configuration(
1081 vec![(
1082 server_1_id.0.clone(),
1083 settings::ContextServerSettingsContent::Extension {
1084 enabled: true,
1085 settings: json!({
1086 "somevalue": false
1087 }),
1088 },
1089 )],
1090 cx,
1091 );
1092
1093 cx.run_until_parked();
1094
1095 cx.update(|cx| {
1096 assert_eq!(
1097 store.read(cx).status_for_server(&server_1_id),
1098 Some(ContextServerStatus::Running)
1099 );
1100 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1101 });
1102 }
1103 }
1104
1105 #[gpui::test]
1106 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1107 const SERVER_1_ID: &str = "mcp-1";
1108
1109 let server_1_id = ContextServerId(SERVER_1_ID.into());
1110
1111 let (_fs, project) = setup_context_server_test(
1112 cx,
1113 json!({"code.rs": ""}),
1114 vec![(
1115 SERVER_1_ID.into(),
1116 ContextServerSettings::Custom {
1117 enabled: true,
1118 command: ContextServerCommand {
1119 path: "somebinary".into(),
1120 args: vec!["arg".to_string()],
1121 env: None,
1122 timeout: None,
1123 },
1124 },
1125 )],
1126 )
1127 .await;
1128
1129 let executor = cx.executor();
1130 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1131 let store = cx.new(|cx| {
1132 ContextServerStore::test_maintain_server_loop(
1133 Box::new(move |id, _| {
1134 Arc::new(ContextServer::new(
1135 id.clone(),
1136 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1137 ))
1138 }),
1139 registry.clone(),
1140 project.read(cx).worktree_store(),
1141 project.downgrade(),
1142 cx,
1143 )
1144 });
1145
1146 // Ensure that mcp-1 starts up
1147 {
1148 let _server_events = assert_server_events(
1149 &store,
1150 vec![
1151 (server_1_id.clone(), ContextServerStatus::Starting),
1152 (server_1_id.clone(), ContextServerStatus::Running),
1153 ],
1154 cx,
1155 );
1156 cx.run_until_parked();
1157 }
1158
1159 // Ensure that mcp-1 is stopped once it is disabled.
1160 {
1161 let _server_events = assert_server_events(
1162 &store,
1163 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1164 cx,
1165 );
1166 set_context_server_configuration(
1167 vec![(
1168 server_1_id.0.clone(),
1169 settings::ContextServerSettingsContent::Custom {
1170 enabled: false,
1171 command: ContextServerCommand {
1172 path: "somebinary".into(),
1173 args: vec!["arg".to_string()],
1174 env: None,
1175 timeout: None,
1176 },
1177 },
1178 )],
1179 cx,
1180 );
1181
1182 cx.run_until_parked();
1183 }
1184
1185 // Ensure that mcp-1 is started once it is enabled again.
1186 {
1187 let _server_events = assert_server_events(
1188 &store,
1189 vec![
1190 (server_1_id.clone(), ContextServerStatus::Starting),
1191 (server_1_id.clone(), ContextServerStatus::Running),
1192 ],
1193 cx,
1194 );
1195 set_context_server_configuration(
1196 vec![(
1197 server_1_id.0.clone(),
1198 settings::ContextServerSettingsContent::Custom {
1199 enabled: true,
1200 command: ContextServerCommand {
1201 path: "somebinary".into(),
1202 args: vec!["arg".to_string()],
1203 timeout: None,
1204 env: None,
1205 },
1206 },
1207 )],
1208 cx,
1209 );
1210
1211 cx.run_until_parked();
1212 }
1213 }
1214
1215 fn set_context_server_configuration(
1216 context_servers: Vec<(Arc<str>, settings::ContextServerSettingsContent)>,
1217 cx: &mut TestAppContext,
1218 ) {
1219 cx.update(|cx| {
1220 SettingsStore::update_global(cx, |store, cx| {
1221 store.update_user_settings(cx, |content| {
1222 content.project.context_servers.clear();
1223 for (id, config) in context_servers {
1224 content.project.context_servers.insert(id, config);
1225 }
1226 });
1227 })
1228 });
1229 }
1230
1231 struct ServerEvents {
1232 received_event_count: Rc<RefCell<usize>>,
1233 expected_event_count: usize,
1234 _subscription: Subscription,
1235 }
1236
1237 impl Drop for ServerEvents {
1238 fn drop(&mut self) {
1239 let actual_event_count = *self.received_event_count.borrow();
1240 assert_eq!(
1241 actual_event_count, self.expected_event_count,
1242 "
1243 Expected to receive {} context server store events, but received {} events",
1244 self.expected_event_count, actual_event_count
1245 );
1246 }
1247 }
1248
1249 fn dummy_server_settings() -> ContextServerSettings {
1250 ContextServerSettings::Custom {
1251 enabled: true,
1252 command: ContextServerCommand {
1253 path: "somebinary".into(),
1254 args: vec!["arg".to_string()],
1255 env: None,
1256 timeout: None,
1257 },
1258 }
1259 }
1260
1261 fn assert_server_events(
1262 store: &Entity<ContextServerStore>,
1263 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1264 cx: &mut TestAppContext,
1265 ) -> ServerEvents {
1266 cx.update(|cx| {
1267 let mut ix = 0;
1268 let received_event_count = Rc::new(RefCell::new(0));
1269 let expected_event_count = expected_events.len();
1270 let subscription = cx.subscribe(store, {
1271 let received_event_count = received_event_count.clone();
1272 move |_, event, _| match event {
1273 Event::ServerStatusChanged {
1274 server_id: actual_server_id,
1275 status: actual_status,
1276 } => {
1277 let (expected_server_id, expected_status) = &expected_events[ix];
1278
1279 assert_eq!(
1280 actual_server_id, expected_server_id,
1281 "Expected different server id at index {}",
1282 ix
1283 );
1284 assert_eq!(
1285 actual_status, expected_status,
1286 "Expected different status at index {}",
1287 ix
1288 );
1289 ix += 1;
1290 *received_event_count.borrow_mut() += 1;
1291 }
1292 }
1293 });
1294 ServerEvents {
1295 expected_event_count,
1296 received_event_count,
1297 _subscription: subscription,
1298 }
1299 })
1300 }
1301
1302 async fn setup_context_server_test(
1303 cx: &mut TestAppContext,
1304 files: serde_json::Value,
1305 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1306 ) -> (Arc<FakeFs>, Entity<Project>) {
1307 cx.update(|cx| {
1308 let settings_store = SettingsStore::test(cx);
1309 cx.set_global(settings_store);
1310 let mut settings = ProjectSettings::get_global(cx).clone();
1311 for (id, config) in context_server_configurations {
1312 settings.context_servers.insert(id, config);
1313 }
1314 ProjectSettings::override_global(settings, cx);
1315 });
1316
1317 let fs = FakeFs::new(cx.executor());
1318 fs.insert_tree(path!("/test"), files).await;
1319 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1320
1321 (fs, project)
1322 }
1323
1324 struct FakeContextServerDescriptor {
1325 path: PathBuf,
1326 }
1327
1328 impl FakeContextServerDescriptor {
1329 fn new(path: impl Into<PathBuf>) -> Self {
1330 Self { path: path.into() }
1331 }
1332 }
1333
1334 impl ContextServerDescriptor for FakeContextServerDescriptor {
1335 fn command(
1336 &self,
1337 _worktree_store: Entity<WorktreeStore>,
1338 _cx: &AsyncApp,
1339 ) -> Task<Result<ContextServerCommand>> {
1340 Task::ready(Ok(ContextServerCommand {
1341 path: self.path.clone(),
1342 args: vec!["arg1".to_string(), "arg2".to_string()],
1343 env: None,
1344 timeout: None,
1345 }))
1346 }
1347
1348 fn configuration(
1349 &self,
1350 _worktree_store: Entity<WorktreeStore>,
1351 _cx: &AsyncApp,
1352 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1353 Task::ready(Ok(None))
1354 }
1355 }
1356}