1pub mod extension;
2pub mod registry;
3
4use std::{path::Path, sync::Arc};
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
9use futures::{FutureExt as _, future::join_all};
10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
11use registry::ContextServerDescriptorRegistry;
12use settings::{Settings as _, SettingsStore};
13use util::ResultExt as _;
14
15use crate::{
16 Project,
17 project_settings::{ContextServerSettings, ProjectSettings},
18 worktree_store::WorktreeStore,
19};
20
21pub fn init(cx: &mut App) {
22 extension::init(cx);
23}
24
25actions!(
26 context_server,
27 [
28 /// Restarts the context server.
29 Restart
30 ]
31);
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub enum ContextServerStatus {
35 Starting,
36 Running,
37 Stopped,
38 Error(Arc<str>),
39}
40
41impl ContextServerStatus {
42 fn from_state(state: &ContextServerState) -> Self {
43 match state {
44 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
45 ContextServerState::Running { .. } => ContextServerStatus::Running,
46 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
47 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
48 }
49 }
50}
51
52enum ContextServerState {
53 Starting {
54 server: Arc<ContextServer>,
55 configuration: Arc<ContextServerConfiguration>,
56 _task: Task<()>,
57 },
58 Running {
59 server: Arc<ContextServer>,
60 configuration: Arc<ContextServerConfiguration>,
61 },
62 Stopped {
63 server: Arc<ContextServer>,
64 configuration: Arc<ContextServerConfiguration>,
65 },
66 Error {
67 server: Arc<ContextServer>,
68 configuration: Arc<ContextServerConfiguration>,
69 error: Arc<str>,
70 },
71}
72
73impl ContextServerState {
74 pub fn server(&self) -> Arc<ContextServer> {
75 match self {
76 ContextServerState::Starting { server, .. } => server.clone(),
77 ContextServerState::Running { server, .. } => server.clone(),
78 ContextServerState::Stopped { server, .. } => server.clone(),
79 ContextServerState::Error { server, .. } => server.clone(),
80 }
81 }
82
83 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
84 match self {
85 ContextServerState::Starting { configuration, .. } => configuration.clone(),
86 ContextServerState::Running { configuration, .. } => configuration.clone(),
87 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
88 ContextServerState::Error { configuration, .. } => configuration.clone(),
89 }
90 }
91}
92
93#[derive(Debug, PartialEq, Eq)]
94pub enum ContextServerConfiguration {
95 Custom {
96 command: ContextServerCommand,
97 },
98 Extension {
99 command: ContextServerCommand,
100 settings: serde_json::Value,
101 },
102}
103
104impl ContextServerConfiguration {
105 pub fn command(&self) -> &ContextServerCommand {
106 match self {
107 ContextServerConfiguration::Custom { command } => command,
108 ContextServerConfiguration::Extension { command, .. } => command,
109 }
110 }
111
112 pub async fn from_settings(
113 settings: ContextServerSettings,
114 id: ContextServerId,
115 registry: Entity<ContextServerDescriptorRegistry>,
116 worktree_store: Entity<WorktreeStore>,
117 cx: &AsyncApp,
118 ) -> Option<Self> {
119 match settings {
120 ContextServerSettings::Custom {
121 enabled: _,
122 command,
123 } => Some(ContextServerConfiguration::Custom { command }),
124 ContextServerSettings::Extension {
125 enabled: _,
126 settings,
127 } => {
128 let descriptor = cx
129 .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
130 .ok()
131 .flatten()?;
132
133 let command = descriptor.command(worktree_store, cx).await.log_err()?;
134
135 Some(ContextServerConfiguration::Extension { command, settings })
136 }
137 }
138 }
139}
140
141pub type ContextServerFactory =
142 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
143
144pub struct ContextServerStore {
145 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
146 servers: HashMap<ContextServerId, ContextServerState>,
147 worktree_store: Entity<WorktreeStore>,
148 project: WeakEntity<Project>,
149 registry: Entity<ContextServerDescriptorRegistry>,
150 update_servers_task: Option<Task<Result<()>>>,
151 context_server_factory: Option<ContextServerFactory>,
152 needs_server_update: bool,
153 _subscriptions: Vec<Subscription>,
154}
155
156pub enum Event {
157 ServerStatusChanged {
158 server_id: ContextServerId,
159 status: ContextServerStatus,
160 },
161}
162
163impl EventEmitter<Event> for ContextServerStore {}
164
165impl ContextServerStore {
166 pub fn new(
167 worktree_store: Entity<WorktreeStore>,
168 weak_project: WeakEntity<Project>,
169 cx: &mut Context<Self>,
170 ) -> Self {
171 Self::new_internal(
172 true,
173 None,
174 ContextServerDescriptorRegistry::default_global(cx),
175 worktree_store,
176 weak_project,
177 cx,
178 )
179 }
180
181 /// Returns all configured context server ids, regardless of enabled state.
182 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
183 self.context_server_settings
184 .keys()
185 .cloned()
186 .map(ContextServerId)
187 .collect()
188 }
189
190 #[cfg(any(test, feature = "test-support"))]
191 pub fn test(
192 registry: Entity<ContextServerDescriptorRegistry>,
193 worktree_store: Entity<WorktreeStore>,
194 weak_project: WeakEntity<Project>,
195 cx: &mut Context<Self>,
196 ) -> Self {
197 Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
198 }
199
200 #[cfg(any(test, feature = "test-support"))]
201 pub fn test_maintain_server_loop(
202 context_server_factory: ContextServerFactory,
203 registry: Entity<ContextServerDescriptorRegistry>,
204 worktree_store: Entity<WorktreeStore>,
205 weak_project: WeakEntity<Project>,
206 cx: &mut Context<Self>,
207 ) -> Self {
208 Self::new_internal(
209 true,
210 Some(context_server_factory),
211 registry,
212 worktree_store,
213 weak_project,
214 cx,
215 )
216 }
217
218 fn new_internal(
219 maintain_server_loop: bool,
220 context_server_factory: Option<ContextServerFactory>,
221 registry: Entity<ContextServerDescriptorRegistry>,
222 worktree_store: Entity<WorktreeStore>,
223 weak_project: WeakEntity<Project>,
224 cx: &mut Context<Self>,
225 ) -> Self {
226 let subscriptions = if maintain_server_loop {
227 vec![
228 cx.observe(®istry, |this, _registry, cx| {
229 this.available_context_servers_changed(cx);
230 }),
231 cx.observe_global::<SettingsStore>(|this, cx| {
232 let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
233 if &this.context_server_settings == settings {
234 return;
235 }
236 this.context_server_settings = settings.clone();
237 this.available_context_servers_changed(cx);
238 }),
239 ]
240 } else {
241 Vec::new()
242 };
243
244 let mut this = Self {
245 _subscriptions: subscriptions,
246 context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
247 .clone(),
248 worktree_store,
249 project: weak_project,
250 registry,
251 needs_server_update: false,
252 servers: HashMap::default(),
253 update_servers_task: None,
254 context_server_factory,
255 };
256 if maintain_server_loop {
257 this.available_context_servers_changed(cx);
258 }
259 this
260 }
261
262 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
263 self.servers.get(id).map(|state| state.server())
264 }
265
266 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
267 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
268 Some(server.clone())
269 } else {
270 None
271 }
272 }
273
274 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
275 self.servers.get(id).map(ContextServerStatus::from_state)
276 }
277
278 pub fn configuration_for_server(
279 &self,
280 id: &ContextServerId,
281 ) -> Option<Arc<ContextServerConfiguration>> {
282 self.servers.get(id).map(|state| state.configuration())
283 }
284
285 pub fn all_server_ids(&self) -> Vec<ContextServerId> {
286 self.servers.keys().cloned().collect()
287 }
288
289 pub fn all_registry_descriptor_ids(&self, cx: &App) -> Vec<ContextServerId> {
290 self.registry
291 .read(cx)
292 .context_server_descriptors()
293 .into_iter()
294 .map(|(id, _)| ContextServerId(id))
295 .collect()
296 }
297
298 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
299 self.servers
300 .values()
301 .filter_map(|state| {
302 if let ContextServerState::Running { server, .. } = state {
303 Some(server.clone())
304 } else {
305 None
306 }
307 })
308 .collect()
309 }
310
311 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
312 cx.spawn(async move |this, cx| {
313 let this = this.upgrade().context("Context server store dropped")?;
314 let settings = this
315 .update(cx, |this, _| {
316 this.context_server_settings.get(&server.id().0).cloned()
317 })
318 .ok()
319 .flatten()
320 .context("Failed to get context server settings")?;
321
322 if !settings.enabled() {
323 return Ok(());
324 }
325
326 let (registry, worktree_store) = this.update(cx, |this, _| {
327 (this.registry.clone(), this.worktree_store.clone())
328 })?;
329 let configuration = ContextServerConfiguration::from_settings(
330 settings,
331 server.id(),
332 registry,
333 worktree_store,
334 cx,
335 )
336 .await
337 .context("Failed to create context server configuration")?;
338
339 this.update(cx, |this, cx| {
340 this.run_server(server, Arc::new(configuration), cx)
341 })
342 })
343 .detach_and_log_err(cx);
344 }
345
346 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
347 if matches!(
348 self.servers.get(id),
349 Some(ContextServerState::Stopped { .. })
350 ) {
351 return Ok(());
352 }
353
354 let state = self
355 .servers
356 .remove(id)
357 .context("Context server not found")?;
358
359 let server = state.server();
360 let configuration = state.configuration();
361 let mut result = Ok(());
362 if let ContextServerState::Running { server, .. } = &state {
363 result = server.stop();
364 }
365 drop(state);
366
367 self.update_server_state(
368 id.clone(),
369 ContextServerState::Stopped {
370 configuration,
371 server,
372 },
373 cx,
374 );
375
376 result
377 }
378
379 pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
380 if let Some(state) = self.servers.get(id) {
381 let configuration = state.configuration();
382
383 self.stop_server(&state.server().id(), cx)?;
384 let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
385 self.run_server(new_server, configuration, cx);
386 }
387 Ok(())
388 }
389
390 fn run_server(
391 &mut self,
392 server: Arc<ContextServer>,
393 configuration: Arc<ContextServerConfiguration>,
394 cx: &mut Context<Self>,
395 ) {
396 let id = server.id();
397 if matches!(
398 self.servers.get(&id),
399 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
400 ) {
401 self.stop_server(&id, cx).log_err();
402 }
403
404 let task = cx.spawn({
405 let id = server.id();
406 let server = server.clone();
407 let configuration = configuration.clone();
408 async move |this, cx| {
409 match server.clone().start(cx).await {
410 Ok(_) => {
411 debug_assert!(server.client().is_some());
412
413 this.update(cx, |this, cx| {
414 this.update_server_state(
415 id.clone(),
416 ContextServerState::Running {
417 server,
418 configuration,
419 },
420 cx,
421 )
422 })
423 .log_err()
424 }
425 Err(err) => {
426 log::error!("{} context server failed to start: {}", id, err);
427 this.update(cx, |this, cx| {
428 this.update_server_state(
429 id.clone(),
430 ContextServerState::Error {
431 configuration,
432 server,
433 error: err.to_string().into(),
434 },
435 cx,
436 )
437 })
438 .log_err()
439 }
440 };
441 }
442 });
443
444 self.update_server_state(
445 id.clone(),
446 ContextServerState::Starting {
447 configuration,
448 _task: task,
449 server,
450 },
451 cx,
452 );
453 }
454
455 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
456 let state = self
457 .servers
458 .remove(id)
459 .context("Context server not found")?;
460 drop(state);
461 cx.emit(Event::ServerStatusChanged {
462 server_id: id.clone(),
463 status: ContextServerStatus::Stopped,
464 });
465 Ok(())
466 }
467
468 fn create_context_server(
469 &self,
470 id: ContextServerId,
471 configuration: Arc<ContextServerConfiguration>,
472 cx: &mut Context<Self>,
473 ) -> Arc<ContextServer> {
474 let root_path = self
475 .project
476 .read_with(cx, |project, cx| project.active_project_directory(cx))
477 .ok()
478 .flatten()
479 .or_else(|| {
480 self.worktree_store.read_with(cx, |store, cx| {
481 store.visible_worktrees(cx).fold(None, |acc, item| {
482 if acc.is_none() {
483 item.read(cx).root_dir()
484 } else {
485 acc
486 }
487 })
488 })
489 });
490
491 if let Some(factory) = self.context_server_factory.as_ref() {
492 factory(id, configuration)
493 } else {
494 Arc::new(ContextServer::stdio(
495 id,
496 configuration.command().clone(),
497 root_path,
498 ))
499 }
500 }
501
502 fn resolve_context_server_settings<'a>(
503 worktree_store: &'a Entity<WorktreeStore>,
504 cx: &'a App,
505 ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
506 let location = worktree_store
507 .read(cx)
508 .visible_worktrees(cx)
509 .next()
510 .map(|worktree| settings::SettingsLocation {
511 worktree_id: worktree.read(cx).id(),
512 path: Path::new(""),
513 });
514 &ProjectSettings::get(location, cx).context_servers
515 }
516
517 fn update_server_state(
518 &mut self,
519 id: ContextServerId,
520 state: ContextServerState,
521 cx: &mut Context<Self>,
522 ) {
523 let status = ContextServerStatus::from_state(&state);
524 self.servers.insert(id.clone(), state);
525 cx.emit(Event::ServerStatusChanged {
526 server_id: id,
527 status,
528 });
529 }
530
531 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
532 if self.update_servers_task.is_some() {
533 self.needs_server_update = true;
534 } else {
535 self.needs_server_update = false;
536 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
537 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
538 log::error!("Error maintaining context servers: {}", err);
539 }
540
541 this.update(cx, |this, cx| {
542 this.update_servers_task.take();
543 if this.needs_server_update {
544 this.available_context_servers_changed(cx);
545 }
546 })?;
547
548 Ok(())
549 }));
550 }
551 }
552
553 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
554 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
555 (
556 this.context_server_settings.clone(),
557 this.registry.clone(),
558 this.worktree_store.clone(),
559 )
560 })?;
561
562 for (id, _) in
563 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
564 {
565 configured_servers
566 .entry(id)
567 .or_insert(ContextServerSettings::default_extension());
568 }
569
570 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
571 configured_servers
572 .into_iter()
573 .partition(|(_, settings)| settings.enabled());
574
575 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
576 let id = ContextServerId(id);
577 ContextServerConfiguration::from_settings(
578 settings,
579 id.clone(),
580 registry.clone(),
581 worktree_store.clone(),
582 cx,
583 )
584 .map(|config| (id, config))
585 }))
586 .await
587 .into_iter()
588 .filter_map(|(id, config)| config.map(|config| (id, config)))
589 .collect::<HashMap<_, _>>();
590
591 let mut servers_to_start = Vec::new();
592 let mut servers_to_remove = HashSet::default();
593 let mut servers_to_stop = HashSet::default();
594
595 this.update(cx, |this, cx| {
596 for server_id in this.servers.keys() {
597 // All servers that are not in desired_servers should be removed from the store.
598 // This can happen if the user removed a server from the context server settings.
599 if !configured_servers.contains_key(server_id) {
600 if disabled_servers.contains_key(&server_id.0) {
601 servers_to_stop.insert(server_id.clone());
602 } else {
603 servers_to_remove.insert(server_id.clone());
604 }
605 }
606 }
607
608 for (id, config) in configured_servers {
609 let state = this.servers.get(&id);
610 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
611 let existing_config = state.as_ref().map(|state| state.configuration());
612 if existing_config.as_deref() != Some(&config) || is_stopped {
613 let config = Arc::new(config);
614 let server = this.create_context_server(id.clone(), config.clone(), cx);
615 servers_to_start.push((server, config));
616 if this.servers.contains_key(&id) {
617 servers_to_stop.insert(id);
618 }
619 }
620 }
621 })?;
622
623 this.update(cx, |this, cx| {
624 for id in servers_to_stop {
625 this.stop_server(&id, cx)?;
626 }
627 for id in servers_to_remove {
628 this.remove_server(&id, cx)?;
629 }
630 for (server, config) in servers_to_start {
631 this.run_server(server, config, cx);
632 }
633 anyhow::Ok(())
634 })?
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use crate::{
642 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
643 project_settings::ProjectSettings,
644 };
645 use context_server::test::create_fake_transport;
646 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
647 use serde_json::json;
648 use std::{cell::RefCell, path::PathBuf, rc::Rc};
649 use util::path;
650
651 #[gpui::test]
652 async fn test_context_server_status(cx: &mut TestAppContext) {
653 const SERVER_1_ID: &str = "mcp-1";
654 const SERVER_2_ID: &str = "mcp-2";
655
656 let (_fs, project) = setup_context_server_test(
657 cx,
658 json!({"code.rs": ""}),
659 vec![
660 (SERVER_1_ID.into(), dummy_server_settings()),
661 (SERVER_2_ID.into(), dummy_server_settings()),
662 ],
663 )
664 .await;
665
666 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
667 let store = cx.new(|cx| {
668 ContextServerStore::test(
669 registry.clone(),
670 project.read(cx).worktree_store(),
671 project.downgrade(),
672 cx,
673 )
674 });
675
676 let server_1_id = ContextServerId(SERVER_1_ID.into());
677 let server_2_id = ContextServerId(SERVER_2_ID.into());
678
679 let server_1 = Arc::new(ContextServer::new(
680 server_1_id.clone(),
681 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
682 ));
683 let server_2 = Arc::new(ContextServer::new(
684 server_2_id.clone(),
685 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
686 ));
687
688 store.update(cx, |store, cx| store.start_server(server_1, cx));
689
690 cx.run_until_parked();
691
692 cx.update(|cx| {
693 assert_eq!(
694 store.read(cx).status_for_server(&server_1_id),
695 Some(ContextServerStatus::Running)
696 );
697 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
698 });
699
700 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
701
702 cx.run_until_parked();
703
704 cx.update(|cx| {
705 assert_eq!(
706 store.read(cx).status_for_server(&server_1_id),
707 Some(ContextServerStatus::Running)
708 );
709 assert_eq!(
710 store.read(cx).status_for_server(&server_2_id),
711 Some(ContextServerStatus::Running)
712 );
713 });
714
715 store
716 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
717 .unwrap();
718
719 cx.update(|cx| {
720 assert_eq!(
721 store.read(cx).status_for_server(&server_1_id),
722 Some(ContextServerStatus::Running)
723 );
724 assert_eq!(
725 store.read(cx).status_for_server(&server_2_id),
726 Some(ContextServerStatus::Stopped)
727 );
728 });
729 }
730
731 #[gpui::test]
732 async fn test_context_server_status_events(cx: &mut TestAppContext) {
733 const SERVER_1_ID: &str = "mcp-1";
734 const SERVER_2_ID: &str = "mcp-2";
735
736 let (_fs, project) = setup_context_server_test(
737 cx,
738 json!({"code.rs": ""}),
739 vec![
740 (SERVER_1_ID.into(), dummy_server_settings()),
741 (SERVER_2_ID.into(), dummy_server_settings()),
742 ],
743 )
744 .await;
745
746 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
747 let store = cx.new(|cx| {
748 ContextServerStore::test(
749 registry.clone(),
750 project.read(cx).worktree_store(),
751 project.downgrade(),
752 cx,
753 )
754 });
755
756 let server_1_id = ContextServerId(SERVER_1_ID.into());
757 let server_2_id = ContextServerId(SERVER_2_ID.into());
758
759 let server_1 = Arc::new(ContextServer::new(
760 server_1_id.clone(),
761 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
762 ));
763 let server_2 = Arc::new(ContextServer::new(
764 server_2_id.clone(),
765 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
766 ));
767
768 let _server_events = assert_server_events(
769 &store,
770 vec![
771 (server_1_id.clone(), ContextServerStatus::Starting),
772 (server_1_id, ContextServerStatus::Running),
773 (server_2_id.clone(), ContextServerStatus::Starting),
774 (server_2_id.clone(), ContextServerStatus::Running),
775 (server_2_id.clone(), ContextServerStatus::Stopped),
776 ],
777 cx,
778 );
779
780 store.update(cx, |store, cx| store.start_server(server_1, cx));
781
782 cx.run_until_parked();
783
784 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
785
786 cx.run_until_parked();
787
788 store
789 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
790 .unwrap();
791 }
792
793 #[gpui::test(iterations = 25)]
794 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
795 const SERVER_1_ID: &str = "mcp-1";
796
797 let (_fs, project) = setup_context_server_test(
798 cx,
799 json!({"code.rs": ""}),
800 vec![(SERVER_1_ID.into(), dummy_server_settings())],
801 )
802 .await;
803
804 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
805 let store = cx.new(|cx| {
806 ContextServerStore::test(
807 registry.clone(),
808 project.read(cx).worktree_store(),
809 project.downgrade(),
810 cx,
811 )
812 });
813
814 let server_id = ContextServerId(SERVER_1_ID.into());
815
816 let server_with_same_id_1 = Arc::new(ContextServer::new(
817 server_id.clone(),
818 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
819 ));
820 let server_with_same_id_2 = Arc::new(ContextServer::new(
821 server_id.clone(),
822 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
823 ));
824
825 // If we start another server with the same id, we should report that we stopped the previous one
826 let _server_events = assert_server_events(
827 &store,
828 vec![
829 (server_id.clone(), ContextServerStatus::Starting),
830 (server_id.clone(), ContextServerStatus::Stopped),
831 (server_id.clone(), ContextServerStatus::Starting),
832 (server_id.clone(), ContextServerStatus::Running),
833 ],
834 cx,
835 );
836
837 store.update(cx, |store, cx| {
838 store.start_server(server_with_same_id_1.clone(), cx)
839 });
840 store.update(cx, |store, cx| {
841 store.start_server(server_with_same_id_2.clone(), cx)
842 });
843
844 cx.run_until_parked();
845
846 cx.update(|cx| {
847 assert_eq!(
848 store.read(cx).status_for_server(&server_id),
849 Some(ContextServerStatus::Running)
850 );
851 });
852 }
853
854 #[gpui::test]
855 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
856 const SERVER_1_ID: &str = "mcp-1";
857 const SERVER_2_ID: &str = "mcp-2";
858
859 let server_1_id = ContextServerId(SERVER_1_ID.into());
860 let server_2_id = ContextServerId(SERVER_2_ID.into());
861
862 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
863
864 let (_fs, project) = setup_context_server_test(
865 cx,
866 json!({"code.rs": ""}),
867 vec![(
868 SERVER_1_ID.into(),
869 ContextServerSettings::Extension {
870 enabled: true,
871 settings: json!({
872 "somevalue": true
873 }),
874 },
875 )],
876 )
877 .await;
878
879 let executor = cx.executor();
880 let registry = cx.new(|cx| {
881 let mut registry = ContextServerDescriptorRegistry::new();
882 registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
883 registry
884 });
885 let store = cx.new(|cx| {
886 ContextServerStore::test_maintain_server_loop(
887 Box::new(move |id, _| {
888 Arc::new(ContextServer::new(
889 id.clone(),
890 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
891 ))
892 }),
893 registry.clone(),
894 project.read(cx).worktree_store(),
895 project.downgrade(),
896 cx,
897 )
898 });
899
900 // Ensure that mcp-1 starts up
901 {
902 let _server_events = assert_server_events(
903 &store,
904 vec![
905 (server_1_id.clone(), ContextServerStatus::Starting),
906 (server_1_id.clone(), ContextServerStatus::Running),
907 ],
908 cx,
909 );
910 cx.run_until_parked();
911 }
912
913 // Ensure that mcp-1 is restarted when the configuration was changed
914 {
915 let _server_events = assert_server_events(
916 &store,
917 vec![
918 (server_1_id.clone(), ContextServerStatus::Stopped),
919 (server_1_id.clone(), ContextServerStatus::Starting),
920 (server_1_id.clone(), ContextServerStatus::Running),
921 ],
922 cx,
923 );
924 set_context_server_configuration(
925 vec![(
926 server_1_id.0.clone(),
927 settings::ContextServerSettingsContent::Extension {
928 enabled: true,
929 settings: json!({
930 "somevalue": false
931 }),
932 },
933 )],
934 cx,
935 );
936
937 cx.run_until_parked();
938 }
939
940 // Ensure that mcp-1 is not restarted when the configuration was not changed
941 {
942 let _server_events = assert_server_events(&store, vec![], cx);
943 set_context_server_configuration(
944 vec![(
945 server_1_id.0.clone(),
946 settings::ContextServerSettingsContent::Extension {
947 enabled: true,
948 settings: json!({
949 "somevalue": false
950 }),
951 },
952 )],
953 cx,
954 );
955
956 cx.run_until_parked();
957 }
958
959 // Ensure that mcp-2 is started once it is added to the settings
960 {
961 let _server_events = assert_server_events(
962 &store,
963 vec![
964 (server_2_id.clone(), ContextServerStatus::Starting),
965 (server_2_id.clone(), ContextServerStatus::Running),
966 ],
967 cx,
968 );
969 set_context_server_configuration(
970 vec![
971 (
972 server_1_id.0.clone(),
973 settings::ContextServerSettingsContent::Extension {
974 enabled: true,
975 settings: json!({
976 "somevalue": false
977 }),
978 },
979 ),
980 (
981 server_2_id.0.clone(),
982 settings::ContextServerSettingsContent::Custom {
983 enabled: true,
984 command: ContextServerCommand {
985 path: "somebinary".into(),
986 args: vec!["arg".to_string()],
987 env: None,
988 timeout: None,
989 },
990 },
991 ),
992 ],
993 cx,
994 );
995
996 cx.run_until_parked();
997 }
998
999 // Ensure that mcp-2 is restarted once the args have changed
1000 {
1001 let _server_events = assert_server_events(
1002 &store,
1003 vec![
1004 (server_2_id.clone(), ContextServerStatus::Stopped),
1005 (server_2_id.clone(), ContextServerStatus::Starting),
1006 (server_2_id.clone(), ContextServerStatus::Running),
1007 ],
1008 cx,
1009 );
1010 set_context_server_configuration(
1011 vec![
1012 (
1013 server_1_id.0.clone(),
1014 settings::ContextServerSettingsContent::Extension {
1015 enabled: true,
1016 settings: json!({
1017 "somevalue": false
1018 }),
1019 },
1020 ),
1021 (
1022 server_2_id.0.clone(),
1023 settings::ContextServerSettingsContent::Custom {
1024 enabled: true,
1025 command: ContextServerCommand {
1026 path: "somebinary".into(),
1027 args: vec!["anotherArg".to_string()],
1028 env: None,
1029 timeout: None,
1030 },
1031 },
1032 ),
1033 ],
1034 cx,
1035 );
1036
1037 cx.run_until_parked();
1038 }
1039
1040 // Ensure that mcp-2 is removed once it is removed from the settings
1041 {
1042 let _server_events = assert_server_events(
1043 &store,
1044 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1045 cx,
1046 );
1047 set_context_server_configuration(
1048 vec![(
1049 server_1_id.0.clone(),
1050 settings::ContextServerSettingsContent::Extension {
1051 enabled: true,
1052 settings: json!({
1053 "somevalue": false
1054 }),
1055 },
1056 )],
1057 cx,
1058 );
1059
1060 cx.run_until_parked();
1061
1062 cx.update(|cx| {
1063 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1064 });
1065 }
1066
1067 // Ensure that nothing happens if the settings do not change
1068 {
1069 let _server_events = assert_server_events(&store, vec![], cx);
1070 set_context_server_configuration(
1071 vec![(
1072 server_1_id.0.clone(),
1073 settings::ContextServerSettingsContent::Extension {
1074 enabled: true,
1075 settings: json!({
1076 "somevalue": false
1077 }),
1078 },
1079 )],
1080 cx,
1081 );
1082
1083 cx.run_until_parked();
1084
1085 cx.update(|cx| {
1086 assert_eq!(
1087 store.read(cx).status_for_server(&server_1_id),
1088 Some(ContextServerStatus::Running)
1089 );
1090 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1091 });
1092 }
1093 }
1094
1095 #[gpui::test]
1096 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1097 const SERVER_1_ID: &str = "mcp-1";
1098
1099 let server_1_id = ContextServerId(SERVER_1_ID.into());
1100
1101 let (_fs, project) = setup_context_server_test(
1102 cx,
1103 json!({"code.rs": ""}),
1104 vec![(
1105 SERVER_1_ID.into(),
1106 ContextServerSettings::Custom {
1107 enabled: true,
1108 command: ContextServerCommand {
1109 path: "somebinary".into(),
1110 args: vec!["arg".to_string()],
1111 env: None,
1112 timeout: None,
1113 },
1114 },
1115 )],
1116 )
1117 .await;
1118
1119 let executor = cx.executor();
1120 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1121 let store = cx.new(|cx| {
1122 ContextServerStore::test_maintain_server_loop(
1123 Box::new(move |id, _| {
1124 Arc::new(ContextServer::new(
1125 id.clone(),
1126 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1127 ))
1128 }),
1129 registry.clone(),
1130 project.read(cx).worktree_store(),
1131 project.downgrade(),
1132 cx,
1133 )
1134 });
1135
1136 // Ensure that mcp-1 starts up
1137 {
1138 let _server_events = assert_server_events(
1139 &store,
1140 vec![
1141 (server_1_id.clone(), ContextServerStatus::Starting),
1142 (server_1_id.clone(), ContextServerStatus::Running),
1143 ],
1144 cx,
1145 );
1146 cx.run_until_parked();
1147 }
1148
1149 // Ensure that mcp-1 is stopped once it is disabled.
1150 {
1151 let _server_events = assert_server_events(
1152 &store,
1153 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1154 cx,
1155 );
1156 set_context_server_configuration(
1157 vec![(
1158 server_1_id.0.clone(),
1159 settings::ContextServerSettingsContent::Custom {
1160 enabled: false,
1161 command: ContextServerCommand {
1162 path: "somebinary".into(),
1163 args: vec!["arg".to_string()],
1164 env: None,
1165 timeout: None,
1166 },
1167 },
1168 )],
1169 cx,
1170 );
1171
1172 cx.run_until_parked();
1173 }
1174
1175 // Ensure that mcp-1 is started once it is enabled again.
1176 {
1177 let _server_events = assert_server_events(
1178 &store,
1179 vec![
1180 (server_1_id.clone(), ContextServerStatus::Starting),
1181 (server_1_id.clone(), ContextServerStatus::Running),
1182 ],
1183 cx,
1184 );
1185 set_context_server_configuration(
1186 vec![(
1187 server_1_id.0.clone(),
1188 settings::ContextServerSettingsContent::Custom {
1189 enabled: true,
1190 command: ContextServerCommand {
1191 path: "somebinary".into(),
1192 args: vec!["arg".to_string()],
1193 timeout: None,
1194 env: None,
1195 },
1196 },
1197 )],
1198 cx,
1199 );
1200
1201 cx.run_until_parked();
1202 }
1203 }
1204
1205 fn set_context_server_configuration(
1206 context_servers: Vec<(Arc<str>, settings::ContextServerSettingsContent)>,
1207 cx: &mut TestAppContext,
1208 ) {
1209 cx.update(|cx| {
1210 SettingsStore::update_global(cx, |store, cx| {
1211 store.update_user_settings(cx, |content| {
1212 content.project.context_servers.clear();
1213 for (id, config) in context_servers {
1214 content.project.context_servers.insert(id, config);
1215 }
1216 });
1217 })
1218 });
1219 }
1220
1221 struct ServerEvents {
1222 received_event_count: Rc<RefCell<usize>>,
1223 expected_event_count: usize,
1224 _subscription: Subscription,
1225 }
1226
1227 impl Drop for ServerEvents {
1228 fn drop(&mut self) {
1229 let actual_event_count = *self.received_event_count.borrow();
1230 assert_eq!(
1231 actual_event_count, self.expected_event_count,
1232 "
1233 Expected to receive {} context server store events, but received {} events",
1234 self.expected_event_count, actual_event_count
1235 );
1236 }
1237 }
1238
1239 fn dummy_server_settings() -> ContextServerSettings {
1240 ContextServerSettings::Custom {
1241 enabled: true,
1242 command: ContextServerCommand {
1243 path: "somebinary".into(),
1244 args: vec!["arg".to_string()],
1245 env: None,
1246 timeout: None,
1247 },
1248 }
1249 }
1250
1251 fn assert_server_events(
1252 store: &Entity<ContextServerStore>,
1253 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1254 cx: &mut TestAppContext,
1255 ) -> ServerEvents {
1256 cx.update(|cx| {
1257 let mut ix = 0;
1258 let received_event_count = Rc::new(RefCell::new(0));
1259 let expected_event_count = expected_events.len();
1260 let subscription = cx.subscribe(store, {
1261 let received_event_count = received_event_count.clone();
1262 move |_, event, _| match event {
1263 Event::ServerStatusChanged {
1264 server_id: actual_server_id,
1265 status: actual_status,
1266 } => {
1267 let (expected_server_id, expected_status) = &expected_events[ix];
1268
1269 assert_eq!(
1270 actual_server_id, expected_server_id,
1271 "Expected different server id at index {}",
1272 ix
1273 );
1274 assert_eq!(
1275 actual_status, expected_status,
1276 "Expected different status at index {}",
1277 ix
1278 );
1279 ix += 1;
1280 *received_event_count.borrow_mut() += 1;
1281 }
1282 }
1283 });
1284 ServerEvents {
1285 expected_event_count,
1286 received_event_count,
1287 _subscription: subscription,
1288 }
1289 })
1290 }
1291
1292 async fn setup_context_server_test(
1293 cx: &mut TestAppContext,
1294 files: serde_json::Value,
1295 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1296 ) -> (Arc<FakeFs>, Entity<Project>) {
1297 cx.update(|cx| {
1298 let settings_store = SettingsStore::test(cx);
1299 cx.set_global(settings_store);
1300 Project::init_settings(cx);
1301 let mut settings = ProjectSettings::get_global(cx).clone();
1302 for (id, config) in context_server_configurations {
1303 settings.context_servers.insert(id, config);
1304 }
1305 ProjectSettings::override_global(settings, cx);
1306 });
1307
1308 let fs = FakeFs::new(cx.executor());
1309 fs.insert_tree(path!("/test"), files).await;
1310 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1311
1312 (fs, project)
1313 }
1314
1315 struct FakeContextServerDescriptor {
1316 path: PathBuf,
1317 }
1318
1319 impl FakeContextServerDescriptor {
1320 fn new(path: impl Into<PathBuf>) -> Self {
1321 Self { path: path.into() }
1322 }
1323 }
1324
1325 impl ContextServerDescriptor for FakeContextServerDescriptor {
1326 fn command(
1327 &self,
1328 _worktree_store: Entity<WorktreeStore>,
1329 _cx: &AsyncApp,
1330 ) -> Task<Result<ContextServerCommand>> {
1331 Task::ready(Ok(ContextServerCommand {
1332 path: self.path.clone(),
1333 args: vec!["arg1".to_string(), "arg2".to_string()],
1334 env: None,
1335 timeout: None,
1336 }))
1337 }
1338
1339 fn configuration(
1340 &self,
1341 _worktree_store: Entity<WorktreeStore>,
1342 _cx: &AsyncApp,
1343 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1344 Task::ready(Ok(None))
1345 }
1346 }
1347}