1pub mod extension;
2pub mod registry;
3
4use std::{path::Path, sync::Arc};
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
9use futures::{FutureExt as _, future::join_all};
10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
11use registry::ContextServerDescriptorRegistry;
12use settings::{Settings as _, SettingsStore};
13use util::ResultExt as _;
14
15use crate::{
16 project_settings::{ContextServerSettings, ProjectSettings},
17 worktree_store::WorktreeStore,
18};
19
20pub fn init(cx: &mut App) {
21 extension::init(cx);
22}
23
24actions!(context_server, [Restart]);
25
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27pub enum ContextServerStatus {
28 Starting,
29 Running,
30 Stopped,
31 Error(Arc<str>),
32}
33
34impl ContextServerStatus {
35 fn from_state(state: &ContextServerState) -> Self {
36 match state {
37 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
38 ContextServerState::Running { .. } => ContextServerStatus::Running,
39 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
40 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
41 }
42 }
43}
44
45enum ContextServerState {
46 Starting {
47 server: Arc<ContextServer>,
48 configuration: Arc<ContextServerConfiguration>,
49 _task: Task<()>,
50 },
51 Running {
52 server: Arc<ContextServer>,
53 configuration: Arc<ContextServerConfiguration>,
54 },
55 Stopped {
56 server: Arc<ContextServer>,
57 configuration: Arc<ContextServerConfiguration>,
58 },
59 Error {
60 server: Arc<ContextServer>,
61 configuration: Arc<ContextServerConfiguration>,
62 error: Arc<str>,
63 },
64}
65
66impl ContextServerState {
67 pub fn server(&self) -> Arc<ContextServer> {
68 match self {
69 ContextServerState::Starting { server, .. } => server.clone(),
70 ContextServerState::Running { server, .. } => server.clone(),
71 ContextServerState::Stopped { server, .. } => server.clone(),
72 ContextServerState::Error { server, .. } => server.clone(),
73 }
74 }
75
76 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
77 match self {
78 ContextServerState::Starting { configuration, .. } => configuration.clone(),
79 ContextServerState::Running { configuration, .. } => configuration.clone(),
80 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
81 ContextServerState::Error { configuration, .. } => configuration.clone(),
82 }
83 }
84}
85
86#[derive(Debug, PartialEq, Eq)]
87pub enum ContextServerConfiguration {
88 Custom {
89 command: ContextServerCommand,
90 },
91 Extension {
92 command: ContextServerCommand,
93 settings: serde_json::Value,
94 },
95}
96
97impl ContextServerConfiguration {
98 pub fn command(&self) -> &ContextServerCommand {
99 match self {
100 ContextServerConfiguration::Custom { command } => command,
101 ContextServerConfiguration::Extension { command, .. } => command,
102 }
103 }
104
105 pub async fn from_settings(
106 settings: ContextServerSettings,
107 id: ContextServerId,
108 registry: Entity<ContextServerDescriptorRegistry>,
109 worktree_store: Entity<WorktreeStore>,
110 cx: &AsyncApp,
111 ) -> Option<Self> {
112 match settings {
113 ContextServerSettings::Custom {
114 enabled: _,
115 command,
116 } => Some(ContextServerConfiguration::Custom { command }),
117 ContextServerSettings::Extension {
118 enabled: _,
119 settings,
120 } => {
121 let descriptor = cx
122 .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
123 .ok()
124 .flatten()?;
125
126 let command = descriptor.command(worktree_store, cx).await.log_err()?;
127
128 Some(ContextServerConfiguration::Extension { command, settings })
129 }
130 }
131 }
132}
133
134pub type ContextServerFactory =
135 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
136
137pub struct ContextServerStore {
138 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
139 servers: HashMap<ContextServerId, ContextServerState>,
140 worktree_store: Entity<WorktreeStore>,
141 registry: Entity<ContextServerDescriptorRegistry>,
142 update_servers_task: Option<Task<Result<()>>>,
143 context_server_factory: Option<ContextServerFactory>,
144 needs_server_update: bool,
145 _subscriptions: Vec<Subscription>,
146}
147
148pub enum Event {
149 ServerStatusChanged {
150 server_id: ContextServerId,
151 status: ContextServerStatus,
152 },
153}
154
155impl EventEmitter<Event> for ContextServerStore {}
156
157impl ContextServerStore {
158 pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
159 Self::new_internal(
160 true,
161 None,
162 ContextServerDescriptorRegistry::default_global(cx),
163 worktree_store,
164 cx,
165 )
166 }
167
168 #[cfg(any(test, feature = "test-support"))]
169 pub fn test(
170 registry: Entity<ContextServerDescriptorRegistry>,
171 worktree_store: Entity<WorktreeStore>,
172 cx: &mut Context<Self>,
173 ) -> Self {
174 Self::new_internal(false, None, registry, worktree_store, cx)
175 }
176
177 #[cfg(any(test, feature = "test-support"))]
178 pub fn test_maintain_server_loop(
179 context_server_factory: ContextServerFactory,
180 registry: Entity<ContextServerDescriptorRegistry>,
181 worktree_store: Entity<WorktreeStore>,
182 cx: &mut Context<Self>,
183 ) -> Self {
184 Self::new_internal(
185 true,
186 Some(context_server_factory),
187 registry,
188 worktree_store,
189 cx,
190 )
191 }
192
193 fn new_internal(
194 maintain_server_loop: bool,
195 context_server_factory: Option<ContextServerFactory>,
196 registry: Entity<ContextServerDescriptorRegistry>,
197 worktree_store: Entity<WorktreeStore>,
198 cx: &mut Context<Self>,
199 ) -> Self {
200 let subscriptions = if maintain_server_loop {
201 vec![
202 cx.observe(®istry, |this, _registry, cx| {
203 this.available_context_servers_changed(cx);
204 }),
205 cx.observe_global::<SettingsStore>(|this, cx| {
206 let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
207 if &this.context_server_settings == settings {
208 return;
209 }
210 this.context_server_settings = settings.clone();
211 this.available_context_servers_changed(cx);
212 }),
213 ]
214 } else {
215 Vec::new()
216 };
217
218 let mut this = Self {
219 _subscriptions: subscriptions,
220 context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
221 .clone(),
222 worktree_store,
223 registry,
224 needs_server_update: false,
225 servers: HashMap::default(),
226 update_servers_task: None,
227 context_server_factory,
228 };
229 if maintain_server_loop {
230 this.available_context_servers_changed(cx);
231 }
232 this
233 }
234
235 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
236 self.servers.get(id).map(|state| state.server())
237 }
238
239 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
240 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
241 Some(server.clone())
242 } else {
243 None
244 }
245 }
246
247 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
248 self.servers.get(id).map(ContextServerStatus::from_state)
249 }
250
251 pub fn configuration_for_server(
252 &self,
253 id: &ContextServerId,
254 ) -> Option<Arc<ContextServerConfiguration>> {
255 self.servers.get(id).map(|state| state.configuration())
256 }
257
258 pub fn all_server_ids(&self) -> Vec<ContextServerId> {
259 self.servers.keys().cloned().collect()
260 }
261
262 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
263 self.servers
264 .values()
265 .filter_map(|state| {
266 if let ContextServerState::Running { server, .. } = state {
267 Some(server.clone())
268 } else {
269 None
270 }
271 })
272 .collect()
273 }
274
275 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
276 cx.spawn(async move |this, cx| {
277 let this = this.upgrade().context("Context server store dropped")?;
278 let settings = this
279 .update(cx, |this, _| {
280 this.context_server_settings.get(&server.id().0).cloned()
281 })
282 .ok()
283 .flatten()
284 .context("Failed to get context server settings")?;
285
286 if !settings.enabled() {
287 return Ok(());
288 }
289
290 let (registry, worktree_store) = this.update(cx, |this, _| {
291 (this.registry.clone(), this.worktree_store.clone())
292 })?;
293 let configuration = ContextServerConfiguration::from_settings(
294 settings,
295 server.id(),
296 registry,
297 worktree_store,
298 cx,
299 )
300 .await
301 .context("Failed to create context server configuration")?;
302
303 this.update(cx, |this, cx| {
304 this.run_server(server, Arc::new(configuration), cx)
305 })
306 })
307 .detach_and_log_err(cx);
308 }
309
310 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
311 if matches!(
312 self.servers.get(id),
313 Some(ContextServerState::Stopped { .. })
314 ) {
315 return Ok(());
316 }
317
318 let state = self
319 .servers
320 .remove(id)
321 .context("Context server not found")?;
322
323 let server = state.server();
324 let configuration = state.configuration();
325 let mut result = Ok(());
326 if let ContextServerState::Running { server, .. } = &state {
327 result = server.stop();
328 }
329 drop(state);
330
331 self.update_server_state(
332 id.clone(),
333 ContextServerState::Stopped {
334 configuration,
335 server,
336 },
337 cx,
338 );
339
340 result
341 }
342
343 pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
344 if let Some(state) = self.servers.get(&id) {
345 let configuration = state.configuration();
346
347 self.stop_server(&state.server().id(), cx)?;
348 let new_server = self.create_context_server(id.clone(), configuration.clone())?;
349 self.run_server(new_server, configuration, cx);
350 }
351 Ok(())
352 }
353
354 fn run_server(
355 &mut self,
356 server: Arc<ContextServer>,
357 configuration: Arc<ContextServerConfiguration>,
358 cx: &mut Context<Self>,
359 ) {
360 let id = server.id();
361 if matches!(
362 self.servers.get(&id),
363 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
364 ) {
365 self.stop_server(&id, cx).log_err();
366 }
367
368 let task = cx.spawn({
369 let id = server.id();
370 let server = server.clone();
371 let configuration = configuration.clone();
372 async move |this, cx| {
373 match server.clone().start(&cx).await {
374 Ok(_) => {
375 log::info!("Started {} context server", id);
376 debug_assert!(server.client().is_some());
377
378 this.update(cx, |this, cx| {
379 this.update_server_state(
380 id.clone(),
381 ContextServerState::Running {
382 server,
383 configuration,
384 },
385 cx,
386 )
387 })
388 .log_err()
389 }
390 Err(err) => {
391 log::error!("{} context server failed to start: {}", id, err);
392 this.update(cx, |this, cx| {
393 this.update_server_state(
394 id.clone(),
395 ContextServerState::Error {
396 configuration,
397 server,
398 error: err.to_string().into(),
399 },
400 cx,
401 )
402 })
403 .log_err()
404 }
405 };
406 }
407 });
408
409 self.update_server_state(
410 id.clone(),
411 ContextServerState::Starting {
412 configuration,
413 _task: task,
414 server,
415 },
416 cx,
417 );
418 }
419
420 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
421 let state = self
422 .servers
423 .remove(id)
424 .context("Context server not found")?;
425 drop(state);
426 cx.emit(Event::ServerStatusChanged {
427 server_id: id.clone(),
428 status: ContextServerStatus::Stopped,
429 });
430 Ok(())
431 }
432
433 fn create_context_server(
434 &self,
435 id: ContextServerId,
436 configuration: Arc<ContextServerConfiguration>,
437 ) -> Result<Arc<ContextServer>> {
438 if let Some(factory) = self.context_server_factory.as_ref() {
439 Ok(factory(id, configuration))
440 } else {
441 Ok(Arc::new(ContextServer::stdio(
442 id,
443 configuration.command().clone(),
444 )))
445 }
446 }
447
448 fn resolve_context_server_settings<'a>(
449 worktree_store: &'a Entity<WorktreeStore>,
450 cx: &'a App,
451 ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
452 let location = worktree_store
453 .read(cx)
454 .visible_worktrees(cx)
455 .next()
456 .map(|worktree| settings::SettingsLocation {
457 worktree_id: worktree.read(cx).id(),
458 path: Path::new(""),
459 });
460 &ProjectSettings::get(location, cx).context_servers
461 }
462
463 fn update_server_state(
464 &mut self,
465 id: ContextServerId,
466 state: ContextServerState,
467 cx: &mut Context<Self>,
468 ) {
469 let status = ContextServerStatus::from_state(&state);
470 self.servers.insert(id.clone(), state);
471 cx.emit(Event::ServerStatusChanged {
472 server_id: id,
473 status,
474 });
475 }
476
477 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
478 if self.update_servers_task.is_some() {
479 self.needs_server_update = true;
480 } else {
481 self.needs_server_update = false;
482 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
483 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
484 log::error!("Error maintaining context servers: {}", err);
485 }
486
487 this.update(cx, |this, cx| {
488 this.update_servers_task.take();
489 if this.needs_server_update {
490 this.available_context_servers_changed(cx);
491 }
492 })?;
493
494 Ok(())
495 }));
496 }
497 }
498
499 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
500 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
501 (
502 this.context_server_settings.clone(),
503 this.registry.clone(),
504 this.worktree_store.clone(),
505 )
506 })?;
507
508 for (id, _) in
509 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
510 {
511 configured_servers
512 .entry(id)
513 .or_insert(ContextServerSettings::default_extension());
514 }
515
516 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
517 configured_servers
518 .into_iter()
519 .partition(|(_, settings)| settings.enabled());
520
521 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
522 let id = ContextServerId(id);
523 ContextServerConfiguration::from_settings(
524 settings,
525 id.clone(),
526 registry.clone(),
527 worktree_store.clone(),
528 cx,
529 )
530 .map(|config| (id, config))
531 }))
532 .await
533 .into_iter()
534 .filter_map(|(id, config)| config.map(|config| (id, config)))
535 .collect::<HashMap<_, _>>();
536
537 let mut servers_to_start = Vec::new();
538 let mut servers_to_remove = HashSet::default();
539 let mut servers_to_stop = HashSet::default();
540
541 this.update(cx, |this, _cx| {
542 for server_id in this.servers.keys() {
543 // All servers that are not in desired_servers should be removed from the store.
544 // This can happen if the user removed a server from the context server settings.
545 if !configured_servers.contains_key(&server_id) {
546 if disabled_servers.contains_key(&server_id.0) {
547 servers_to_stop.insert(server_id.clone());
548 } else {
549 servers_to_remove.insert(server_id.clone());
550 }
551 }
552 }
553
554 for (id, config) in configured_servers {
555 let state = this.servers.get(&id);
556 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
557 let existing_config = state.as_ref().map(|state| state.configuration());
558 if existing_config.as_deref() != Some(&config) || is_stopped {
559 let config = Arc::new(config);
560 if let Some(server) = this
561 .create_context_server(id.clone(), config.clone())
562 .log_err()
563 {
564 servers_to_start.push((server, config));
565 if this.servers.contains_key(&id) {
566 servers_to_stop.insert(id);
567 }
568 }
569 }
570 }
571 })?;
572
573 this.update(cx, |this, cx| {
574 for id in servers_to_stop {
575 this.stop_server(&id, cx)?;
576 }
577 for id in servers_to_remove {
578 this.remove_server(&id, cx)?;
579 }
580 for (server, config) in servers_to_start {
581 this.run_server(server, config, cx);
582 }
583 anyhow::Ok(())
584 })?
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591 use crate::{
592 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
593 project_settings::ProjectSettings,
594 };
595 use context_server::test::create_fake_transport;
596 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
597 use serde_json::json;
598 use std::{cell::RefCell, rc::Rc};
599 use util::path;
600
601 #[gpui::test]
602 async fn test_context_server_status(cx: &mut TestAppContext) {
603 const SERVER_1_ID: &'static str = "mcp-1";
604 const SERVER_2_ID: &'static str = "mcp-2";
605
606 let (_fs, project) = setup_context_server_test(
607 cx,
608 json!({"code.rs": ""}),
609 vec![
610 (SERVER_1_ID.into(), dummy_server_settings()),
611 (SERVER_2_ID.into(), dummy_server_settings()),
612 ],
613 )
614 .await;
615
616 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
617 let store = cx.new(|cx| {
618 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
619 });
620
621 let server_1_id = ContextServerId(SERVER_1_ID.into());
622 let server_2_id = ContextServerId(SERVER_2_ID.into());
623
624 let server_1 = Arc::new(ContextServer::new(
625 server_1_id.clone(),
626 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
627 ));
628 let server_2 = Arc::new(ContextServer::new(
629 server_2_id.clone(),
630 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
631 ));
632
633 store.update(cx, |store, cx| store.start_server(server_1, cx));
634
635 cx.run_until_parked();
636
637 cx.update(|cx| {
638 assert_eq!(
639 store.read(cx).status_for_server(&server_1_id),
640 Some(ContextServerStatus::Running)
641 );
642 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
643 });
644
645 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
646
647 cx.run_until_parked();
648
649 cx.update(|cx| {
650 assert_eq!(
651 store.read(cx).status_for_server(&server_1_id),
652 Some(ContextServerStatus::Running)
653 );
654 assert_eq!(
655 store.read(cx).status_for_server(&server_2_id),
656 Some(ContextServerStatus::Running)
657 );
658 });
659
660 store
661 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
662 .unwrap();
663
664 cx.update(|cx| {
665 assert_eq!(
666 store.read(cx).status_for_server(&server_1_id),
667 Some(ContextServerStatus::Running)
668 );
669 assert_eq!(
670 store.read(cx).status_for_server(&server_2_id),
671 Some(ContextServerStatus::Stopped)
672 );
673 });
674 }
675
676 #[gpui::test]
677 async fn test_context_server_status_events(cx: &mut TestAppContext) {
678 const SERVER_1_ID: &'static str = "mcp-1";
679 const SERVER_2_ID: &'static str = "mcp-2";
680
681 let (_fs, project) = setup_context_server_test(
682 cx,
683 json!({"code.rs": ""}),
684 vec![
685 (SERVER_1_ID.into(), dummy_server_settings()),
686 (SERVER_2_ID.into(), dummy_server_settings()),
687 ],
688 )
689 .await;
690
691 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
692 let store = cx.new(|cx| {
693 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
694 });
695
696 let server_1_id = ContextServerId(SERVER_1_ID.into());
697 let server_2_id = ContextServerId(SERVER_2_ID.into());
698
699 let server_1 = Arc::new(ContextServer::new(
700 server_1_id.clone(),
701 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
702 ));
703 let server_2 = Arc::new(ContextServer::new(
704 server_2_id.clone(),
705 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
706 ));
707
708 let _server_events = assert_server_events(
709 &store,
710 vec![
711 (server_1_id.clone(), ContextServerStatus::Starting),
712 (server_1_id.clone(), ContextServerStatus::Running),
713 (server_2_id.clone(), ContextServerStatus::Starting),
714 (server_2_id.clone(), ContextServerStatus::Running),
715 (server_2_id.clone(), ContextServerStatus::Stopped),
716 ],
717 cx,
718 );
719
720 store.update(cx, |store, cx| store.start_server(server_1, cx));
721
722 cx.run_until_parked();
723
724 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
725
726 cx.run_until_parked();
727
728 store
729 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
730 .unwrap();
731 }
732
733 #[gpui::test(iterations = 25)]
734 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
735 const SERVER_1_ID: &'static str = "mcp-1";
736
737 let (_fs, project) = setup_context_server_test(
738 cx,
739 json!({"code.rs": ""}),
740 vec![(SERVER_1_ID.into(), dummy_server_settings())],
741 )
742 .await;
743
744 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
745 let store = cx.new(|cx| {
746 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
747 });
748
749 let server_id = ContextServerId(SERVER_1_ID.into());
750
751 let server_with_same_id_1 = Arc::new(ContextServer::new(
752 server_id.clone(),
753 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
754 ));
755 let server_with_same_id_2 = Arc::new(ContextServer::new(
756 server_id.clone(),
757 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
758 ));
759
760 // If we start another server with the same id, we should report that we stopped the previous one
761 let _server_events = assert_server_events(
762 &store,
763 vec![
764 (server_id.clone(), ContextServerStatus::Starting),
765 (server_id.clone(), ContextServerStatus::Stopped),
766 (server_id.clone(), ContextServerStatus::Starting),
767 (server_id.clone(), ContextServerStatus::Running),
768 ],
769 cx,
770 );
771
772 store.update(cx, |store, cx| {
773 store.start_server(server_with_same_id_1.clone(), cx)
774 });
775 store.update(cx, |store, cx| {
776 store.start_server(server_with_same_id_2.clone(), cx)
777 });
778
779 cx.run_until_parked();
780
781 cx.update(|cx| {
782 assert_eq!(
783 store.read(cx).status_for_server(&server_id),
784 Some(ContextServerStatus::Running)
785 );
786 });
787 }
788
789 #[gpui::test]
790 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
791 const SERVER_1_ID: &'static str = "mcp-1";
792 const SERVER_2_ID: &'static str = "mcp-2";
793
794 let server_1_id = ContextServerId(SERVER_1_ID.into());
795 let server_2_id = ContextServerId(SERVER_2_ID.into());
796
797 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
798
799 let (_fs, project) = setup_context_server_test(
800 cx,
801 json!({"code.rs": ""}),
802 vec![(
803 SERVER_1_ID.into(),
804 ContextServerSettings::Extension {
805 enabled: true,
806 settings: json!({
807 "somevalue": true
808 }),
809 },
810 )],
811 )
812 .await;
813
814 let executor = cx.executor();
815 let registry = cx.new(|_| {
816 let mut registry = ContextServerDescriptorRegistry::new();
817 registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1);
818 registry
819 });
820 let store = cx.new(|cx| {
821 ContextServerStore::test_maintain_server_loop(
822 Box::new(move |id, _| {
823 Arc::new(ContextServer::new(
824 id.clone(),
825 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
826 ))
827 }),
828 registry.clone(),
829 project.read(cx).worktree_store(),
830 cx,
831 )
832 });
833
834 // Ensure that mcp-1 starts up
835 {
836 let _server_events = assert_server_events(
837 &store,
838 vec![
839 (server_1_id.clone(), ContextServerStatus::Starting),
840 (server_1_id.clone(), ContextServerStatus::Running),
841 ],
842 cx,
843 );
844 cx.run_until_parked();
845 }
846
847 // Ensure that mcp-1 is restarted when the configuration was changed
848 {
849 let _server_events = assert_server_events(
850 &store,
851 vec![
852 (server_1_id.clone(), ContextServerStatus::Stopped),
853 (server_1_id.clone(), ContextServerStatus::Starting),
854 (server_1_id.clone(), ContextServerStatus::Running),
855 ],
856 cx,
857 );
858 set_context_server_configuration(
859 vec![(
860 server_1_id.0.clone(),
861 ContextServerSettings::Extension {
862 enabled: true,
863 settings: json!({
864 "somevalue": false
865 }),
866 },
867 )],
868 cx,
869 );
870
871 cx.run_until_parked();
872 }
873
874 // Ensure that mcp-1 is not restarted when the configuration was not changed
875 {
876 let _server_events = assert_server_events(&store, vec![], cx);
877 set_context_server_configuration(
878 vec![(
879 server_1_id.0.clone(),
880 ContextServerSettings::Extension {
881 enabled: true,
882 settings: json!({
883 "somevalue": false
884 }),
885 },
886 )],
887 cx,
888 );
889
890 cx.run_until_parked();
891 }
892
893 // Ensure that mcp-2 is started once it is added to the settings
894 {
895 let _server_events = assert_server_events(
896 &store,
897 vec![
898 (server_2_id.clone(), ContextServerStatus::Starting),
899 (server_2_id.clone(), ContextServerStatus::Running),
900 ],
901 cx,
902 );
903 set_context_server_configuration(
904 vec![
905 (
906 server_1_id.0.clone(),
907 ContextServerSettings::Extension {
908 enabled: true,
909 settings: json!({
910 "somevalue": false
911 }),
912 },
913 ),
914 (
915 server_2_id.0.clone(),
916 ContextServerSettings::Custom {
917 enabled: true,
918 command: ContextServerCommand {
919 path: "somebinary".to_string(),
920 args: vec!["arg".to_string()],
921 env: None,
922 },
923 },
924 ),
925 ],
926 cx,
927 );
928
929 cx.run_until_parked();
930 }
931
932 // Ensure that mcp-2 is restarted once the args have changed
933 {
934 let _server_events = assert_server_events(
935 &store,
936 vec![
937 (server_2_id.clone(), ContextServerStatus::Stopped),
938 (server_2_id.clone(), ContextServerStatus::Starting),
939 (server_2_id.clone(), ContextServerStatus::Running),
940 ],
941 cx,
942 );
943 set_context_server_configuration(
944 vec![
945 (
946 server_1_id.0.clone(),
947 ContextServerSettings::Extension {
948 enabled: true,
949 settings: json!({
950 "somevalue": false
951 }),
952 },
953 ),
954 (
955 server_2_id.0.clone(),
956 ContextServerSettings::Custom {
957 enabled: true,
958 command: ContextServerCommand {
959 path: "somebinary".to_string(),
960 args: vec!["anotherArg".to_string()],
961 env: None,
962 },
963 },
964 ),
965 ],
966 cx,
967 );
968
969 cx.run_until_parked();
970 }
971
972 // Ensure that mcp-2 is removed once it is removed from the settings
973 {
974 let _server_events = assert_server_events(
975 &store,
976 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
977 cx,
978 );
979 set_context_server_configuration(
980 vec![(
981 server_1_id.0.clone(),
982 ContextServerSettings::Extension {
983 enabled: true,
984 settings: json!({
985 "somevalue": false
986 }),
987 },
988 )],
989 cx,
990 );
991
992 cx.run_until_parked();
993
994 cx.update(|cx| {
995 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
996 });
997 }
998
999 // Ensure that nothing happens if the settings do not change
1000 {
1001 let _server_events = assert_server_events(&store, vec![], cx);
1002 set_context_server_configuration(
1003 vec![(
1004 server_1_id.0.clone(),
1005 ContextServerSettings::Extension {
1006 enabled: true,
1007 settings: json!({
1008 "somevalue": false
1009 }),
1010 },
1011 )],
1012 cx,
1013 );
1014
1015 cx.run_until_parked();
1016
1017 cx.update(|cx| {
1018 assert_eq!(
1019 store.read(cx).status_for_server(&server_1_id),
1020 Some(ContextServerStatus::Running)
1021 );
1022 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1023 });
1024 }
1025 }
1026
1027 #[gpui::test]
1028 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1029 const SERVER_1_ID: &'static str = "mcp-1";
1030
1031 let server_1_id = ContextServerId(SERVER_1_ID.into());
1032
1033 let (_fs, project) = setup_context_server_test(
1034 cx,
1035 json!({"code.rs": ""}),
1036 vec![(
1037 SERVER_1_ID.into(),
1038 ContextServerSettings::Custom {
1039 enabled: true,
1040 command: ContextServerCommand {
1041 path: "somebinary".to_string(),
1042 args: vec!["arg".to_string()],
1043 env: None,
1044 },
1045 },
1046 )],
1047 )
1048 .await;
1049
1050 let executor = cx.executor();
1051 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1052 let store = cx.new(|cx| {
1053 ContextServerStore::test_maintain_server_loop(
1054 Box::new(move |id, _| {
1055 Arc::new(ContextServer::new(
1056 id.clone(),
1057 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1058 ))
1059 }),
1060 registry.clone(),
1061 project.read(cx).worktree_store(),
1062 cx,
1063 )
1064 });
1065
1066 // Ensure that mcp-1 starts up
1067 {
1068 let _server_events = assert_server_events(
1069 &store,
1070 vec![
1071 (server_1_id.clone(), ContextServerStatus::Starting),
1072 (server_1_id.clone(), ContextServerStatus::Running),
1073 ],
1074 cx,
1075 );
1076 cx.run_until_parked();
1077 }
1078
1079 // Ensure that mcp-1 is stopped once it is disabled.
1080 {
1081 let _server_events = assert_server_events(
1082 &store,
1083 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1084 cx,
1085 );
1086 set_context_server_configuration(
1087 vec![(
1088 server_1_id.0.clone(),
1089 ContextServerSettings::Custom {
1090 enabled: false,
1091 command: ContextServerCommand {
1092 path: "somebinary".to_string(),
1093 args: vec!["arg".to_string()],
1094 env: None,
1095 },
1096 },
1097 )],
1098 cx,
1099 );
1100
1101 cx.run_until_parked();
1102 }
1103
1104 // Ensure that mcp-1 is started once it is enabled again.
1105 {
1106 let _server_events = assert_server_events(
1107 &store,
1108 vec![
1109 (server_1_id.clone(), ContextServerStatus::Starting),
1110 (server_1_id.clone(), ContextServerStatus::Running),
1111 ],
1112 cx,
1113 );
1114 set_context_server_configuration(
1115 vec![(
1116 server_1_id.0.clone(),
1117 ContextServerSettings::Custom {
1118 enabled: true,
1119 command: ContextServerCommand {
1120 path: "somebinary".to_string(),
1121 args: vec!["arg".to_string()],
1122 env: None,
1123 },
1124 },
1125 )],
1126 cx,
1127 );
1128
1129 cx.run_until_parked();
1130 }
1131 }
1132
1133 fn set_context_server_configuration(
1134 context_servers: Vec<(Arc<str>, ContextServerSettings)>,
1135 cx: &mut TestAppContext,
1136 ) {
1137 cx.update(|cx| {
1138 SettingsStore::update_global(cx, |store, cx| {
1139 let mut settings = ProjectSettings::default();
1140 for (id, config) in context_servers {
1141 settings.context_servers.insert(id, config);
1142 }
1143 store
1144 .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
1145 .unwrap();
1146 })
1147 });
1148 }
1149
1150 struct ServerEvents {
1151 received_event_count: Rc<RefCell<usize>>,
1152 expected_event_count: usize,
1153 _subscription: Subscription,
1154 }
1155
1156 impl Drop for ServerEvents {
1157 fn drop(&mut self) {
1158 let actual_event_count = *self.received_event_count.borrow();
1159 assert_eq!(
1160 actual_event_count, self.expected_event_count,
1161 "
1162 Expected to receive {} context server store events, but received {} events",
1163 self.expected_event_count, actual_event_count
1164 );
1165 }
1166 }
1167
1168 fn dummy_server_settings() -> ContextServerSettings {
1169 ContextServerSettings::Custom {
1170 enabled: true,
1171 command: ContextServerCommand {
1172 path: "somebinary".to_string(),
1173 args: vec!["arg".to_string()],
1174 env: None,
1175 },
1176 }
1177 }
1178
1179 fn assert_server_events(
1180 store: &Entity<ContextServerStore>,
1181 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1182 cx: &mut TestAppContext,
1183 ) -> ServerEvents {
1184 cx.update(|cx| {
1185 let mut ix = 0;
1186 let received_event_count = Rc::new(RefCell::new(0));
1187 let expected_event_count = expected_events.len();
1188 let subscription = cx.subscribe(store, {
1189 let received_event_count = received_event_count.clone();
1190 move |_, event, _| match event {
1191 Event::ServerStatusChanged {
1192 server_id: actual_server_id,
1193 status: actual_status,
1194 } => {
1195 let (expected_server_id, expected_status) = &expected_events[ix];
1196
1197 assert_eq!(
1198 actual_server_id, expected_server_id,
1199 "Expected different server id at index {}",
1200 ix
1201 );
1202 assert_eq!(
1203 actual_status, expected_status,
1204 "Expected different status at index {}",
1205 ix
1206 );
1207 ix += 1;
1208 *received_event_count.borrow_mut() += 1;
1209 }
1210 }
1211 });
1212 ServerEvents {
1213 expected_event_count,
1214 received_event_count,
1215 _subscription: subscription,
1216 }
1217 })
1218 }
1219
1220 async fn setup_context_server_test(
1221 cx: &mut TestAppContext,
1222 files: serde_json::Value,
1223 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1224 ) -> (Arc<FakeFs>, Entity<Project>) {
1225 cx.update(|cx| {
1226 let settings_store = SettingsStore::test(cx);
1227 cx.set_global(settings_store);
1228 Project::init_settings(cx);
1229 let mut settings = ProjectSettings::get_global(cx).clone();
1230 for (id, config) in context_server_configurations {
1231 settings.context_servers.insert(id, config);
1232 }
1233 ProjectSettings::override_global(settings, cx);
1234 });
1235
1236 let fs = FakeFs::new(cx.executor());
1237 fs.insert_tree(path!("/test"), files).await;
1238 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1239
1240 (fs, project)
1241 }
1242
1243 struct FakeContextServerDescriptor {
1244 path: String,
1245 }
1246
1247 impl FakeContextServerDescriptor {
1248 fn new(path: impl Into<String>) -> Self {
1249 Self { path: path.into() }
1250 }
1251 }
1252
1253 impl ContextServerDescriptor for FakeContextServerDescriptor {
1254 fn command(
1255 &self,
1256 _worktree_store: Entity<WorktreeStore>,
1257 _cx: &AsyncApp,
1258 ) -> Task<Result<ContextServerCommand>> {
1259 Task::ready(Ok(ContextServerCommand {
1260 path: self.path.clone(),
1261 args: vec!["arg1".to_string(), "arg2".to_string()],
1262 env: None,
1263 }))
1264 }
1265
1266 fn configuration(
1267 &self,
1268 _worktree_store: Entity<WorktreeStore>,
1269 _cx: &AsyncApp,
1270 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1271 Task::ready(Ok(None))
1272 }
1273 }
1274}