1pub mod extension;
2pub mod registry;
3
4use std::{path::Path, sync::Arc};
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
9use futures::{FutureExt as _, future::join_all};
10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
11use registry::ContextServerDescriptorRegistry;
12use settings::{Settings as _, SettingsStore};
13use util::ResultExt as _;
14
15use crate::{
16 project_settings::{ContextServerSettings, ProjectSettings},
17 worktree_store::WorktreeStore,
18};
19
20pub fn init(cx: &mut App) {
21 extension::init(cx);
22}
23
24actions!(
25 context_server,
26 [
27 /// Restarts the context server.
28 Restart
29 ]
30);
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33pub enum ContextServerStatus {
34 Starting,
35 Running,
36 Stopped,
37 Error(Arc<str>),
38}
39
40impl ContextServerStatus {
41 fn from_state(state: &ContextServerState) -> Self {
42 match state {
43 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
44 ContextServerState::Running { .. } => ContextServerStatus::Running,
45 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
46 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
47 }
48 }
49}
50
51enum ContextServerState {
52 Starting {
53 server: Arc<ContextServer>,
54 configuration: Arc<ContextServerConfiguration>,
55 _task: Task<()>,
56 },
57 Running {
58 server: Arc<ContextServer>,
59 configuration: Arc<ContextServerConfiguration>,
60 },
61 Stopped {
62 server: Arc<ContextServer>,
63 configuration: Arc<ContextServerConfiguration>,
64 },
65 Error {
66 server: Arc<ContextServer>,
67 configuration: Arc<ContextServerConfiguration>,
68 error: Arc<str>,
69 },
70}
71
72impl ContextServerState {
73 pub fn server(&self) -> Arc<ContextServer> {
74 match self {
75 ContextServerState::Starting { server, .. } => server.clone(),
76 ContextServerState::Running { server, .. } => server.clone(),
77 ContextServerState::Stopped { server, .. } => server.clone(),
78 ContextServerState::Error { server, .. } => server.clone(),
79 }
80 }
81
82 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
83 match self {
84 ContextServerState::Starting { configuration, .. } => configuration.clone(),
85 ContextServerState::Running { configuration, .. } => configuration.clone(),
86 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
87 ContextServerState::Error { configuration, .. } => configuration.clone(),
88 }
89 }
90}
91
92#[derive(Debug, PartialEq, Eq)]
93pub enum ContextServerConfiguration {
94 Custom {
95 command: ContextServerCommand,
96 },
97 Extension {
98 command: ContextServerCommand,
99 settings: serde_json::Value,
100 },
101}
102
103impl ContextServerConfiguration {
104 pub fn command(&self) -> &ContextServerCommand {
105 match self {
106 ContextServerConfiguration::Custom { command } => command,
107 ContextServerConfiguration::Extension { command, .. } => command,
108 }
109 }
110
111 pub async fn from_settings(
112 settings: ContextServerSettings,
113 id: ContextServerId,
114 registry: Entity<ContextServerDescriptorRegistry>,
115 worktree_store: Entity<WorktreeStore>,
116 cx: &AsyncApp,
117 ) -> Option<Self> {
118 match settings {
119 ContextServerSettings::Custom {
120 enabled: _,
121 command,
122 } => Some(ContextServerConfiguration::Custom { command }),
123 ContextServerSettings::Extension {
124 enabled: _,
125 settings,
126 } => {
127 let descriptor = cx
128 .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
129 .ok()
130 .flatten()?;
131
132 let command = descriptor.command(worktree_store, cx).await.log_err()?;
133
134 Some(ContextServerConfiguration::Extension { command, settings })
135 }
136 }
137 }
138}
139
140pub type ContextServerFactory =
141 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
142
143pub struct ContextServerStore {
144 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
145 servers: HashMap<ContextServerId, ContextServerState>,
146 worktree_store: Entity<WorktreeStore>,
147 registry: Entity<ContextServerDescriptorRegistry>,
148 update_servers_task: Option<Task<Result<()>>>,
149 context_server_factory: Option<ContextServerFactory>,
150 needs_server_update: bool,
151 _subscriptions: Vec<Subscription>,
152}
153
154pub enum Event {
155 ServerStatusChanged {
156 server_id: ContextServerId,
157 status: ContextServerStatus,
158 },
159}
160
161impl EventEmitter<Event> for ContextServerStore {}
162
163impl ContextServerStore {
164 pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
165 Self::new_internal(
166 true,
167 None,
168 ContextServerDescriptorRegistry::default_global(cx),
169 worktree_store,
170 cx,
171 )
172 }
173
174 #[cfg(any(test, feature = "test-support"))]
175 pub fn test(
176 registry: Entity<ContextServerDescriptorRegistry>,
177 worktree_store: Entity<WorktreeStore>,
178 cx: &mut Context<Self>,
179 ) -> Self {
180 Self::new_internal(false, None, registry, worktree_store, cx)
181 }
182
183 #[cfg(any(test, feature = "test-support"))]
184 pub fn test_maintain_server_loop(
185 context_server_factory: ContextServerFactory,
186 registry: Entity<ContextServerDescriptorRegistry>,
187 worktree_store: Entity<WorktreeStore>,
188 cx: &mut Context<Self>,
189 ) -> Self {
190 Self::new_internal(
191 true,
192 Some(context_server_factory),
193 registry,
194 worktree_store,
195 cx,
196 )
197 }
198
199 fn new_internal(
200 maintain_server_loop: bool,
201 context_server_factory: Option<ContextServerFactory>,
202 registry: Entity<ContextServerDescriptorRegistry>,
203 worktree_store: Entity<WorktreeStore>,
204 cx: &mut Context<Self>,
205 ) -> Self {
206 let subscriptions = if maintain_server_loop {
207 vec![
208 cx.observe(®istry, |this, _registry, cx| {
209 this.available_context_servers_changed(cx);
210 }),
211 cx.observe_global::<SettingsStore>(|this, cx| {
212 let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
213 if &this.context_server_settings == settings {
214 return;
215 }
216 this.context_server_settings = settings.clone();
217 this.available_context_servers_changed(cx);
218 }),
219 ]
220 } else {
221 Vec::new()
222 };
223
224 let mut this = Self {
225 _subscriptions: subscriptions,
226 context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
227 .clone(),
228 worktree_store,
229 registry,
230 needs_server_update: false,
231 servers: HashMap::default(),
232 update_servers_task: None,
233 context_server_factory,
234 };
235 if maintain_server_loop {
236 this.available_context_servers_changed(cx);
237 }
238 this
239 }
240
241 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
242 self.servers.get(id).map(|state| state.server())
243 }
244
245 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
246 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
247 Some(server.clone())
248 } else {
249 None
250 }
251 }
252
253 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
254 self.servers.get(id).map(ContextServerStatus::from_state)
255 }
256
257 pub fn configuration_for_server(
258 &self,
259 id: &ContextServerId,
260 ) -> Option<Arc<ContextServerConfiguration>> {
261 self.servers.get(id).map(|state| state.configuration())
262 }
263
264 pub fn all_server_ids(&self) -> Vec<ContextServerId> {
265 self.servers.keys().cloned().collect()
266 }
267
268 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
269 self.servers
270 .values()
271 .filter_map(|state| {
272 if let ContextServerState::Running { server, .. } = state {
273 Some(server.clone())
274 } else {
275 None
276 }
277 })
278 .collect()
279 }
280
281 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
282 cx.spawn(async move |this, cx| {
283 let this = this.upgrade().context("Context server store dropped")?;
284 let settings = this
285 .update(cx, |this, _| {
286 this.context_server_settings.get(&server.id().0).cloned()
287 })
288 .ok()
289 .flatten()
290 .context("Failed to get context server settings")?;
291
292 if !settings.enabled() {
293 return Ok(());
294 }
295
296 let (registry, worktree_store) = this.update(cx, |this, _| {
297 (this.registry.clone(), this.worktree_store.clone())
298 })?;
299 let configuration = ContextServerConfiguration::from_settings(
300 settings,
301 server.id(),
302 registry,
303 worktree_store,
304 cx,
305 )
306 .await
307 .context("Failed to create context server configuration")?;
308
309 this.update(cx, |this, cx| {
310 this.run_server(server, Arc::new(configuration), cx)
311 })
312 })
313 .detach_and_log_err(cx);
314 }
315
316 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
317 if matches!(
318 self.servers.get(id),
319 Some(ContextServerState::Stopped { .. })
320 ) {
321 return Ok(());
322 }
323
324 let state = self
325 .servers
326 .remove(id)
327 .context("Context server not found")?;
328
329 let server = state.server();
330 let configuration = state.configuration();
331 let mut result = Ok(());
332 if let ContextServerState::Running { server, .. } = &state {
333 result = server.stop();
334 }
335 drop(state);
336
337 self.update_server_state(
338 id.clone(),
339 ContextServerState::Stopped {
340 configuration,
341 server,
342 },
343 cx,
344 );
345
346 result
347 }
348
349 pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
350 if let Some(state) = self.servers.get(&id) {
351 let configuration = state.configuration();
352
353 self.stop_server(&state.server().id(), cx)?;
354 let new_server = self.create_context_server(id.clone(), configuration.clone())?;
355 self.run_server(new_server, configuration, cx);
356 }
357 Ok(())
358 }
359
360 fn run_server(
361 &mut self,
362 server: Arc<ContextServer>,
363 configuration: Arc<ContextServerConfiguration>,
364 cx: &mut Context<Self>,
365 ) {
366 let id = server.id();
367 if matches!(
368 self.servers.get(&id),
369 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
370 ) {
371 self.stop_server(&id, cx).log_err();
372 }
373
374 let task = cx.spawn({
375 let id = server.id();
376 let server = server.clone();
377 let configuration = configuration.clone();
378 async move |this, cx| {
379 match server.clone().start(&cx).await {
380 Ok(_) => {
381 log::info!("Started {} context server", id);
382 debug_assert!(server.client().is_some());
383
384 this.update(cx, |this, cx| {
385 this.update_server_state(
386 id.clone(),
387 ContextServerState::Running {
388 server,
389 configuration,
390 },
391 cx,
392 )
393 })
394 .log_err()
395 }
396 Err(err) => {
397 log::error!("{} context server failed to start: {}", id, err);
398 this.update(cx, |this, cx| {
399 this.update_server_state(
400 id.clone(),
401 ContextServerState::Error {
402 configuration,
403 server,
404 error: err.to_string().into(),
405 },
406 cx,
407 )
408 })
409 .log_err()
410 }
411 };
412 }
413 });
414
415 self.update_server_state(
416 id.clone(),
417 ContextServerState::Starting {
418 configuration,
419 _task: task,
420 server,
421 },
422 cx,
423 );
424 }
425
426 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
427 let state = self
428 .servers
429 .remove(id)
430 .context("Context server not found")?;
431 drop(state);
432 cx.emit(Event::ServerStatusChanged {
433 server_id: id.clone(),
434 status: ContextServerStatus::Stopped,
435 });
436 Ok(())
437 }
438
439 fn create_context_server(
440 &self,
441 id: ContextServerId,
442 configuration: Arc<ContextServerConfiguration>,
443 ) -> Result<Arc<ContextServer>> {
444 if let Some(factory) = self.context_server_factory.as_ref() {
445 Ok(factory(id, configuration))
446 } else {
447 Ok(Arc::new(ContextServer::stdio(
448 id,
449 configuration.command().clone(),
450 )))
451 }
452 }
453
454 fn resolve_context_server_settings<'a>(
455 worktree_store: &'a Entity<WorktreeStore>,
456 cx: &'a App,
457 ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
458 let location = worktree_store
459 .read(cx)
460 .visible_worktrees(cx)
461 .next()
462 .map(|worktree| settings::SettingsLocation {
463 worktree_id: worktree.read(cx).id(),
464 path: Path::new(""),
465 });
466 &ProjectSettings::get(location, cx).context_servers
467 }
468
469 fn update_server_state(
470 &mut self,
471 id: ContextServerId,
472 state: ContextServerState,
473 cx: &mut Context<Self>,
474 ) {
475 let status = ContextServerStatus::from_state(&state);
476 self.servers.insert(id.clone(), state);
477 cx.emit(Event::ServerStatusChanged {
478 server_id: id,
479 status,
480 });
481 }
482
483 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
484 if self.update_servers_task.is_some() {
485 self.needs_server_update = true;
486 } else {
487 self.needs_server_update = false;
488 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
489 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
490 log::error!("Error maintaining context servers: {}", err);
491 }
492
493 this.update(cx, |this, cx| {
494 this.update_servers_task.take();
495 if this.needs_server_update {
496 this.available_context_servers_changed(cx);
497 }
498 })?;
499
500 Ok(())
501 }));
502 }
503 }
504
505 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
506 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
507 (
508 this.context_server_settings.clone(),
509 this.registry.clone(),
510 this.worktree_store.clone(),
511 )
512 })?;
513
514 for (id, _) in
515 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
516 {
517 configured_servers
518 .entry(id)
519 .or_insert(ContextServerSettings::default_extension());
520 }
521
522 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
523 configured_servers
524 .into_iter()
525 .partition(|(_, settings)| settings.enabled());
526
527 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
528 let id = ContextServerId(id);
529 ContextServerConfiguration::from_settings(
530 settings,
531 id.clone(),
532 registry.clone(),
533 worktree_store.clone(),
534 cx,
535 )
536 .map(|config| (id, config))
537 }))
538 .await
539 .into_iter()
540 .filter_map(|(id, config)| config.map(|config| (id, config)))
541 .collect::<HashMap<_, _>>();
542
543 let mut servers_to_start = Vec::new();
544 let mut servers_to_remove = HashSet::default();
545 let mut servers_to_stop = HashSet::default();
546
547 this.update(cx, |this, _cx| {
548 for server_id in this.servers.keys() {
549 // All servers that are not in desired_servers should be removed from the store.
550 // This can happen if the user removed a server from the context server settings.
551 if !configured_servers.contains_key(&server_id) {
552 if disabled_servers.contains_key(&server_id.0) {
553 servers_to_stop.insert(server_id.clone());
554 } else {
555 servers_to_remove.insert(server_id.clone());
556 }
557 }
558 }
559
560 for (id, config) in configured_servers {
561 let state = this.servers.get(&id);
562 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
563 let existing_config = state.as_ref().map(|state| state.configuration());
564 if existing_config.as_deref() != Some(&config) || is_stopped {
565 let config = Arc::new(config);
566 if let Some(server) = this
567 .create_context_server(id.clone(), config.clone())
568 .log_err()
569 {
570 servers_to_start.push((server, config));
571 if this.servers.contains_key(&id) {
572 servers_to_stop.insert(id);
573 }
574 }
575 }
576 }
577 })?;
578
579 this.update(cx, |this, cx| {
580 for id in servers_to_stop {
581 this.stop_server(&id, cx)?;
582 }
583 for id in servers_to_remove {
584 this.remove_server(&id, cx)?;
585 }
586 for (server, config) in servers_to_start {
587 this.run_server(server, config, cx);
588 }
589 anyhow::Ok(())
590 })?
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597 use crate::{
598 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
599 project_settings::ProjectSettings,
600 };
601 use context_server::test::create_fake_transport;
602 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
603 use serde_json::json;
604 use std::{cell::RefCell, rc::Rc};
605 use util::path;
606
607 #[gpui::test]
608 async fn test_context_server_status(cx: &mut TestAppContext) {
609 const SERVER_1_ID: &'static str = "mcp-1";
610 const SERVER_2_ID: &'static str = "mcp-2";
611
612 let (_fs, project) = setup_context_server_test(
613 cx,
614 json!({"code.rs": ""}),
615 vec![
616 (SERVER_1_ID.into(), dummy_server_settings()),
617 (SERVER_2_ID.into(), dummy_server_settings()),
618 ],
619 )
620 .await;
621
622 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
623 let store = cx.new(|cx| {
624 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
625 });
626
627 let server_1_id = ContextServerId(SERVER_1_ID.into());
628 let server_2_id = ContextServerId(SERVER_2_ID.into());
629
630 let server_1 = Arc::new(ContextServer::new(
631 server_1_id.clone(),
632 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
633 ));
634 let server_2 = Arc::new(ContextServer::new(
635 server_2_id.clone(),
636 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
637 ));
638
639 store.update(cx, |store, cx| store.start_server(server_1, cx));
640
641 cx.run_until_parked();
642
643 cx.update(|cx| {
644 assert_eq!(
645 store.read(cx).status_for_server(&server_1_id),
646 Some(ContextServerStatus::Running)
647 );
648 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
649 });
650
651 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
652
653 cx.run_until_parked();
654
655 cx.update(|cx| {
656 assert_eq!(
657 store.read(cx).status_for_server(&server_1_id),
658 Some(ContextServerStatus::Running)
659 );
660 assert_eq!(
661 store.read(cx).status_for_server(&server_2_id),
662 Some(ContextServerStatus::Running)
663 );
664 });
665
666 store
667 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
668 .unwrap();
669
670 cx.update(|cx| {
671 assert_eq!(
672 store.read(cx).status_for_server(&server_1_id),
673 Some(ContextServerStatus::Running)
674 );
675 assert_eq!(
676 store.read(cx).status_for_server(&server_2_id),
677 Some(ContextServerStatus::Stopped)
678 );
679 });
680 }
681
682 #[gpui::test]
683 async fn test_context_server_status_events(cx: &mut TestAppContext) {
684 const SERVER_1_ID: &'static str = "mcp-1";
685 const SERVER_2_ID: &'static str = "mcp-2";
686
687 let (_fs, project) = setup_context_server_test(
688 cx,
689 json!({"code.rs": ""}),
690 vec![
691 (SERVER_1_ID.into(), dummy_server_settings()),
692 (SERVER_2_ID.into(), dummy_server_settings()),
693 ],
694 )
695 .await;
696
697 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
698 let store = cx.new(|cx| {
699 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
700 });
701
702 let server_1_id = ContextServerId(SERVER_1_ID.into());
703 let server_2_id = ContextServerId(SERVER_2_ID.into());
704
705 let server_1 = Arc::new(ContextServer::new(
706 server_1_id.clone(),
707 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
708 ));
709 let server_2 = Arc::new(ContextServer::new(
710 server_2_id.clone(),
711 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
712 ));
713
714 let _server_events = assert_server_events(
715 &store,
716 vec![
717 (server_1_id.clone(), ContextServerStatus::Starting),
718 (server_1_id.clone(), ContextServerStatus::Running),
719 (server_2_id.clone(), ContextServerStatus::Starting),
720 (server_2_id.clone(), ContextServerStatus::Running),
721 (server_2_id.clone(), ContextServerStatus::Stopped),
722 ],
723 cx,
724 );
725
726 store.update(cx, |store, cx| store.start_server(server_1, cx));
727
728 cx.run_until_parked();
729
730 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
731
732 cx.run_until_parked();
733
734 store
735 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
736 .unwrap();
737 }
738
739 #[gpui::test(iterations = 25)]
740 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
741 const SERVER_1_ID: &'static str = "mcp-1";
742
743 let (_fs, project) = setup_context_server_test(
744 cx,
745 json!({"code.rs": ""}),
746 vec![(SERVER_1_ID.into(), dummy_server_settings())],
747 )
748 .await;
749
750 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
751 let store = cx.new(|cx| {
752 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
753 });
754
755 let server_id = ContextServerId(SERVER_1_ID.into());
756
757 let server_with_same_id_1 = Arc::new(ContextServer::new(
758 server_id.clone(),
759 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
760 ));
761 let server_with_same_id_2 = Arc::new(ContextServer::new(
762 server_id.clone(),
763 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
764 ));
765
766 // If we start another server with the same id, we should report that we stopped the previous one
767 let _server_events = assert_server_events(
768 &store,
769 vec![
770 (server_id.clone(), ContextServerStatus::Starting),
771 (server_id.clone(), ContextServerStatus::Stopped),
772 (server_id.clone(), ContextServerStatus::Starting),
773 (server_id.clone(), ContextServerStatus::Running),
774 ],
775 cx,
776 );
777
778 store.update(cx, |store, cx| {
779 store.start_server(server_with_same_id_1.clone(), cx)
780 });
781 store.update(cx, |store, cx| {
782 store.start_server(server_with_same_id_2.clone(), cx)
783 });
784
785 cx.run_until_parked();
786
787 cx.update(|cx| {
788 assert_eq!(
789 store.read(cx).status_for_server(&server_id),
790 Some(ContextServerStatus::Running)
791 );
792 });
793 }
794
795 #[gpui::test]
796 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
797 const SERVER_1_ID: &'static str = "mcp-1";
798 const SERVER_2_ID: &'static str = "mcp-2";
799
800 let server_1_id = ContextServerId(SERVER_1_ID.into());
801 let server_2_id = ContextServerId(SERVER_2_ID.into());
802
803 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
804
805 let (_fs, project) = setup_context_server_test(
806 cx,
807 json!({"code.rs": ""}),
808 vec![(
809 SERVER_1_ID.into(),
810 ContextServerSettings::Extension {
811 enabled: true,
812 settings: json!({
813 "somevalue": true
814 }),
815 },
816 )],
817 )
818 .await;
819
820 let executor = cx.executor();
821 let registry = cx.new(|_| {
822 let mut registry = ContextServerDescriptorRegistry::new();
823 registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1);
824 registry
825 });
826 let store = cx.new(|cx| {
827 ContextServerStore::test_maintain_server_loop(
828 Box::new(move |id, _| {
829 Arc::new(ContextServer::new(
830 id.clone(),
831 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
832 ))
833 }),
834 registry.clone(),
835 project.read(cx).worktree_store(),
836 cx,
837 )
838 });
839
840 // Ensure that mcp-1 starts up
841 {
842 let _server_events = assert_server_events(
843 &store,
844 vec![
845 (server_1_id.clone(), ContextServerStatus::Starting),
846 (server_1_id.clone(), ContextServerStatus::Running),
847 ],
848 cx,
849 );
850 cx.run_until_parked();
851 }
852
853 // Ensure that mcp-1 is restarted when the configuration was changed
854 {
855 let _server_events = assert_server_events(
856 &store,
857 vec![
858 (server_1_id.clone(), ContextServerStatus::Stopped),
859 (server_1_id.clone(), ContextServerStatus::Starting),
860 (server_1_id.clone(), ContextServerStatus::Running),
861 ],
862 cx,
863 );
864 set_context_server_configuration(
865 vec![(
866 server_1_id.0.clone(),
867 ContextServerSettings::Extension {
868 enabled: true,
869 settings: json!({
870 "somevalue": false
871 }),
872 },
873 )],
874 cx,
875 );
876
877 cx.run_until_parked();
878 }
879
880 // Ensure that mcp-1 is not restarted when the configuration was not changed
881 {
882 let _server_events = assert_server_events(&store, vec![], cx);
883 set_context_server_configuration(
884 vec![(
885 server_1_id.0.clone(),
886 ContextServerSettings::Extension {
887 enabled: true,
888 settings: json!({
889 "somevalue": false
890 }),
891 },
892 )],
893 cx,
894 );
895
896 cx.run_until_parked();
897 }
898
899 // Ensure that mcp-2 is started once it is added to the settings
900 {
901 let _server_events = assert_server_events(
902 &store,
903 vec![
904 (server_2_id.clone(), ContextServerStatus::Starting),
905 (server_2_id.clone(), ContextServerStatus::Running),
906 ],
907 cx,
908 );
909 set_context_server_configuration(
910 vec![
911 (
912 server_1_id.0.clone(),
913 ContextServerSettings::Extension {
914 enabled: true,
915 settings: json!({
916 "somevalue": false
917 }),
918 },
919 ),
920 (
921 server_2_id.0.clone(),
922 ContextServerSettings::Custom {
923 enabled: true,
924 command: ContextServerCommand {
925 path: "somebinary".to_string(),
926 args: vec!["arg".to_string()],
927 env: None,
928 },
929 },
930 ),
931 ],
932 cx,
933 );
934
935 cx.run_until_parked();
936 }
937
938 // Ensure that mcp-2 is restarted once the args have changed
939 {
940 let _server_events = assert_server_events(
941 &store,
942 vec![
943 (server_2_id.clone(), ContextServerStatus::Stopped),
944 (server_2_id.clone(), ContextServerStatus::Starting),
945 (server_2_id.clone(), ContextServerStatus::Running),
946 ],
947 cx,
948 );
949 set_context_server_configuration(
950 vec![
951 (
952 server_1_id.0.clone(),
953 ContextServerSettings::Extension {
954 enabled: true,
955 settings: json!({
956 "somevalue": false
957 }),
958 },
959 ),
960 (
961 server_2_id.0.clone(),
962 ContextServerSettings::Custom {
963 enabled: true,
964 command: ContextServerCommand {
965 path: "somebinary".to_string(),
966 args: vec!["anotherArg".to_string()],
967 env: None,
968 },
969 },
970 ),
971 ],
972 cx,
973 );
974
975 cx.run_until_parked();
976 }
977
978 // Ensure that mcp-2 is removed once it is removed from the settings
979 {
980 let _server_events = assert_server_events(
981 &store,
982 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
983 cx,
984 );
985 set_context_server_configuration(
986 vec![(
987 server_1_id.0.clone(),
988 ContextServerSettings::Extension {
989 enabled: true,
990 settings: json!({
991 "somevalue": false
992 }),
993 },
994 )],
995 cx,
996 );
997
998 cx.run_until_parked();
999
1000 cx.update(|cx| {
1001 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1002 });
1003 }
1004
1005 // Ensure that nothing happens if the settings do not change
1006 {
1007 let _server_events = assert_server_events(&store, vec![], cx);
1008 set_context_server_configuration(
1009 vec![(
1010 server_1_id.0.clone(),
1011 ContextServerSettings::Extension {
1012 enabled: true,
1013 settings: json!({
1014 "somevalue": false
1015 }),
1016 },
1017 )],
1018 cx,
1019 );
1020
1021 cx.run_until_parked();
1022
1023 cx.update(|cx| {
1024 assert_eq!(
1025 store.read(cx).status_for_server(&server_1_id),
1026 Some(ContextServerStatus::Running)
1027 );
1028 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1029 });
1030 }
1031 }
1032
1033 #[gpui::test]
1034 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1035 const SERVER_1_ID: &'static str = "mcp-1";
1036
1037 let server_1_id = ContextServerId(SERVER_1_ID.into());
1038
1039 let (_fs, project) = setup_context_server_test(
1040 cx,
1041 json!({"code.rs": ""}),
1042 vec![(
1043 SERVER_1_ID.into(),
1044 ContextServerSettings::Custom {
1045 enabled: true,
1046 command: ContextServerCommand {
1047 path: "somebinary".to_string(),
1048 args: vec!["arg".to_string()],
1049 env: None,
1050 },
1051 },
1052 )],
1053 )
1054 .await;
1055
1056 let executor = cx.executor();
1057 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1058 let store = cx.new(|cx| {
1059 ContextServerStore::test_maintain_server_loop(
1060 Box::new(move |id, _| {
1061 Arc::new(ContextServer::new(
1062 id.clone(),
1063 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1064 ))
1065 }),
1066 registry.clone(),
1067 project.read(cx).worktree_store(),
1068 cx,
1069 )
1070 });
1071
1072 // Ensure that mcp-1 starts up
1073 {
1074 let _server_events = assert_server_events(
1075 &store,
1076 vec![
1077 (server_1_id.clone(), ContextServerStatus::Starting),
1078 (server_1_id.clone(), ContextServerStatus::Running),
1079 ],
1080 cx,
1081 );
1082 cx.run_until_parked();
1083 }
1084
1085 // Ensure that mcp-1 is stopped once it is disabled.
1086 {
1087 let _server_events = assert_server_events(
1088 &store,
1089 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1090 cx,
1091 );
1092 set_context_server_configuration(
1093 vec![(
1094 server_1_id.0.clone(),
1095 ContextServerSettings::Custom {
1096 enabled: false,
1097 command: ContextServerCommand {
1098 path: "somebinary".to_string(),
1099 args: vec!["arg".to_string()],
1100 env: None,
1101 },
1102 },
1103 )],
1104 cx,
1105 );
1106
1107 cx.run_until_parked();
1108 }
1109
1110 // Ensure that mcp-1 is started once it is enabled again.
1111 {
1112 let _server_events = assert_server_events(
1113 &store,
1114 vec![
1115 (server_1_id.clone(), ContextServerStatus::Starting),
1116 (server_1_id.clone(), ContextServerStatus::Running),
1117 ],
1118 cx,
1119 );
1120 set_context_server_configuration(
1121 vec![(
1122 server_1_id.0.clone(),
1123 ContextServerSettings::Custom {
1124 enabled: true,
1125 command: ContextServerCommand {
1126 path: "somebinary".to_string(),
1127 args: vec!["arg".to_string()],
1128 env: None,
1129 },
1130 },
1131 )],
1132 cx,
1133 );
1134
1135 cx.run_until_parked();
1136 }
1137 }
1138
1139 fn set_context_server_configuration(
1140 context_servers: Vec<(Arc<str>, ContextServerSettings)>,
1141 cx: &mut TestAppContext,
1142 ) {
1143 cx.update(|cx| {
1144 SettingsStore::update_global(cx, |store, cx| {
1145 let mut settings = ProjectSettings::default();
1146 for (id, config) in context_servers {
1147 settings.context_servers.insert(id, config);
1148 }
1149 store
1150 .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
1151 .unwrap();
1152 })
1153 });
1154 }
1155
1156 struct ServerEvents {
1157 received_event_count: Rc<RefCell<usize>>,
1158 expected_event_count: usize,
1159 _subscription: Subscription,
1160 }
1161
1162 impl Drop for ServerEvents {
1163 fn drop(&mut self) {
1164 let actual_event_count = *self.received_event_count.borrow();
1165 assert_eq!(
1166 actual_event_count, self.expected_event_count,
1167 "
1168 Expected to receive {} context server store events, but received {} events",
1169 self.expected_event_count, actual_event_count
1170 );
1171 }
1172 }
1173
1174 fn dummy_server_settings() -> ContextServerSettings {
1175 ContextServerSettings::Custom {
1176 enabled: true,
1177 command: ContextServerCommand {
1178 path: "somebinary".to_string(),
1179 args: vec!["arg".to_string()],
1180 env: None,
1181 },
1182 }
1183 }
1184
1185 fn assert_server_events(
1186 store: &Entity<ContextServerStore>,
1187 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1188 cx: &mut TestAppContext,
1189 ) -> ServerEvents {
1190 cx.update(|cx| {
1191 let mut ix = 0;
1192 let received_event_count = Rc::new(RefCell::new(0));
1193 let expected_event_count = expected_events.len();
1194 let subscription = cx.subscribe(store, {
1195 let received_event_count = received_event_count.clone();
1196 move |_, event, _| match event {
1197 Event::ServerStatusChanged {
1198 server_id: actual_server_id,
1199 status: actual_status,
1200 } => {
1201 let (expected_server_id, expected_status) = &expected_events[ix];
1202
1203 assert_eq!(
1204 actual_server_id, expected_server_id,
1205 "Expected different server id at index {}",
1206 ix
1207 );
1208 assert_eq!(
1209 actual_status, expected_status,
1210 "Expected different status at index {}",
1211 ix
1212 );
1213 ix += 1;
1214 *received_event_count.borrow_mut() += 1;
1215 }
1216 }
1217 });
1218 ServerEvents {
1219 expected_event_count,
1220 received_event_count,
1221 _subscription: subscription,
1222 }
1223 })
1224 }
1225
1226 async fn setup_context_server_test(
1227 cx: &mut TestAppContext,
1228 files: serde_json::Value,
1229 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1230 ) -> (Arc<FakeFs>, Entity<Project>) {
1231 cx.update(|cx| {
1232 let settings_store = SettingsStore::test(cx);
1233 cx.set_global(settings_store);
1234 Project::init_settings(cx);
1235 let mut settings = ProjectSettings::get_global(cx).clone();
1236 for (id, config) in context_server_configurations {
1237 settings.context_servers.insert(id, config);
1238 }
1239 ProjectSettings::override_global(settings, cx);
1240 });
1241
1242 let fs = FakeFs::new(cx.executor());
1243 fs.insert_tree(path!("/test"), files).await;
1244 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1245
1246 (fs, project)
1247 }
1248
1249 struct FakeContextServerDescriptor {
1250 path: String,
1251 }
1252
1253 impl FakeContextServerDescriptor {
1254 fn new(path: impl Into<String>) -> Self {
1255 Self { path: path.into() }
1256 }
1257 }
1258
1259 impl ContextServerDescriptor for FakeContextServerDescriptor {
1260 fn command(
1261 &self,
1262 _worktree_store: Entity<WorktreeStore>,
1263 _cx: &AsyncApp,
1264 ) -> Task<Result<ContextServerCommand>> {
1265 Task::ready(Ok(ContextServerCommand {
1266 path: self.path.clone(),
1267 args: vec!["arg1".to_string(), "arg2".to_string()],
1268 env: None,
1269 }))
1270 }
1271
1272 fn configuration(
1273 &self,
1274 _worktree_store: Entity<WorktreeStore>,
1275 _cx: &AsyncApp,
1276 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1277 Task::ready(Ok(None))
1278 }
1279 }
1280}