1pub mod extension;
2pub mod registry;
3
4use std::sync::Arc;
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
9use futures::{FutureExt as _, future::join_all};
10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
11use registry::ContextServerDescriptorRegistry;
12use settings::{Settings as _, SettingsStore};
13use util::{ResultExt as _, rel_path::RelPath};
14
15use crate::{
16 Project,
17 project_settings::{ContextServerSettings, ProjectSettings},
18 worktree_store::WorktreeStore,
19};
20
21pub fn init(cx: &mut App) {
22 extension::init(cx);
23}
24
25actions!(
26 context_server,
27 [
28 /// Restarts the context server.
29 Restart
30 ]
31);
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub enum ContextServerStatus {
35 Starting,
36 Running,
37 Stopped,
38 Error(Arc<str>),
39}
40
41impl ContextServerStatus {
42 fn from_state(state: &ContextServerState) -> Self {
43 match state {
44 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
45 ContextServerState::Running { .. } => ContextServerStatus::Running,
46 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
47 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
48 }
49 }
50}
51
52enum ContextServerState {
53 Starting {
54 server: Arc<ContextServer>,
55 configuration: Arc<ContextServerConfiguration>,
56 _task: Task<()>,
57 },
58 Running {
59 server: Arc<ContextServer>,
60 configuration: Arc<ContextServerConfiguration>,
61 },
62 Stopped {
63 server: Arc<ContextServer>,
64 configuration: Arc<ContextServerConfiguration>,
65 },
66 Error {
67 server: Arc<ContextServer>,
68 configuration: Arc<ContextServerConfiguration>,
69 error: Arc<str>,
70 },
71}
72
73impl ContextServerState {
74 pub fn server(&self) -> Arc<ContextServer> {
75 match self {
76 ContextServerState::Starting { server, .. } => server.clone(),
77 ContextServerState::Running { server, .. } => server.clone(),
78 ContextServerState::Stopped { server, .. } => server.clone(),
79 ContextServerState::Error { server, .. } => server.clone(),
80 }
81 }
82
83 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
84 match self {
85 ContextServerState::Starting { configuration, .. } => configuration.clone(),
86 ContextServerState::Running { configuration, .. } => configuration.clone(),
87 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
88 ContextServerState::Error { configuration, .. } => configuration.clone(),
89 }
90 }
91}
92
93#[derive(Debug, PartialEq, Eq)]
94pub enum ContextServerConfiguration {
95 Custom {
96 command: ContextServerCommand,
97 },
98 Extension {
99 command: ContextServerCommand,
100 settings: serde_json::Value,
101 },
102}
103
104impl ContextServerConfiguration {
105 pub fn command(&self) -> &ContextServerCommand {
106 match self {
107 ContextServerConfiguration::Custom { command } => command,
108 ContextServerConfiguration::Extension { command, .. } => command,
109 }
110 }
111
112 pub async fn from_settings(
113 settings: ContextServerSettings,
114 id: ContextServerId,
115 registry: Entity<ContextServerDescriptorRegistry>,
116 worktree_store: Entity<WorktreeStore>,
117 cx: &AsyncApp,
118 ) -> Option<Self> {
119 match settings {
120 ContextServerSettings::Custom {
121 enabled: _,
122 command,
123 } => Some(ContextServerConfiguration::Custom { command }),
124 ContextServerSettings::Extension {
125 enabled: _,
126 settings,
127 } => {
128 let descriptor = cx
129 .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
130 .ok()
131 .flatten()?;
132
133 let command = descriptor.command(worktree_store, cx).await.log_err()?;
134
135 Some(ContextServerConfiguration::Extension { command, settings })
136 }
137 }
138 }
139}
140
141pub type ContextServerFactory =
142 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
143
144pub struct ContextServerStore {
145 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
146 servers: HashMap<ContextServerId, ContextServerState>,
147 worktree_store: Entity<WorktreeStore>,
148 project: WeakEntity<Project>,
149 registry: Entity<ContextServerDescriptorRegistry>,
150 update_servers_task: Option<Task<Result<()>>>,
151 context_server_factory: Option<ContextServerFactory>,
152 needs_server_update: bool,
153 _subscriptions: Vec<Subscription>,
154}
155
156pub enum Event {
157 ServerStatusChanged {
158 server_id: ContextServerId,
159 status: ContextServerStatus,
160 },
161}
162
163impl EventEmitter<Event> for ContextServerStore {}
164
165impl ContextServerStore {
166 pub fn new(
167 worktree_store: Entity<WorktreeStore>,
168 weak_project: WeakEntity<Project>,
169 cx: &mut Context<Self>,
170 ) -> Self {
171 Self::new_internal(
172 true,
173 None,
174 ContextServerDescriptorRegistry::default_global(cx),
175 worktree_store,
176 weak_project,
177 cx,
178 )
179 }
180
181 /// Returns all configured context server ids, regardless of enabled state.
182 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
183 self.context_server_settings
184 .keys()
185 .cloned()
186 .map(ContextServerId)
187 .collect()
188 }
189
190 #[cfg(any(test, feature = "test-support"))]
191 pub fn test(
192 registry: Entity<ContextServerDescriptorRegistry>,
193 worktree_store: Entity<WorktreeStore>,
194 weak_project: WeakEntity<Project>,
195 cx: &mut Context<Self>,
196 ) -> Self {
197 Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
198 }
199
200 #[cfg(any(test, feature = "test-support"))]
201 pub fn test_maintain_server_loop(
202 context_server_factory: ContextServerFactory,
203 registry: Entity<ContextServerDescriptorRegistry>,
204 worktree_store: Entity<WorktreeStore>,
205 weak_project: WeakEntity<Project>,
206 cx: &mut Context<Self>,
207 ) -> Self {
208 Self::new_internal(
209 true,
210 Some(context_server_factory),
211 registry,
212 worktree_store,
213 weak_project,
214 cx,
215 )
216 }
217
218 fn new_internal(
219 maintain_server_loop: bool,
220 context_server_factory: Option<ContextServerFactory>,
221 registry: Entity<ContextServerDescriptorRegistry>,
222 worktree_store: Entity<WorktreeStore>,
223 weak_project: WeakEntity<Project>,
224 cx: &mut Context<Self>,
225 ) -> Self {
226 let subscriptions = if maintain_server_loop {
227 vec![
228 cx.observe(®istry, |this, _registry, cx| {
229 this.available_context_servers_changed(cx);
230 }),
231 cx.observe_global::<SettingsStore>(|this, cx| {
232 let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
233 if &this.context_server_settings == settings {
234 return;
235 }
236 this.context_server_settings = settings.clone();
237 this.available_context_servers_changed(cx);
238 }),
239 ]
240 } else {
241 Vec::new()
242 };
243
244 let mut this = Self {
245 _subscriptions: subscriptions,
246 context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
247 .clone(),
248 worktree_store,
249 project: weak_project,
250 registry,
251 needs_server_update: false,
252 servers: HashMap::default(),
253 update_servers_task: None,
254 context_server_factory,
255 };
256 if maintain_server_loop {
257 this.available_context_servers_changed(cx);
258 }
259 this
260 }
261
262 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
263 self.servers.get(id).map(|state| state.server())
264 }
265
266 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
267 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
268 Some(server.clone())
269 } else {
270 None
271 }
272 }
273
274 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
275 self.servers.get(id).map(ContextServerStatus::from_state)
276 }
277
278 pub fn configuration_for_server(
279 &self,
280 id: &ContextServerId,
281 ) -> Option<Arc<ContextServerConfiguration>> {
282 self.servers.get(id).map(|state| state.configuration())
283 }
284
285 pub fn server_ids(&self, cx: &App) -> HashSet<ContextServerId> {
286 self.servers
287 .keys()
288 .cloned()
289 .chain(
290 self.registry
291 .read(cx)
292 .context_server_descriptors()
293 .into_iter()
294 .map(|(id, _)| ContextServerId(id)),
295 )
296 .collect()
297 }
298
299 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
300 self.servers
301 .values()
302 .filter_map(|state| {
303 if let ContextServerState::Running { server, .. } = state {
304 Some(server.clone())
305 } else {
306 None
307 }
308 })
309 .collect()
310 }
311
312 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
313 cx.spawn(async move |this, cx| {
314 let this = this.upgrade().context("Context server store dropped")?;
315 let settings = this
316 .update(cx, |this, _| {
317 this.context_server_settings.get(&server.id().0).cloned()
318 })
319 .ok()
320 .flatten()
321 .context("Failed to get context server settings")?;
322
323 if !settings.enabled() {
324 return Ok(());
325 }
326
327 let (registry, worktree_store) = this.update(cx, |this, _| {
328 (this.registry.clone(), this.worktree_store.clone())
329 })?;
330 let configuration = ContextServerConfiguration::from_settings(
331 settings,
332 server.id(),
333 registry,
334 worktree_store,
335 cx,
336 )
337 .await
338 .context("Failed to create context server configuration")?;
339
340 this.update(cx, |this, cx| {
341 this.run_server(server, Arc::new(configuration), cx)
342 })
343 })
344 .detach_and_log_err(cx);
345 }
346
347 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
348 if matches!(
349 self.servers.get(id),
350 Some(ContextServerState::Stopped { .. })
351 ) {
352 return Ok(());
353 }
354
355 let state = self
356 .servers
357 .remove(id)
358 .context("Context server not found")?;
359
360 let server = state.server();
361 let configuration = state.configuration();
362 let mut result = Ok(());
363 if let ContextServerState::Running { server, .. } = &state {
364 result = server.stop();
365 }
366 drop(state);
367
368 self.update_server_state(
369 id.clone(),
370 ContextServerState::Stopped {
371 configuration,
372 server,
373 },
374 cx,
375 );
376
377 result
378 }
379
380 pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
381 if let Some(state) = self.servers.get(id) {
382 let configuration = state.configuration();
383
384 self.stop_server(&state.server().id(), cx)?;
385 let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
386 self.run_server(new_server, configuration, cx);
387 }
388 Ok(())
389 }
390
391 fn run_server(
392 &mut self,
393 server: Arc<ContextServer>,
394 configuration: Arc<ContextServerConfiguration>,
395 cx: &mut Context<Self>,
396 ) {
397 let id = server.id();
398 if matches!(
399 self.servers.get(&id),
400 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
401 ) {
402 self.stop_server(&id, cx).log_err();
403 }
404
405 let task = cx.spawn({
406 let id = server.id();
407 let server = server.clone();
408 let configuration = configuration.clone();
409 async move |this, cx| {
410 match server.clone().start(cx).await {
411 Ok(_) => {
412 debug_assert!(server.client().is_some());
413
414 this.update(cx, |this, cx| {
415 this.update_server_state(
416 id.clone(),
417 ContextServerState::Running {
418 server,
419 configuration,
420 },
421 cx,
422 )
423 })
424 .log_err()
425 }
426 Err(err) => {
427 log::error!("{} context server failed to start: {}", id, err);
428 this.update(cx, |this, cx| {
429 this.update_server_state(
430 id.clone(),
431 ContextServerState::Error {
432 configuration,
433 server,
434 error: err.to_string().into(),
435 },
436 cx,
437 )
438 })
439 .log_err()
440 }
441 };
442 }
443 });
444
445 self.update_server_state(
446 id.clone(),
447 ContextServerState::Starting {
448 configuration,
449 _task: task,
450 server,
451 },
452 cx,
453 );
454 }
455
456 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
457 let state = self
458 .servers
459 .remove(id)
460 .context("Context server not found")?;
461 drop(state);
462 cx.emit(Event::ServerStatusChanged {
463 server_id: id.clone(),
464 status: ContextServerStatus::Stopped,
465 });
466 Ok(())
467 }
468
469 fn create_context_server(
470 &self,
471 id: ContextServerId,
472 configuration: Arc<ContextServerConfiguration>,
473 cx: &mut Context<Self>,
474 ) -> Arc<ContextServer> {
475 let root_path = self
476 .project
477 .read_with(cx, |project, cx| project.active_project_directory(cx))
478 .ok()
479 .flatten()
480 .or_else(|| {
481 self.worktree_store.read_with(cx, |store, cx| {
482 store.visible_worktrees(cx).fold(None, |acc, item| {
483 if acc.is_none() {
484 item.read(cx).root_dir()
485 } else {
486 acc
487 }
488 })
489 })
490 });
491
492 if let Some(factory) = self.context_server_factory.as_ref() {
493 factory(id, configuration)
494 } else {
495 Arc::new(ContextServer::stdio(
496 id,
497 configuration.command().clone(),
498 root_path,
499 ))
500 }
501 }
502
503 fn resolve_context_server_settings<'a>(
504 worktree_store: &'a Entity<WorktreeStore>,
505 cx: &'a App,
506 ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
507 let location = worktree_store
508 .read(cx)
509 .visible_worktrees(cx)
510 .next()
511 .map(|worktree| settings::SettingsLocation {
512 worktree_id: worktree.read(cx).id(),
513 path: RelPath::empty(),
514 });
515 &ProjectSettings::get(location, cx).context_servers
516 }
517
518 fn update_server_state(
519 &mut self,
520 id: ContextServerId,
521 state: ContextServerState,
522 cx: &mut Context<Self>,
523 ) {
524 let status = ContextServerStatus::from_state(&state);
525 self.servers.insert(id.clone(), state);
526 cx.emit(Event::ServerStatusChanged {
527 server_id: id,
528 status,
529 });
530 }
531
532 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
533 if self.update_servers_task.is_some() {
534 self.needs_server_update = true;
535 } else {
536 self.needs_server_update = false;
537 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
538 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
539 log::error!("Error maintaining context servers: {}", err);
540 }
541
542 this.update(cx, |this, cx| {
543 this.update_servers_task.take();
544 if this.needs_server_update {
545 this.available_context_servers_changed(cx);
546 }
547 })?;
548
549 Ok(())
550 }));
551 }
552 }
553
554 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
555 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
556 (
557 this.context_server_settings.clone(),
558 this.registry.clone(),
559 this.worktree_store.clone(),
560 )
561 })?;
562
563 for (id, _) in
564 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
565 {
566 configured_servers
567 .entry(id)
568 .or_insert(ContextServerSettings::default_extension());
569 }
570
571 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
572 configured_servers
573 .into_iter()
574 .partition(|(_, settings)| settings.enabled());
575
576 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
577 let id = ContextServerId(id);
578 ContextServerConfiguration::from_settings(
579 settings,
580 id.clone(),
581 registry.clone(),
582 worktree_store.clone(),
583 cx,
584 )
585 .map(|config| (id, config))
586 }))
587 .await
588 .into_iter()
589 .filter_map(|(id, config)| config.map(|config| (id, config)))
590 .collect::<HashMap<_, _>>();
591
592 let mut servers_to_start = Vec::new();
593 let mut servers_to_remove = HashSet::default();
594 let mut servers_to_stop = HashSet::default();
595
596 this.update(cx, |this, cx| {
597 for server_id in this.servers.keys() {
598 // All servers that are not in desired_servers should be removed from the store.
599 // This can happen if the user removed a server from the context server settings.
600 if !configured_servers.contains_key(server_id) {
601 if disabled_servers.contains_key(&server_id.0) {
602 servers_to_stop.insert(server_id.clone());
603 } else {
604 servers_to_remove.insert(server_id.clone());
605 }
606 }
607 }
608
609 for (id, config) in configured_servers {
610 let state = this.servers.get(&id);
611 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
612 let existing_config = state.as_ref().map(|state| state.configuration());
613 if existing_config.as_deref() != Some(&config) || is_stopped {
614 let config = Arc::new(config);
615 let server = this.create_context_server(id.clone(), config.clone(), cx);
616 servers_to_start.push((server, config));
617 if this.servers.contains_key(&id) {
618 servers_to_stop.insert(id);
619 }
620 }
621 }
622 })?;
623
624 this.update(cx, |this, cx| {
625 for id in servers_to_stop {
626 this.stop_server(&id, cx)?;
627 }
628 for id in servers_to_remove {
629 this.remove_server(&id, cx)?;
630 }
631 for (server, config) in servers_to_start {
632 this.run_server(server, config, cx);
633 }
634 anyhow::Ok(())
635 })?
636 }
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642 use crate::{
643 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
644 project_settings::ProjectSettings,
645 };
646 use context_server::test::create_fake_transport;
647 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
648 use serde_json::json;
649 use std::{cell::RefCell, path::PathBuf, rc::Rc};
650 use util::path;
651
652 #[gpui::test]
653 async fn test_context_server_status(cx: &mut TestAppContext) {
654 const SERVER_1_ID: &str = "mcp-1";
655 const SERVER_2_ID: &str = "mcp-2";
656
657 let (_fs, project) = setup_context_server_test(
658 cx,
659 json!({"code.rs": ""}),
660 vec![
661 (SERVER_1_ID.into(), dummy_server_settings()),
662 (SERVER_2_ID.into(), dummy_server_settings()),
663 ],
664 )
665 .await;
666
667 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
668 let store = cx.new(|cx| {
669 ContextServerStore::test(
670 registry.clone(),
671 project.read(cx).worktree_store(),
672 project.downgrade(),
673 cx,
674 )
675 });
676
677 let server_1_id = ContextServerId(SERVER_1_ID.into());
678 let server_2_id = ContextServerId(SERVER_2_ID.into());
679
680 let server_1 = Arc::new(ContextServer::new(
681 server_1_id.clone(),
682 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
683 ));
684 let server_2 = Arc::new(ContextServer::new(
685 server_2_id.clone(),
686 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
687 ));
688
689 store.update(cx, |store, cx| store.start_server(server_1, cx));
690
691 cx.run_until_parked();
692
693 cx.update(|cx| {
694 assert_eq!(
695 store.read(cx).status_for_server(&server_1_id),
696 Some(ContextServerStatus::Running)
697 );
698 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
699 });
700
701 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
702
703 cx.run_until_parked();
704
705 cx.update(|cx| {
706 assert_eq!(
707 store.read(cx).status_for_server(&server_1_id),
708 Some(ContextServerStatus::Running)
709 );
710 assert_eq!(
711 store.read(cx).status_for_server(&server_2_id),
712 Some(ContextServerStatus::Running)
713 );
714 });
715
716 store
717 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
718 .unwrap();
719
720 cx.update(|cx| {
721 assert_eq!(
722 store.read(cx).status_for_server(&server_1_id),
723 Some(ContextServerStatus::Running)
724 );
725 assert_eq!(
726 store.read(cx).status_for_server(&server_2_id),
727 Some(ContextServerStatus::Stopped)
728 );
729 });
730 }
731
732 #[gpui::test]
733 async fn test_context_server_status_events(cx: &mut TestAppContext) {
734 const SERVER_1_ID: &str = "mcp-1";
735 const SERVER_2_ID: &str = "mcp-2";
736
737 let (_fs, project) = setup_context_server_test(
738 cx,
739 json!({"code.rs": ""}),
740 vec![
741 (SERVER_1_ID.into(), dummy_server_settings()),
742 (SERVER_2_ID.into(), dummy_server_settings()),
743 ],
744 )
745 .await;
746
747 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
748 let store = cx.new(|cx| {
749 ContextServerStore::test(
750 registry.clone(),
751 project.read(cx).worktree_store(),
752 project.downgrade(),
753 cx,
754 )
755 });
756
757 let server_1_id = ContextServerId(SERVER_1_ID.into());
758 let server_2_id = ContextServerId(SERVER_2_ID.into());
759
760 let server_1 = Arc::new(ContextServer::new(
761 server_1_id.clone(),
762 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
763 ));
764 let server_2 = Arc::new(ContextServer::new(
765 server_2_id.clone(),
766 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
767 ));
768
769 let _server_events = assert_server_events(
770 &store,
771 vec![
772 (server_1_id.clone(), ContextServerStatus::Starting),
773 (server_1_id, ContextServerStatus::Running),
774 (server_2_id.clone(), ContextServerStatus::Starting),
775 (server_2_id.clone(), ContextServerStatus::Running),
776 (server_2_id.clone(), ContextServerStatus::Stopped),
777 ],
778 cx,
779 );
780
781 store.update(cx, |store, cx| store.start_server(server_1, cx));
782
783 cx.run_until_parked();
784
785 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
786
787 cx.run_until_parked();
788
789 store
790 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
791 .unwrap();
792 }
793
794 #[gpui::test(iterations = 25)]
795 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
796 const SERVER_1_ID: &str = "mcp-1";
797
798 let (_fs, project) = setup_context_server_test(
799 cx,
800 json!({"code.rs": ""}),
801 vec![(SERVER_1_ID.into(), dummy_server_settings())],
802 )
803 .await;
804
805 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
806 let store = cx.new(|cx| {
807 ContextServerStore::test(
808 registry.clone(),
809 project.read(cx).worktree_store(),
810 project.downgrade(),
811 cx,
812 )
813 });
814
815 let server_id = ContextServerId(SERVER_1_ID.into());
816
817 let server_with_same_id_1 = Arc::new(ContextServer::new(
818 server_id.clone(),
819 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
820 ));
821 let server_with_same_id_2 = Arc::new(ContextServer::new(
822 server_id.clone(),
823 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
824 ));
825
826 // If we start another server with the same id, we should report that we stopped the previous one
827 let _server_events = assert_server_events(
828 &store,
829 vec![
830 (server_id.clone(), ContextServerStatus::Starting),
831 (server_id.clone(), ContextServerStatus::Stopped),
832 (server_id.clone(), ContextServerStatus::Starting),
833 (server_id.clone(), ContextServerStatus::Running),
834 ],
835 cx,
836 );
837
838 store.update(cx, |store, cx| {
839 store.start_server(server_with_same_id_1.clone(), cx)
840 });
841 store.update(cx, |store, cx| {
842 store.start_server(server_with_same_id_2.clone(), cx)
843 });
844
845 cx.run_until_parked();
846
847 cx.update(|cx| {
848 assert_eq!(
849 store.read(cx).status_for_server(&server_id),
850 Some(ContextServerStatus::Running)
851 );
852 });
853 }
854
855 #[gpui::test]
856 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
857 const SERVER_1_ID: &str = "mcp-1";
858 const SERVER_2_ID: &str = "mcp-2";
859
860 let server_1_id = ContextServerId(SERVER_1_ID.into());
861 let server_2_id = ContextServerId(SERVER_2_ID.into());
862
863 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
864
865 let (_fs, project) = setup_context_server_test(
866 cx,
867 json!({"code.rs": ""}),
868 vec![(
869 SERVER_1_ID.into(),
870 ContextServerSettings::Extension {
871 enabled: true,
872 settings: json!({
873 "somevalue": true
874 }),
875 },
876 )],
877 )
878 .await;
879
880 let executor = cx.executor();
881 let registry = cx.new(|cx| {
882 let mut registry = ContextServerDescriptorRegistry::new();
883 registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
884 registry
885 });
886 let store = cx.new(|cx| {
887 ContextServerStore::test_maintain_server_loop(
888 Box::new(move |id, _| {
889 Arc::new(ContextServer::new(
890 id.clone(),
891 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
892 ))
893 }),
894 registry.clone(),
895 project.read(cx).worktree_store(),
896 project.downgrade(),
897 cx,
898 )
899 });
900
901 // Ensure that mcp-1 starts up
902 {
903 let _server_events = assert_server_events(
904 &store,
905 vec![
906 (server_1_id.clone(), ContextServerStatus::Starting),
907 (server_1_id.clone(), ContextServerStatus::Running),
908 ],
909 cx,
910 );
911 cx.run_until_parked();
912 }
913
914 // Ensure that mcp-1 is restarted when the configuration was changed
915 {
916 let _server_events = assert_server_events(
917 &store,
918 vec![
919 (server_1_id.clone(), ContextServerStatus::Stopped),
920 (server_1_id.clone(), ContextServerStatus::Starting),
921 (server_1_id.clone(), ContextServerStatus::Running),
922 ],
923 cx,
924 );
925 set_context_server_configuration(
926 vec![(
927 server_1_id.0.clone(),
928 settings::ContextServerSettingsContent::Extension {
929 enabled: true,
930 settings: json!({
931 "somevalue": false
932 }),
933 },
934 )],
935 cx,
936 );
937
938 cx.run_until_parked();
939 }
940
941 // Ensure that mcp-1 is not restarted when the configuration was not changed
942 {
943 let _server_events = assert_server_events(&store, vec![], cx);
944 set_context_server_configuration(
945 vec![(
946 server_1_id.0.clone(),
947 settings::ContextServerSettingsContent::Extension {
948 enabled: true,
949 settings: json!({
950 "somevalue": false
951 }),
952 },
953 )],
954 cx,
955 );
956
957 cx.run_until_parked();
958 }
959
960 // Ensure that mcp-2 is started once it is added to the settings
961 {
962 let _server_events = assert_server_events(
963 &store,
964 vec![
965 (server_2_id.clone(), ContextServerStatus::Starting),
966 (server_2_id.clone(), ContextServerStatus::Running),
967 ],
968 cx,
969 );
970 set_context_server_configuration(
971 vec![
972 (
973 server_1_id.0.clone(),
974 settings::ContextServerSettingsContent::Extension {
975 enabled: true,
976 settings: json!({
977 "somevalue": false
978 }),
979 },
980 ),
981 (
982 server_2_id.0.clone(),
983 settings::ContextServerSettingsContent::Custom {
984 enabled: true,
985 command: ContextServerCommand {
986 path: "somebinary".into(),
987 args: vec!["arg".to_string()],
988 env: None,
989 timeout: None,
990 },
991 },
992 ),
993 ],
994 cx,
995 );
996
997 cx.run_until_parked();
998 }
999
1000 // Ensure that mcp-2 is restarted once the args have changed
1001 {
1002 let _server_events = assert_server_events(
1003 &store,
1004 vec![
1005 (server_2_id.clone(), ContextServerStatus::Stopped),
1006 (server_2_id.clone(), ContextServerStatus::Starting),
1007 (server_2_id.clone(), ContextServerStatus::Running),
1008 ],
1009 cx,
1010 );
1011 set_context_server_configuration(
1012 vec![
1013 (
1014 server_1_id.0.clone(),
1015 settings::ContextServerSettingsContent::Extension {
1016 enabled: true,
1017 settings: json!({
1018 "somevalue": false
1019 }),
1020 },
1021 ),
1022 (
1023 server_2_id.0.clone(),
1024 settings::ContextServerSettingsContent::Custom {
1025 enabled: true,
1026 command: ContextServerCommand {
1027 path: "somebinary".into(),
1028 args: vec!["anotherArg".to_string()],
1029 env: None,
1030 timeout: None,
1031 },
1032 },
1033 ),
1034 ],
1035 cx,
1036 );
1037
1038 cx.run_until_parked();
1039 }
1040
1041 // Ensure that mcp-2 is removed once it is removed from the settings
1042 {
1043 let _server_events = assert_server_events(
1044 &store,
1045 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1046 cx,
1047 );
1048 set_context_server_configuration(
1049 vec![(
1050 server_1_id.0.clone(),
1051 settings::ContextServerSettingsContent::Extension {
1052 enabled: true,
1053 settings: json!({
1054 "somevalue": false
1055 }),
1056 },
1057 )],
1058 cx,
1059 );
1060
1061 cx.run_until_parked();
1062
1063 cx.update(|cx| {
1064 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1065 });
1066 }
1067
1068 // Ensure that nothing happens if the settings do not change
1069 {
1070 let _server_events = assert_server_events(&store, vec![], cx);
1071 set_context_server_configuration(
1072 vec![(
1073 server_1_id.0.clone(),
1074 settings::ContextServerSettingsContent::Extension {
1075 enabled: true,
1076 settings: json!({
1077 "somevalue": false
1078 }),
1079 },
1080 )],
1081 cx,
1082 );
1083
1084 cx.run_until_parked();
1085
1086 cx.update(|cx| {
1087 assert_eq!(
1088 store.read(cx).status_for_server(&server_1_id),
1089 Some(ContextServerStatus::Running)
1090 );
1091 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1092 });
1093 }
1094 }
1095
1096 #[gpui::test]
1097 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1098 const SERVER_1_ID: &str = "mcp-1";
1099
1100 let server_1_id = ContextServerId(SERVER_1_ID.into());
1101
1102 let (_fs, project) = setup_context_server_test(
1103 cx,
1104 json!({"code.rs": ""}),
1105 vec![(
1106 SERVER_1_ID.into(),
1107 ContextServerSettings::Custom {
1108 enabled: true,
1109 command: ContextServerCommand {
1110 path: "somebinary".into(),
1111 args: vec!["arg".to_string()],
1112 env: None,
1113 timeout: None,
1114 },
1115 },
1116 )],
1117 )
1118 .await;
1119
1120 let executor = cx.executor();
1121 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1122 let store = cx.new(|cx| {
1123 ContextServerStore::test_maintain_server_loop(
1124 Box::new(move |id, _| {
1125 Arc::new(ContextServer::new(
1126 id.clone(),
1127 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1128 ))
1129 }),
1130 registry.clone(),
1131 project.read(cx).worktree_store(),
1132 project.downgrade(),
1133 cx,
1134 )
1135 });
1136
1137 // Ensure that mcp-1 starts up
1138 {
1139 let _server_events = assert_server_events(
1140 &store,
1141 vec![
1142 (server_1_id.clone(), ContextServerStatus::Starting),
1143 (server_1_id.clone(), ContextServerStatus::Running),
1144 ],
1145 cx,
1146 );
1147 cx.run_until_parked();
1148 }
1149
1150 // Ensure that mcp-1 is stopped once it is disabled.
1151 {
1152 let _server_events = assert_server_events(
1153 &store,
1154 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1155 cx,
1156 );
1157 set_context_server_configuration(
1158 vec![(
1159 server_1_id.0.clone(),
1160 settings::ContextServerSettingsContent::Custom {
1161 enabled: false,
1162 command: ContextServerCommand {
1163 path: "somebinary".into(),
1164 args: vec!["arg".to_string()],
1165 env: None,
1166 timeout: None,
1167 },
1168 },
1169 )],
1170 cx,
1171 );
1172
1173 cx.run_until_parked();
1174 }
1175
1176 // Ensure that mcp-1 is started once it is enabled again.
1177 {
1178 let _server_events = assert_server_events(
1179 &store,
1180 vec![
1181 (server_1_id.clone(), ContextServerStatus::Starting),
1182 (server_1_id.clone(), ContextServerStatus::Running),
1183 ],
1184 cx,
1185 );
1186 set_context_server_configuration(
1187 vec![(
1188 server_1_id.0.clone(),
1189 settings::ContextServerSettingsContent::Custom {
1190 enabled: true,
1191 command: ContextServerCommand {
1192 path: "somebinary".into(),
1193 args: vec!["arg".to_string()],
1194 timeout: None,
1195 env: None,
1196 },
1197 },
1198 )],
1199 cx,
1200 );
1201
1202 cx.run_until_parked();
1203 }
1204 }
1205
1206 fn set_context_server_configuration(
1207 context_servers: Vec<(Arc<str>, settings::ContextServerSettingsContent)>,
1208 cx: &mut TestAppContext,
1209 ) {
1210 cx.update(|cx| {
1211 SettingsStore::update_global(cx, |store, cx| {
1212 store.update_user_settings(cx, |content| {
1213 content.project.context_servers.clear();
1214 for (id, config) in context_servers {
1215 content.project.context_servers.insert(id, config);
1216 }
1217 });
1218 })
1219 });
1220 }
1221
1222 struct ServerEvents {
1223 received_event_count: Rc<RefCell<usize>>,
1224 expected_event_count: usize,
1225 _subscription: Subscription,
1226 }
1227
1228 impl Drop for ServerEvents {
1229 fn drop(&mut self) {
1230 let actual_event_count = *self.received_event_count.borrow();
1231 assert_eq!(
1232 actual_event_count, self.expected_event_count,
1233 "
1234 Expected to receive {} context server store events, but received {} events",
1235 self.expected_event_count, actual_event_count
1236 );
1237 }
1238 }
1239
1240 fn dummy_server_settings() -> ContextServerSettings {
1241 ContextServerSettings::Custom {
1242 enabled: true,
1243 command: ContextServerCommand {
1244 path: "somebinary".into(),
1245 args: vec!["arg".to_string()],
1246 env: None,
1247 timeout: None,
1248 },
1249 }
1250 }
1251
1252 fn assert_server_events(
1253 store: &Entity<ContextServerStore>,
1254 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1255 cx: &mut TestAppContext,
1256 ) -> ServerEvents {
1257 cx.update(|cx| {
1258 let mut ix = 0;
1259 let received_event_count = Rc::new(RefCell::new(0));
1260 let expected_event_count = expected_events.len();
1261 let subscription = cx.subscribe(store, {
1262 let received_event_count = received_event_count.clone();
1263 move |_, event, _| match event {
1264 Event::ServerStatusChanged {
1265 server_id: actual_server_id,
1266 status: actual_status,
1267 } => {
1268 let (expected_server_id, expected_status) = &expected_events[ix];
1269
1270 assert_eq!(
1271 actual_server_id, expected_server_id,
1272 "Expected different server id at index {}",
1273 ix
1274 );
1275 assert_eq!(
1276 actual_status, expected_status,
1277 "Expected different status at index {}",
1278 ix
1279 );
1280 ix += 1;
1281 *received_event_count.borrow_mut() += 1;
1282 }
1283 }
1284 });
1285 ServerEvents {
1286 expected_event_count,
1287 received_event_count,
1288 _subscription: subscription,
1289 }
1290 })
1291 }
1292
1293 async fn setup_context_server_test(
1294 cx: &mut TestAppContext,
1295 files: serde_json::Value,
1296 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1297 ) -> (Arc<FakeFs>, Entity<Project>) {
1298 cx.update(|cx| {
1299 let settings_store = SettingsStore::test(cx);
1300 cx.set_global(settings_store);
1301 Project::init_settings(cx);
1302 let mut settings = ProjectSettings::get_global(cx).clone();
1303 for (id, config) in context_server_configurations {
1304 settings.context_servers.insert(id, config);
1305 }
1306 ProjectSettings::override_global(settings, cx);
1307 });
1308
1309 let fs = FakeFs::new(cx.executor());
1310 fs.insert_tree(path!("/test"), files).await;
1311 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1312
1313 (fs, project)
1314 }
1315
1316 struct FakeContextServerDescriptor {
1317 path: PathBuf,
1318 }
1319
1320 impl FakeContextServerDescriptor {
1321 fn new(path: impl Into<PathBuf>) -> Self {
1322 Self { path: path.into() }
1323 }
1324 }
1325
1326 impl ContextServerDescriptor for FakeContextServerDescriptor {
1327 fn command(
1328 &self,
1329 _worktree_store: Entity<WorktreeStore>,
1330 _cx: &AsyncApp,
1331 ) -> Task<Result<ContextServerCommand>> {
1332 Task::ready(Ok(ContextServerCommand {
1333 path: self.path.clone(),
1334 args: vec!["arg1".to_string(), "arg2".to_string()],
1335 env: None,
1336 timeout: None,
1337 }))
1338 }
1339
1340 fn configuration(
1341 &self,
1342 _worktree_store: Entity<WorktreeStore>,
1343 _cx: &AsyncApp,
1344 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1345 Task::ready(Ok(None))
1346 }
1347 }
1348}