1pub mod extension;
2pub mod registry;
3
4use std::sync::Arc;
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
9use futures::{FutureExt as _, future::join_all};
10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
11use registry::ContextServerDescriptorRegistry;
12use settings::{Settings as _, SettingsStore};
13use util::{ResultExt as _, rel_path::RelPath};
14
15use crate::{
16 Project,
17 project_settings::{ContextServerSettings, ProjectSettings},
18 worktree_store::WorktreeStore,
19};
20
21pub fn init(cx: &mut App) {
22 extension::init(cx);
23}
24
25actions!(
26 context_server,
27 [
28 /// Restarts the context server.
29 Restart
30 ]
31);
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub enum ContextServerStatus {
35 Starting,
36 Running,
37 Stopped,
38 Error(Arc<str>),
39}
40
41impl ContextServerStatus {
42 fn from_state(state: &ContextServerState) -> Self {
43 match state {
44 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
45 ContextServerState::Running { .. } => ContextServerStatus::Running,
46 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
47 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
48 }
49 }
50}
51
52enum ContextServerState {
53 Starting {
54 server: Arc<ContextServer>,
55 configuration: Arc<ContextServerConfiguration>,
56 _task: Task<()>,
57 },
58 Running {
59 server: Arc<ContextServer>,
60 configuration: Arc<ContextServerConfiguration>,
61 },
62 Stopped {
63 server: Arc<ContextServer>,
64 configuration: Arc<ContextServerConfiguration>,
65 },
66 Error {
67 server: Arc<ContextServer>,
68 configuration: Arc<ContextServerConfiguration>,
69 error: Arc<str>,
70 },
71}
72
73impl ContextServerState {
74 pub fn server(&self) -> Arc<ContextServer> {
75 match self {
76 ContextServerState::Starting { server, .. } => server.clone(),
77 ContextServerState::Running { server, .. } => server.clone(),
78 ContextServerState::Stopped { server, .. } => server.clone(),
79 ContextServerState::Error { server, .. } => server.clone(),
80 }
81 }
82
83 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
84 match self {
85 ContextServerState::Starting { configuration, .. } => configuration.clone(),
86 ContextServerState::Running { configuration, .. } => configuration.clone(),
87 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
88 ContextServerState::Error { configuration, .. } => configuration.clone(),
89 }
90 }
91}
92
93#[derive(Debug, PartialEq, Eq)]
94pub enum ContextServerConfiguration {
95 Custom {
96 command: ContextServerCommand,
97 },
98 Extension {
99 command: ContextServerCommand,
100 settings: serde_json::Value,
101 },
102}
103
104impl ContextServerConfiguration {
105 pub fn command(&self) -> &ContextServerCommand {
106 match self {
107 ContextServerConfiguration::Custom { command } => command,
108 ContextServerConfiguration::Extension { command, .. } => command,
109 }
110 }
111
112 pub async fn from_settings(
113 settings: ContextServerSettings,
114 id: ContextServerId,
115 registry: Entity<ContextServerDescriptorRegistry>,
116 worktree_store: Entity<WorktreeStore>,
117 cx: &AsyncApp,
118 ) -> Option<Self> {
119 match settings {
120 ContextServerSettings::Custom {
121 enabled: _,
122 command,
123 } => Some(ContextServerConfiguration::Custom { command }),
124 ContextServerSettings::Extension {
125 enabled: _,
126 settings,
127 } => {
128 let descriptor = cx
129 .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
130 .ok()
131 .flatten()?;
132
133 let command = descriptor.command(worktree_store, cx).await.log_err()?;
134
135 Some(ContextServerConfiguration::Extension { command, settings })
136 }
137 }
138 }
139}
140
141pub type ContextServerFactory =
142 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
143
144pub struct ContextServerStore {
145 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
146 servers: HashMap<ContextServerId, ContextServerState>,
147 worktree_store: Entity<WorktreeStore>,
148 project: WeakEntity<Project>,
149 registry: Entity<ContextServerDescriptorRegistry>,
150 update_servers_task: Option<Task<Result<()>>>,
151 context_server_factory: Option<ContextServerFactory>,
152 needs_server_update: bool,
153 _subscriptions: Vec<Subscription>,
154}
155
156pub enum Event {
157 ServerStatusChanged {
158 server_id: ContextServerId,
159 status: ContextServerStatus,
160 },
161}
162
163impl EventEmitter<Event> for ContextServerStore {}
164
165impl ContextServerStore {
166 pub fn new(
167 worktree_store: Entity<WorktreeStore>,
168 weak_project: WeakEntity<Project>,
169 cx: &mut Context<Self>,
170 ) -> Self {
171 Self::new_internal(
172 true,
173 None,
174 ContextServerDescriptorRegistry::default_global(cx),
175 worktree_store,
176 weak_project,
177 cx,
178 )
179 }
180
181 /// Returns all configured context server ids, regardless of enabled state.
182 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
183 self.context_server_settings
184 .keys()
185 .cloned()
186 .map(ContextServerId)
187 .collect()
188 }
189
190 #[cfg(any(test, feature = "test-support"))]
191 pub fn test(
192 registry: Entity<ContextServerDescriptorRegistry>,
193 worktree_store: Entity<WorktreeStore>,
194 weak_project: WeakEntity<Project>,
195 cx: &mut Context<Self>,
196 ) -> Self {
197 Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
198 }
199
200 #[cfg(any(test, feature = "test-support"))]
201 pub fn test_maintain_server_loop(
202 context_server_factory: ContextServerFactory,
203 registry: Entity<ContextServerDescriptorRegistry>,
204 worktree_store: Entity<WorktreeStore>,
205 weak_project: WeakEntity<Project>,
206 cx: &mut Context<Self>,
207 ) -> Self {
208 Self::new_internal(
209 true,
210 Some(context_server_factory),
211 registry,
212 worktree_store,
213 weak_project,
214 cx,
215 )
216 }
217
218 fn new_internal(
219 maintain_server_loop: bool,
220 context_server_factory: Option<ContextServerFactory>,
221 registry: Entity<ContextServerDescriptorRegistry>,
222 worktree_store: Entity<WorktreeStore>,
223 weak_project: WeakEntity<Project>,
224 cx: &mut Context<Self>,
225 ) -> Self {
226 let subscriptions = if maintain_server_loop {
227 vec![
228 cx.observe(®istry, |this, _registry, cx| {
229 this.available_context_servers_changed(cx);
230 }),
231 cx.observe_global::<SettingsStore>(|this, cx| {
232 let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
233 if &this.context_server_settings == settings {
234 return;
235 }
236 this.context_server_settings = settings.clone();
237 this.available_context_servers_changed(cx);
238 }),
239 ]
240 } else {
241 Vec::new()
242 };
243
244 let mut this = Self {
245 _subscriptions: subscriptions,
246 context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
247 .clone(),
248 worktree_store,
249 project: weak_project,
250 registry,
251 needs_server_update: false,
252 servers: HashMap::default(),
253 update_servers_task: None,
254 context_server_factory,
255 };
256 if maintain_server_loop {
257 this.available_context_servers_changed(cx);
258 }
259 this
260 }
261
262 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
263 self.servers.get(id).map(|state| state.server())
264 }
265
266 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
267 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
268 Some(server.clone())
269 } else {
270 None
271 }
272 }
273
274 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
275 self.servers.get(id).map(ContextServerStatus::from_state)
276 }
277
278 pub fn configuration_for_server(
279 &self,
280 id: &ContextServerId,
281 ) -> Option<Arc<ContextServerConfiguration>> {
282 self.servers.get(id).map(|state| state.configuration())
283 }
284
285 pub fn server_ids(&self, cx: &App) -> HashSet<ContextServerId> {
286 self.servers
287 .keys()
288 .cloned()
289 .chain(
290 self.registry
291 .read(cx)
292 .context_server_descriptors()
293 .into_iter()
294 .map(|(id, _)| ContextServerId(id)),
295 )
296 .collect()
297 }
298
299 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
300 self.servers
301 .values()
302 .filter_map(|state| {
303 if let ContextServerState::Running { server, .. } = state {
304 Some(server.clone())
305 } else {
306 None
307 }
308 })
309 .collect()
310 }
311
312 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
313 cx.spawn(async move |this, cx| {
314 let this = this.upgrade().context("Context server store dropped")?;
315 let settings = this
316 .update(cx, |this, _| {
317 this.context_server_settings.get(&server.id().0).cloned()
318 })
319 .ok()
320 .flatten()
321 .context("Failed to get context server settings")?;
322
323 if !settings.enabled() {
324 return Ok(());
325 }
326
327 let (registry, worktree_store) = this.update(cx, |this, _| {
328 (this.registry.clone(), this.worktree_store.clone())
329 })?;
330 let configuration = ContextServerConfiguration::from_settings(
331 settings,
332 server.id(),
333 registry,
334 worktree_store,
335 cx,
336 )
337 .await
338 .context("Failed to create context server configuration")?;
339
340 this.update(cx, |this, cx| {
341 this.run_server(server, Arc::new(configuration), cx)
342 })
343 })
344 .detach_and_log_err(cx);
345 }
346
347 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
348 if matches!(
349 self.servers.get(id),
350 Some(ContextServerState::Stopped { .. })
351 ) {
352 return Ok(());
353 }
354
355 let state = self
356 .servers
357 .remove(id)
358 .context("Context server not found")?;
359
360 let server = state.server();
361 let configuration = state.configuration();
362 let mut result = Ok(());
363 if let ContextServerState::Running { server, .. } = &state {
364 result = server.stop();
365 }
366 drop(state);
367
368 self.update_server_state(
369 id.clone(),
370 ContextServerState::Stopped {
371 configuration,
372 server,
373 },
374 cx,
375 );
376
377 result
378 }
379
380 pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
381 if let Some(state) = self.servers.get(id) {
382 let configuration = state.configuration();
383
384 self.stop_server(&state.server().id(), cx)?;
385 let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
386 self.run_server(new_server, configuration, cx);
387 }
388 Ok(())
389 }
390
391 fn run_server(
392 &mut self,
393 server: Arc<ContextServer>,
394 configuration: Arc<ContextServerConfiguration>,
395 cx: &mut Context<Self>,
396 ) {
397 let id = server.id();
398 if matches!(
399 self.servers.get(&id),
400 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
401 ) {
402 self.stop_server(&id, cx).log_err();
403 }
404
405 let task = cx.spawn({
406 let id = server.id();
407 let server = server.clone();
408 let configuration = configuration.clone();
409 async move |this, cx| {
410 match server.clone().start(cx).await {
411 Ok(_) => {
412 debug_assert!(server.client().is_some());
413
414 this.update(cx, |this, cx| {
415 this.update_server_state(
416 id.clone(),
417 ContextServerState::Running {
418 server,
419 configuration,
420 },
421 cx,
422 )
423 })
424 .log_err()
425 }
426 Err(err) => {
427 log::error!("{} context server failed to start: {}", id, err);
428 this.update(cx, |this, cx| {
429 this.update_server_state(
430 id.clone(),
431 ContextServerState::Error {
432 configuration,
433 server,
434 error: err.to_string().into(),
435 },
436 cx,
437 )
438 })
439 .log_err()
440 }
441 };
442 }
443 });
444
445 self.update_server_state(
446 id.clone(),
447 ContextServerState::Starting {
448 configuration,
449 _task: task,
450 server,
451 },
452 cx,
453 );
454 }
455
456 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
457 let state = self
458 .servers
459 .remove(id)
460 .context("Context server not found")?;
461 drop(state);
462 cx.emit(Event::ServerStatusChanged {
463 server_id: id.clone(),
464 status: ContextServerStatus::Stopped,
465 });
466 Ok(())
467 }
468
469 fn create_context_server(
470 &self,
471 id: ContextServerId,
472 configuration: Arc<ContextServerConfiguration>,
473 cx: &mut Context<Self>,
474 ) -> Arc<ContextServer> {
475 let project = self.project.upgrade();
476 let mut root_path = None;
477 if let Some(project) = project {
478 let project = project.read(cx);
479 if project.is_local() {
480 if let Some(path) = project.active_project_directory(cx) {
481 root_path = Some(path);
482 } else {
483 for worktree in self.worktree_store.read(cx).visible_worktrees(cx) {
484 if let Some(path) = worktree.read(cx).root_dir() {
485 root_path = Some(path);
486 break;
487 }
488 }
489 }
490 }
491 };
492
493 if let Some(factory) = self.context_server_factory.as_ref() {
494 factory(id, configuration)
495 } else {
496 Arc::new(ContextServer::stdio(
497 id,
498 configuration.command().clone(),
499 root_path,
500 ))
501 }
502 }
503
504 fn resolve_context_server_settings<'a>(
505 worktree_store: &'a Entity<WorktreeStore>,
506 cx: &'a App,
507 ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
508 let location = worktree_store
509 .read(cx)
510 .visible_worktrees(cx)
511 .next()
512 .map(|worktree| settings::SettingsLocation {
513 worktree_id: worktree.read(cx).id(),
514 path: RelPath::empty(),
515 });
516 &ProjectSettings::get(location, cx).context_servers
517 }
518
519 fn update_server_state(
520 &mut self,
521 id: ContextServerId,
522 state: ContextServerState,
523 cx: &mut Context<Self>,
524 ) {
525 let status = ContextServerStatus::from_state(&state);
526 self.servers.insert(id.clone(), state);
527 cx.emit(Event::ServerStatusChanged {
528 server_id: id,
529 status,
530 });
531 }
532
533 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
534 if self.update_servers_task.is_some() {
535 self.needs_server_update = true;
536 } else {
537 self.needs_server_update = false;
538 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
539 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
540 log::error!("Error maintaining context servers: {}", err);
541 }
542
543 this.update(cx, |this, cx| {
544 this.update_servers_task.take();
545 if this.needs_server_update {
546 this.available_context_servers_changed(cx);
547 }
548 })?;
549
550 Ok(())
551 }));
552 }
553 }
554
555 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
556 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
557 (
558 this.context_server_settings.clone(),
559 this.registry.clone(),
560 this.worktree_store.clone(),
561 )
562 })?;
563
564 for (id, _) in
565 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
566 {
567 configured_servers
568 .entry(id)
569 .or_insert(ContextServerSettings::default_extension());
570 }
571
572 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
573 configured_servers
574 .into_iter()
575 .partition(|(_, settings)| settings.enabled());
576
577 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
578 let id = ContextServerId(id);
579 ContextServerConfiguration::from_settings(
580 settings,
581 id.clone(),
582 registry.clone(),
583 worktree_store.clone(),
584 cx,
585 )
586 .map(|config| (id, config))
587 }))
588 .await
589 .into_iter()
590 .filter_map(|(id, config)| config.map(|config| (id, config)))
591 .collect::<HashMap<_, _>>();
592
593 let mut servers_to_start = Vec::new();
594 let mut servers_to_remove = HashSet::default();
595 let mut servers_to_stop = HashSet::default();
596
597 this.update(cx, |this, cx| {
598 for server_id in this.servers.keys() {
599 // All servers that are not in desired_servers should be removed from the store.
600 // This can happen if the user removed a server from the context server settings.
601 if !configured_servers.contains_key(server_id) {
602 if disabled_servers.contains_key(&server_id.0) {
603 servers_to_stop.insert(server_id.clone());
604 } else {
605 servers_to_remove.insert(server_id.clone());
606 }
607 }
608 }
609
610 for (id, config) in configured_servers {
611 let state = this.servers.get(&id);
612 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
613 let existing_config = state.as_ref().map(|state| state.configuration());
614 if existing_config.as_deref() != Some(&config) || is_stopped {
615 let config = Arc::new(config);
616 let server = this.create_context_server(id.clone(), config.clone(), cx);
617 servers_to_start.push((server, config));
618 if this.servers.contains_key(&id) {
619 servers_to_stop.insert(id);
620 }
621 }
622 }
623 })?;
624
625 this.update(cx, |this, cx| {
626 for id in servers_to_stop {
627 this.stop_server(&id, cx)?;
628 }
629 for id in servers_to_remove {
630 this.remove_server(&id, cx)?;
631 }
632 for (server, config) in servers_to_start {
633 this.run_server(server, config, cx);
634 }
635 anyhow::Ok(())
636 })?
637 }
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643 use crate::{
644 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
645 project_settings::ProjectSettings,
646 };
647 use context_server::test::create_fake_transport;
648 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
649 use serde_json::json;
650 use std::{cell::RefCell, path::PathBuf, rc::Rc};
651 use util::path;
652
653 #[gpui::test]
654 async fn test_context_server_status(cx: &mut TestAppContext) {
655 const SERVER_1_ID: &str = "mcp-1";
656 const SERVER_2_ID: &str = "mcp-2";
657
658 let (_fs, project) = setup_context_server_test(
659 cx,
660 json!({"code.rs": ""}),
661 vec![
662 (SERVER_1_ID.into(), dummy_server_settings()),
663 (SERVER_2_ID.into(), dummy_server_settings()),
664 ],
665 )
666 .await;
667
668 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
669 let store = cx.new(|cx| {
670 ContextServerStore::test(
671 registry.clone(),
672 project.read(cx).worktree_store(),
673 project.downgrade(),
674 cx,
675 )
676 });
677
678 let server_1_id = ContextServerId(SERVER_1_ID.into());
679 let server_2_id = ContextServerId(SERVER_2_ID.into());
680
681 let server_1 = Arc::new(ContextServer::new(
682 server_1_id.clone(),
683 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
684 ));
685 let server_2 = Arc::new(ContextServer::new(
686 server_2_id.clone(),
687 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
688 ));
689
690 store.update(cx, |store, cx| store.start_server(server_1, cx));
691
692 cx.run_until_parked();
693
694 cx.update(|cx| {
695 assert_eq!(
696 store.read(cx).status_for_server(&server_1_id),
697 Some(ContextServerStatus::Running)
698 );
699 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
700 });
701
702 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
703
704 cx.run_until_parked();
705
706 cx.update(|cx| {
707 assert_eq!(
708 store.read(cx).status_for_server(&server_1_id),
709 Some(ContextServerStatus::Running)
710 );
711 assert_eq!(
712 store.read(cx).status_for_server(&server_2_id),
713 Some(ContextServerStatus::Running)
714 );
715 });
716
717 store
718 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
719 .unwrap();
720
721 cx.update(|cx| {
722 assert_eq!(
723 store.read(cx).status_for_server(&server_1_id),
724 Some(ContextServerStatus::Running)
725 );
726 assert_eq!(
727 store.read(cx).status_for_server(&server_2_id),
728 Some(ContextServerStatus::Stopped)
729 );
730 });
731 }
732
733 #[gpui::test]
734 async fn test_context_server_status_events(cx: &mut TestAppContext) {
735 const SERVER_1_ID: &str = "mcp-1";
736 const SERVER_2_ID: &str = "mcp-2";
737
738 let (_fs, project) = setup_context_server_test(
739 cx,
740 json!({"code.rs": ""}),
741 vec![
742 (SERVER_1_ID.into(), dummy_server_settings()),
743 (SERVER_2_ID.into(), dummy_server_settings()),
744 ],
745 )
746 .await;
747
748 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
749 let store = cx.new(|cx| {
750 ContextServerStore::test(
751 registry.clone(),
752 project.read(cx).worktree_store(),
753 project.downgrade(),
754 cx,
755 )
756 });
757
758 let server_1_id = ContextServerId(SERVER_1_ID.into());
759 let server_2_id = ContextServerId(SERVER_2_ID.into());
760
761 let server_1 = Arc::new(ContextServer::new(
762 server_1_id.clone(),
763 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
764 ));
765 let server_2 = Arc::new(ContextServer::new(
766 server_2_id.clone(),
767 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
768 ));
769
770 let _server_events = assert_server_events(
771 &store,
772 vec![
773 (server_1_id.clone(), ContextServerStatus::Starting),
774 (server_1_id, ContextServerStatus::Running),
775 (server_2_id.clone(), ContextServerStatus::Starting),
776 (server_2_id.clone(), ContextServerStatus::Running),
777 (server_2_id.clone(), ContextServerStatus::Stopped),
778 ],
779 cx,
780 );
781
782 store.update(cx, |store, cx| store.start_server(server_1, cx));
783
784 cx.run_until_parked();
785
786 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
787
788 cx.run_until_parked();
789
790 store
791 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
792 .unwrap();
793 }
794
795 #[gpui::test(iterations = 25)]
796 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
797 const SERVER_1_ID: &str = "mcp-1";
798
799 let (_fs, project) = setup_context_server_test(
800 cx,
801 json!({"code.rs": ""}),
802 vec![(SERVER_1_ID.into(), dummy_server_settings())],
803 )
804 .await;
805
806 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
807 let store = cx.new(|cx| {
808 ContextServerStore::test(
809 registry.clone(),
810 project.read(cx).worktree_store(),
811 project.downgrade(),
812 cx,
813 )
814 });
815
816 let server_id = ContextServerId(SERVER_1_ID.into());
817
818 let server_with_same_id_1 = Arc::new(ContextServer::new(
819 server_id.clone(),
820 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
821 ));
822 let server_with_same_id_2 = Arc::new(ContextServer::new(
823 server_id.clone(),
824 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
825 ));
826
827 // If we start another server with the same id, we should report that we stopped the previous one
828 let _server_events = assert_server_events(
829 &store,
830 vec![
831 (server_id.clone(), ContextServerStatus::Starting),
832 (server_id.clone(), ContextServerStatus::Stopped),
833 (server_id.clone(), ContextServerStatus::Starting),
834 (server_id.clone(), ContextServerStatus::Running),
835 ],
836 cx,
837 );
838
839 store.update(cx, |store, cx| {
840 store.start_server(server_with_same_id_1.clone(), cx)
841 });
842 store.update(cx, |store, cx| {
843 store.start_server(server_with_same_id_2.clone(), cx)
844 });
845
846 cx.run_until_parked();
847
848 cx.update(|cx| {
849 assert_eq!(
850 store.read(cx).status_for_server(&server_id),
851 Some(ContextServerStatus::Running)
852 );
853 });
854 }
855
856 #[gpui::test]
857 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
858 const SERVER_1_ID: &str = "mcp-1";
859 const SERVER_2_ID: &str = "mcp-2";
860
861 let server_1_id = ContextServerId(SERVER_1_ID.into());
862 let server_2_id = ContextServerId(SERVER_2_ID.into());
863
864 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
865
866 let (_fs, project) = setup_context_server_test(
867 cx,
868 json!({"code.rs": ""}),
869 vec![(
870 SERVER_1_ID.into(),
871 ContextServerSettings::Extension {
872 enabled: true,
873 settings: json!({
874 "somevalue": true
875 }),
876 },
877 )],
878 )
879 .await;
880
881 let executor = cx.executor();
882 let registry = cx.new(|cx| {
883 let mut registry = ContextServerDescriptorRegistry::new();
884 registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
885 registry
886 });
887 let store = cx.new(|cx| {
888 ContextServerStore::test_maintain_server_loop(
889 Box::new(move |id, _| {
890 Arc::new(ContextServer::new(
891 id.clone(),
892 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
893 ))
894 }),
895 registry.clone(),
896 project.read(cx).worktree_store(),
897 project.downgrade(),
898 cx,
899 )
900 });
901
902 // Ensure that mcp-1 starts up
903 {
904 let _server_events = assert_server_events(
905 &store,
906 vec![
907 (server_1_id.clone(), ContextServerStatus::Starting),
908 (server_1_id.clone(), ContextServerStatus::Running),
909 ],
910 cx,
911 );
912 cx.run_until_parked();
913 }
914
915 // Ensure that mcp-1 is restarted when the configuration was changed
916 {
917 let _server_events = assert_server_events(
918 &store,
919 vec![
920 (server_1_id.clone(), ContextServerStatus::Stopped),
921 (server_1_id.clone(), ContextServerStatus::Starting),
922 (server_1_id.clone(), ContextServerStatus::Running),
923 ],
924 cx,
925 );
926 set_context_server_configuration(
927 vec![(
928 server_1_id.0.clone(),
929 settings::ContextServerSettingsContent::Extension {
930 enabled: true,
931 settings: json!({
932 "somevalue": false
933 }),
934 },
935 )],
936 cx,
937 );
938
939 cx.run_until_parked();
940 }
941
942 // Ensure that mcp-1 is not restarted when the configuration was not changed
943 {
944 let _server_events = assert_server_events(&store, vec![], cx);
945 set_context_server_configuration(
946 vec![(
947 server_1_id.0.clone(),
948 settings::ContextServerSettingsContent::Extension {
949 enabled: true,
950 settings: json!({
951 "somevalue": false
952 }),
953 },
954 )],
955 cx,
956 );
957
958 cx.run_until_parked();
959 }
960
961 // Ensure that mcp-2 is started once it is added to the settings
962 {
963 let _server_events = assert_server_events(
964 &store,
965 vec![
966 (server_2_id.clone(), ContextServerStatus::Starting),
967 (server_2_id.clone(), ContextServerStatus::Running),
968 ],
969 cx,
970 );
971 set_context_server_configuration(
972 vec![
973 (
974 server_1_id.0.clone(),
975 settings::ContextServerSettingsContent::Extension {
976 enabled: true,
977 settings: json!({
978 "somevalue": false
979 }),
980 },
981 ),
982 (
983 server_2_id.0.clone(),
984 settings::ContextServerSettingsContent::Custom {
985 enabled: true,
986 command: ContextServerCommand {
987 path: "somebinary".into(),
988 args: vec!["arg".to_string()],
989 env: None,
990 timeout: None,
991 },
992 },
993 ),
994 ],
995 cx,
996 );
997
998 cx.run_until_parked();
999 }
1000
1001 // Ensure that mcp-2 is restarted once the args have changed
1002 {
1003 let _server_events = assert_server_events(
1004 &store,
1005 vec![
1006 (server_2_id.clone(), ContextServerStatus::Stopped),
1007 (server_2_id.clone(), ContextServerStatus::Starting),
1008 (server_2_id.clone(), ContextServerStatus::Running),
1009 ],
1010 cx,
1011 );
1012 set_context_server_configuration(
1013 vec![
1014 (
1015 server_1_id.0.clone(),
1016 settings::ContextServerSettingsContent::Extension {
1017 enabled: true,
1018 settings: json!({
1019 "somevalue": false
1020 }),
1021 },
1022 ),
1023 (
1024 server_2_id.0.clone(),
1025 settings::ContextServerSettingsContent::Custom {
1026 enabled: true,
1027 command: ContextServerCommand {
1028 path: "somebinary".into(),
1029 args: vec!["anotherArg".to_string()],
1030 env: None,
1031 timeout: None,
1032 },
1033 },
1034 ),
1035 ],
1036 cx,
1037 );
1038
1039 cx.run_until_parked();
1040 }
1041
1042 // Ensure that mcp-2 is removed once it is removed from the settings
1043 {
1044 let _server_events = assert_server_events(
1045 &store,
1046 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1047 cx,
1048 );
1049 set_context_server_configuration(
1050 vec![(
1051 server_1_id.0.clone(),
1052 settings::ContextServerSettingsContent::Extension {
1053 enabled: true,
1054 settings: json!({
1055 "somevalue": false
1056 }),
1057 },
1058 )],
1059 cx,
1060 );
1061
1062 cx.run_until_parked();
1063
1064 cx.update(|cx| {
1065 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1066 });
1067 }
1068
1069 // Ensure that nothing happens if the settings do not change
1070 {
1071 let _server_events = assert_server_events(&store, vec![], cx);
1072 set_context_server_configuration(
1073 vec![(
1074 server_1_id.0.clone(),
1075 settings::ContextServerSettingsContent::Extension {
1076 enabled: true,
1077 settings: json!({
1078 "somevalue": false
1079 }),
1080 },
1081 )],
1082 cx,
1083 );
1084
1085 cx.run_until_parked();
1086
1087 cx.update(|cx| {
1088 assert_eq!(
1089 store.read(cx).status_for_server(&server_1_id),
1090 Some(ContextServerStatus::Running)
1091 );
1092 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1093 });
1094 }
1095 }
1096
1097 #[gpui::test]
1098 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1099 const SERVER_1_ID: &str = "mcp-1";
1100
1101 let server_1_id = ContextServerId(SERVER_1_ID.into());
1102
1103 let (_fs, project) = setup_context_server_test(
1104 cx,
1105 json!({"code.rs": ""}),
1106 vec![(
1107 SERVER_1_ID.into(),
1108 ContextServerSettings::Custom {
1109 enabled: true,
1110 command: ContextServerCommand {
1111 path: "somebinary".into(),
1112 args: vec!["arg".to_string()],
1113 env: None,
1114 timeout: None,
1115 },
1116 },
1117 )],
1118 )
1119 .await;
1120
1121 let executor = cx.executor();
1122 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1123 let store = cx.new(|cx| {
1124 ContextServerStore::test_maintain_server_loop(
1125 Box::new(move |id, _| {
1126 Arc::new(ContextServer::new(
1127 id.clone(),
1128 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1129 ))
1130 }),
1131 registry.clone(),
1132 project.read(cx).worktree_store(),
1133 project.downgrade(),
1134 cx,
1135 )
1136 });
1137
1138 // Ensure that mcp-1 starts up
1139 {
1140 let _server_events = assert_server_events(
1141 &store,
1142 vec![
1143 (server_1_id.clone(), ContextServerStatus::Starting),
1144 (server_1_id.clone(), ContextServerStatus::Running),
1145 ],
1146 cx,
1147 );
1148 cx.run_until_parked();
1149 }
1150
1151 // Ensure that mcp-1 is stopped once it is disabled.
1152 {
1153 let _server_events = assert_server_events(
1154 &store,
1155 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1156 cx,
1157 );
1158 set_context_server_configuration(
1159 vec![(
1160 server_1_id.0.clone(),
1161 settings::ContextServerSettingsContent::Custom {
1162 enabled: false,
1163 command: ContextServerCommand {
1164 path: "somebinary".into(),
1165 args: vec!["arg".to_string()],
1166 env: None,
1167 timeout: None,
1168 },
1169 },
1170 )],
1171 cx,
1172 );
1173
1174 cx.run_until_parked();
1175 }
1176
1177 // Ensure that mcp-1 is started once it is enabled again.
1178 {
1179 let _server_events = assert_server_events(
1180 &store,
1181 vec![
1182 (server_1_id.clone(), ContextServerStatus::Starting),
1183 (server_1_id.clone(), ContextServerStatus::Running),
1184 ],
1185 cx,
1186 );
1187 set_context_server_configuration(
1188 vec![(
1189 server_1_id.0.clone(),
1190 settings::ContextServerSettingsContent::Custom {
1191 enabled: true,
1192 command: ContextServerCommand {
1193 path: "somebinary".into(),
1194 args: vec!["arg".to_string()],
1195 timeout: None,
1196 env: None,
1197 },
1198 },
1199 )],
1200 cx,
1201 );
1202
1203 cx.run_until_parked();
1204 }
1205 }
1206
1207 fn set_context_server_configuration(
1208 context_servers: Vec<(Arc<str>, settings::ContextServerSettingsContent)>,
1209 cx: &mut TestAppContext,
1210 ) {
1211 cx.update(|cx| {
1212 SettingsStore::update_global(cx, |store, cx| {
1213 store.update_user_settings(cx, |content| {
1214 content.project.context_servers.clear();
1215 for (id, config) in context_servers {
1216 content.project.context_servers.insert(id, config);
1217 }
1218 });
1219 })
1220 });
1221 }
1222
1223 struct ServerEvents {
1224 received_event_count: Rc<RefCell<usize>>,
1225 expected_event_count: usize,
1226 _subscription: Subscription,
1227 }
1228
1229 impl Drop for ServerEvents {
1230 fn drop(&mut self) {
1231 let actual_event_count = *self.received_event_count.borrow();
1232 assert_eq!(
1233 actual_event_count, self.expected_event_count,
1234 "
1235 Expected to receive {} context server store events, but received {} events",
1236 self.expected_event_count, actual_event_count
1237 );
1238 }
1239 }
1240
1241 fn dummy_server_settings() -> ContextServerSettings {
1242 ContextServerSettings::Custom {
1243 enabled: true,
1244 command: ContextServerCommand {
1245 path: "somebinary".into(),
1246 args: vec!["arg".to_string()],
1247 env: None,
1248 timeout: None,
1249 },
1250 }
1251 }
1252
1253 fn assert_server_events(
1254 store: &Entity<ContextServerStore>,
1255 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1256 cx: &mut TestAppContext,
1257 ) -> ServerEvents {
1258 cx.update(|cx| {
1259 let mut ix = 0;
1260 let received_event_count = Rc::new(RefCell::new(0));
1261 let expected_event_count = expected_events.len();
1262 let subscription = cx.subscribe(store, {
1263 let received_event_count = received_event_count.clone();
1264 move |_, event, _| match event {
1265 Event::ServerStatusChanged {
1266 server_id: actual_server_id,
1267 status: actual_status,
1268 } => {
1269 let (expected_server_id, expected_status) = &expected_events[ix];
1270
1271 assert_eq!(
1272 actual_server_id, expected_server_id,
1273 "Expected different server id at index {}",
1274 ix
1275 );
1276 assert_eq!(
1277 actual_status, expected_status,
1278 "Expected different status at index {}",
1279 ix
1280 );
1281 ix += 1;
1282 *received_event_count.borrow_mut() += 1;
1283 }
1284 }
1285 });
1286 ServerEvents {
1287 expected_event_count,
1288 received_event_count,
1289 _subscription: subscription,
1290 }
1291 })
1292 }
1293
1294 async fn setup_context_server_test(
1295 cx: &mut TestAppContext,
1296 files: serde_json::Value,
1297 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1298 ) -> (Arc<FakeFs>, Entity<Project>) {
1299 cx.update(|cx| {
1300 let settings_store = SettingsStore::test(cx);
1301 cx.set_global(settings_store);
1302 Project::init_settings(cx);
1303 let mut settings = ProjectSettings::get_global(cx).clone();
1304 for (id, config) in context_server_configurations {
1305 settings.context_servers.insert(id, config);
1306 }
1307 ProjectSettings::override_global(settings, cx);
1308 });
1309
1310 let fs = FakeFs::new(cx.executor());
1311 fs.insert_tree(path!("/test"), files).await;
1312 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1313
1314 (fs, project)
1315 }
1316
1317 struct FakeContextServerDescriptor {
1318 path: PathBuf,
1319 }
1320
1321 impl FakeContextServerDescriptor {
1322 fn new(path: impl Into<PathBuf>) -> Self {
1323 Self { path: path.into() }
1324 }
1325 }
1326
1327 impl ContextServerDescriptor for FakeContextServerDescriptor {
1328 fn command(
1329 &self,
1330 _worktree_store: Entity<WorktreeStore>,
1331 _cx: &AsyncApp,
1332 ) -> Task<Result<ContextServerCommand>> {
1333 Task::ready(Ok(ContextServerCommand {
1334 path: self.path.clone(),
1335 args: vec!["arg1".to_string(), "arg2".to_string()],
1336 env: None,
1337 timeout: None,
1338 }))
1339 }
1340
1341 fn configuration(
1342 &self,
1343 _worktree_store: Entity<WorktreeStore>,
1344 _cx: &AsyncApp,
1345 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1346 Task::ready(Ok(None))
1347 }
1348 }
1349}