1pub mod extension;
2pub mod registry;
3
4use std::{path::Path, sync::Arc};
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerId};
9use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
10use registry::ContextServerDescriptorRegistry;
11use settings::{Settings as _, SettingsStore};
12use util::ResultExt as _;
13
14use crate::{
15 project_settings::{ContextServerConfiguration, ProjectSettings},
16 worktree_store::WorktreeStore,
17};
18
19pub fn init(cx: &mut App) {
20 extension::init(cx);
21}
22
23actions!(context_server, [Restart]);
24
25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
26pub enum ContextServerStatus {
27 Starting,
28 Running,
29 Stopped,
30 Error(Arc<str>),
31}
32
33impl ContextServerStatus {
34 fn from_state(state: &ContextServerState) -> Self {
35 match state {
36 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
37 ContextServerState::Running { .. } => ContextServerStatus::Running,
38 ContextServerState::Stopped { error, .. } => {
39 if let Some(error) = error {
40 ContextServerStatus::Error(error.clone())
41 } else {
42 ContextServerStatus::Stopped
43 }
44 }
45 }
46 }
47}
48
49enum ContextServerState {
50 Starting {
51 server: Arc<ContextServer>,
52 configuration: Arc<ContextServerConfiguration>,
53 _task: Task<()>,
54 },
55 Running {
56 server: Arc<ContextServer>,
57 configuration: Arc<ContextServerConfiguration>,
58 },
59 Stopped {
60 server: Arc<ContextServer>,
61 configuration: Arc<ContextServerConfiguration>,
62 error: Option<Arc<str>>,
63 },
64}
65
66impl ContextServerState {
67 pub fn server(&self) -> Arc<ContextServer> {
68 match self {
69 ContextServerState::Starting { server, .. } => server.clone(),
70 ContextServerState::Running { server, .. } => server.clone(),
71 ContextServerState::Stopped { server, .. } => server.clone(),
72 }
73 }
74
75 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
76 match self {
77 ContextServerState::Starting { configuration, .. } => configuration.clone(),
78 ContextServerState::Running { configuration, .. } => configuration.clone(),
79 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
80 }
81 }
82}
83
84pub type ContextServerFactory =
85 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
86
87pub struct ContextServerStore {
88 servers: HashMap<ContextServerId, ContextServerState>,
89 worktree_store: Entity<WorktreeStore>,
90 registry: Entity<ContextServerDescriptorRegistry>,
91 update_servers_task: Option<Task<Result<()>>>,
92 context_server_factory: Option<ContextServerFactory>,
93 needs_server_update: bool,
94 _subscriptions: Vec<Subscription>,
95}
96
97pub enum Event {
98 ServerStatusChanged {
99 server_id: ContextServerId,
100 status: ContextServerStatus,
101 },
102}
103
104impl EventEmitter<Event> for ContextServerStore {}
105
106impl ContextServerStore {
107 pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
108 Self::new_internal(
109 true,
110 None,
111 ContextServerDescriptorRegistry::default_global(cx),
112 worktree_store,
113 cx,
114 )
115 }
116
117 #[cfg(any(test, feature = "test-support"))]
118 pub fn test(
119 registry: Entity<ContextServerDescriptorRegistry>,
120 worktree_store: Entity<WorktreeStore>,
121 cx: &mut Context<Self>,
122 ) -> Self {
123 Self::new_internal(false, None, registry, worktree_store, cx)
124 }
125
126 #[cfg(any(test, feature = "test-support"))]
127 pub fn test_maintain_server_loop(
128 context_server_factory: ContextServerFactory,
129 registry: Entity<ContextServerDescriptorRegistry>,
130 worktree_store: Entity<WorktreeStore>,
131 cx: &mut Context<Self>,
132 ) -> Self {
133 Self::new_internal(
134 true,
135 Some(context_server_factory),
136 registry,
137 worktree_store,
138 cx,
139 )
140 }
141
142 fn new_internal(
143 maintain_server_loop: bool,
144 context_server_factory: Option<ContextServerFactory>,
145 registry: Entity<ContextServerDescriptorRegistry>,
146 worktree_store: Entity<WorktreeStore>,
147 cx: &mut Context<Self>,
148 ) -> Self {
149 let subscriptions = if maintain_server_loop {
150 vec![
151 cx.observe(®istry, |this, _registry, cx| {
152 this.available_context_servers_changed(cx);
153 }),
154 cx.observe_global::<SettingsStore>(|this, cx| {
155 this.available_context_servers_changed(cx);
156 }),
157 ]
158 } else {
159 Vec::new()
160 };
161
162 let mut this = Self {
163 _subscriptions: subscriptions,
164 worktree_store,
165 registry,
166 needs_server_update: false,
167 servers: HashMap::default(),
168 update_servers_task: None,
169 context_server_factory,
170 };
171 if maintain_server_loop {
172 this.available_context_servers_changed(cx);
173 }
174 this
175 }
176
177 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
178 self.servers.get(id).map(|state| state.server())
179 }
180
181 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
182 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
183 Some(server.clone())
184 } else {
185 None
186 }
187 }
188
189 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
190 self.servers.get(id).map(ContextServerStatus::from_state)
191 }
192
193 pub fn all_server_ids(&self) -> Vec<ContextServerId> {
194 self.servers.keys().cloned().collect()
195 }
196
197 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
198 self.servers
199 .values()
200 .filter_map(|state| {
201 if let ContextServerState::Running { server, .. } = state {
202 Some(server.clone())
203 } else {
204 None
205 }
206 })
207 .collect()
208 }
209
210 pub fn start_server(
211 &mut self,
212 server: Arc<ContextServer>,
213 cx: &mut Context<Self>,
214 ) -> Result<()> {
215 let location = self
216 .worktree_store
217 .read(cx)
218 .visible_worktrees(cx)
219 .next()
220 .map(|worktree| settings::SettingsLocation {
221 worktree_id: worktree.read(cx).id(),
222 path: Path::new(""),
223 });
224 let settings = ProjectSettings::get(location, cx);
225 let configuration = settings
226 .context_servers
227 .get(&server.id().0)
228 .context("Failed to load context server configuration from settings")?
229 .clone();
230
231 self.run_server(server, Arc::new(configuration), cx);
232 Ok(())
233 }
234
235 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
236 let Some(state) = self.servers.remove(id) else {
237 return Err(anyhow::anyhow!("Context server not found"));
238 };
239
240 let server = state.server();
241 let configuration = state.configuration();
242 let mut result = Ok(());
243 if let ContextServerState::Running { server, .. } = &state {
244 result = server.stop();
245 }
246 drop(state);
247
248 self.update_server_state(
249 id.clone(),
250 ContextServerState::Stopped {
251 configuration,
252 server,
253 error: None,
254 },
255 cx,
256 );
257
258 result
259 }
260
261 pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
262 if let Some(state) = self.servers.get(&id) {
263 let configuration = state.configuration();
264
265 self.stop_server(&state.server().id(), cx)?;
266 let new_server = self.create_context_server(id.clone(), configuration.clone())?;
267 self.run_server(new_server, configuration, cx);
268 }
269 Ok(())
270 }
271
272 fn run_server(
273 &mut self,
274 server: Arc<ContextServer>,
275 configuration: Arc<ContextServerConfiguration>,
276 cx: &mut Context<Self>,
277 ) {
278 let id = server.id();
279 if matches!(
280 self.servers.get(&id),
281 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
282 ) {
283 self.stop_server(&id, cx).log_err();
284 }
285
286 let task = cx.spawn({
287 let id = server.id();
288 let server = server.clone();
289 let configuration = configuration.clone();
290 async move |this, cx| {
291 match server.clone().start(&cx).await {
292 Ok(_) => {
293 log::info!("Started {} context server", id);
294 debug_assert!(server.client().is_some());
295
296 this.update(cx, |this, cx| {
297 this.update_server_state(
298 id.clone(),
299 ContextServerState::Running {
300 server,
301 configuration,
302 },
303 cx,
304 )
305 })
306 .log_err()
307 }
308 Err(err) => {
309 log::error!("{} context server failed to start: {}", id, err);
310 this.update(cx, |this, cx| {
311 this.update_server_state(
312 id.clone(),
313 ContextServerState::Stopped {
314 configuration,
315 server,
316 error: Some(err.to_string().into()),
317 },
318 cx,
319 )
320 })
321 .log_err()
322 }
323 };
324 }
325 });
326
327 self.update_server_state(
328 id.clone(),
329 ContextServerState::Starting {
330 configuration,
331 _task: task,
332 server,
333 },
334 cx,
335 );
336 }
337
338 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
339 let Some(state) = self.servers.remove(id) else {
340 return Err(anyhow::anyhow!("Context server not found"));
341 };
342 drop(state);
343 cx.emit(Event::ServerStatusChanged {
344 server_id: id.clone(),
345 status: ContextServerStatus::Stopped,
346 });
347 Ok(())
348 }
349
350 fn is_configuration_valid(&self, configuration: &ContextServerConfiguration) -> bool {
351 // Command must be some when we are running in stdio mode.
352 self.context_server_factory.as_ref().is_some() || configuration.command.is_some()
353 }
354
355 fn create_context_server(
356 &self,
357 id: ContextServerId,
358 configuration: Arc<ContextServerConfiguration>,
359 ) -> Result<Arc<ContextServer>> {
360 if let Some(factory) = self.context_server_factory.as_ref() {
361 Ok(factory(id, configuration))
362 } else {
363 let command = configuration
364 .command
365 .clone()
366 .context("Missing command to run context server")?;
367 Ok(Arc::new(ContextServer::stdio(id, command)))
368 }
369 }
370
371 fn update_server_state(
372 &mut self,
373 id: ContextServerId,
374 state: ContextServerState,
375 cx: &mut Context<Self>,
376 ) {
377 let status = ContextServerStatus::from_state(&state);
378 self.servers.insert(id.clone(), state);
379 cx.emit(Event::ServerStatusChanged {
380 server_id: id,
381 status,
382 });
383 }
384
385 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
386 if self.update_servers_task.is_some() {
387 self.needs_server_update = true;
388 } else {
389 self.needs_server_update = false;
390 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
391 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
392 log::error!("Error maintaining context servers: {}", err);
393 }
394
395 this.update(cx, |this, cx| {
396 this.update_servers_task.take();
397 if this.needs_server_update {
398 this.available_context_servers_changed(cx);
399 }
400 })?;
401
402 Ok(())
403 }));
404 }
405 }
406
407 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
408 let mut desired_servers = HashMap::default();
409
410 let (registry, worktree_store) = this.update(cx, |this, cx| {
411 let location = this
412 .worktree_store
413 .read(cx)
414 .visible_worktrees(cx)
415 .next()
416 .map(|worktree| settings::SettingsLocation {
417 worktree_id: worktree.read(cx).id(),
418 path: Path::new(""),
419 });
420 let settings = ProjectSettings::get(location, cx);
421 desired_servers = settings.context_servers.clone();
422
423 (this.registry.clone(), this.worktree_store.clone())
424 })?;
425
426 for (id, descriptor) in
427 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
428 {
429 let config = desired_servers.entry(id.clone()).or_default();
430 if config.command.is_none() {
431 if let Some(extension_command) = descriptor
432 .command(worktree_store.clone(), &cx)
433 .await
434 .log_err()
435 {
436 config.command = Some(extension_command);
437 }
438 }
439 }
440
441 this.update(cx, |this, _| {
442 // Filter out configurations without commands, the user uninstalled an extension.
443 desired_servers.retain(|_, configuration| this.is_configuration_valid(configuration));
444 })?;
445
446 let mut servers_to_start = Vec::new();
447 let mut servers_to_remove = HashSet::default();
448 let mut servers_to_stop = HashSet::default();
449
450 this.update(cx, |this, _cx| {
451 for server_id in this.servers.keys() {
452 // All servers that are not in desired_servers should be removed from the store.
453 // E.g. this can happen if the user removed a server from the configuration,
454 // or the user uninstalled an extension.
455 if !desired_servers.contains_key(&server_id.0) {
456 servers_to_remove.insert(server_id.clone());
457 }
458 }
459
460 for (id, config) in desired_servers {
461 let id = ContextServerId(id.clone());
462
463 let existing_config = this.servers.get(&id).map(|state| state.configuration());
464 if existing_config.as_deref() != Some(&config) {
465 let config = Arc::new(config);
466 if let Some(server) = this
467 .create_context_server(id.clone(), config.clone())
468 .log_err()
469 {
470 servers_to_start.push((server, config));
471 if this.servers.contains_key(&id) {
472 servers_to_stop.insert(id);
473 }
474 }
475 }
476 }
477 })?;
478
479 for id in servers_to_stop {
480 this.update(cx, |this, cx| this.stop_server(&id, cx).ok())?;
481 }
482
483 for id in servers_to_remove {
484 this.update(cx, |this, cx| this.remove_server(&id, cx).ok())?;
485 }
486
487 for (server, config) in servers_to_start {
488 this.update(cx, |this, cx| this.run_server(server, config, cx))
489 .log_err();
490 }
491
492 Ok(())
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use crate::{FakeFs, Project, project_settings::ProjectSettings};
500 use context_server::{
501 transport::Transport,
502 types::{
503 self, Implementation, InitializeResponse, ProtocolVersion, RequestType,
504 ServerCapabilities,
505 },
506 };
507 use futures::{Stream, StreamExt as _, lock::Mutex};
508 use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _};
509 use serde_json::json;
510 use std::{cell::RefCell, pin::Pin, rc::Rc};
511 use util::path;
512
513 #[gpui::test]
514 async fn test_context_server_status(cx: &mut TestAppContext) {
515 const SERVER_1_ID: &'static str = "mcp-1";
516 const SERVER_2_ID: &'static str = "mcp-2";
517
518 let (_fs, project) = setup_context_server_test(
519 cx,
520 json!({"code.rs": ""}),
521 vec![
522 (SERVER_1_ID.into(), ContextServerConfiguration::default()),
523 (SERVER_2_ID.into(), ContextServerConfiguration::default()),
524 ],
525 )
526 .await;
527
528 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
529 let store = cx.new(|cx| {
530 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
531 });
532
533 let server_1_id = ContextServerId("mcp-1".into());
534 let server_2_id = ContextServerId("mcp-2".into());
535
536 let transport_1 =
537 Arc::new(FakeTransport::new(
538 cx.executor(),
539 |_, request_type, _| match request_type {
540 Some(RequestType::Initialize) => {
541 Some(create_initialize_response("mcp-1".to_string()))
542 }
543 _ => None,
544 },
545 ));
546
547 let transport_2 =
548 Arc::new(FakeTransport::new(
549 cx.executor(),
550 |_, request_type, _| match request_type {
551 Some(RequestType::Initialize) => {
552 Some(create_initialize_response("mcp-2".to_string()))
553 }
554 _ => None,
555 },
556 ));
557
558 let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
559 let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
560
561 store
562 .update(cx, |store, cx| store.start_server(server_1, cx))
563 .unwrap();
564
565 cx.run_until_parked();
566
567 cx.update(|cx| {
568 assert_eq!(
569 store.read(cx).status_for_server(&server_1_id),
570 Some(ContextServerStatus::Running)
571 );
572 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
573 });
574
575 store
576 .update(cx, |store, cx| store.start_server(server_2.clone(), cx))
577 .unwrap();
578
579 cx.run_until_parked();
580
581 cx.update(|cx| {
582 assert_eq!(
583 store.read(cx).status_for_server(&server_1_id),
584 Some(ContextServerStatus::Running)
585 );
586 assert_eq!(
587 store.read(cx).status_for_server(&server_2_id),
588 Some(ContextServerStatus::Running)
589 );
590 });
591
592 store
593 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
594 .unwrap();
595
596 cx.update(|cx| {
597 assert_eq!(
598 store.read(cx).status_for_server(&server_1_id),
599 Some(ContextServerStatus::Running)
600 );
601 assert_eq!(
602 store.read(cx).status_for_server(&server_2_id),
603 Some(ContextServerStatus::Stopped)
604 );
605 });
606 }
607
608 #[gpui::test]
609 async fn test_context_server_status_events(cx: &mut TestAppContext) {
610 const SERVER_1_ID: &'static str = "mcp-1";
611 const SERVER_2_ID: &'static str = "mcp-2";
612
613 let (_fs, project) = setup_context_server_test(
614 cx,
615 json!({"code.rs": ""}),
616 vec![
617 (SERVER_1_ID.into(), ContextServerConfiguration::default()),
618 (SERVER_2_ID.into(), ContextServerConfiguration::default()),
619 ],
620 )
621 .await;
622
623 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
624 let store = cx.new(|cx| {
625 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
626 });
627
628 let server_1_id = ContextServerId("mcp-1".into());
629 let server_2_id = ContextServerId("mcp-2".into());
630
631 let transport_1 =
632 Arc::new(FakeTransport::new(
633 cx.executor(),
634 |_, request_type, _| match request_type {
635 Some(RequestType::Initialize) => {
636 Some(create_initialize_response("mcp-1".to_string()))
637 }
638 _ => None,
639 },
640 ));
641
642 let transport_2 =
643 Arc::new(FakeTransport::new(
644 cx.executor(),
645 |_, request_type, _| match request_type {
646 Some(RequestType::Initialize) => {
647 Some(create_initialize_response("mcp-2".to_string()))
648 }
649 _ => None,
650 },
651 ));
652
653 let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
654 let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
655
656 let _server_events = assert_server_events(
657 &store,
658 vec![
659 (server_1_id.clone(), ContextServerStatus::Starting),
660 (server_1_id.clone(), ContextServerStatus::Running),
661 (server_2_id.clone(), ContextServerStatus::Starting),
662 (server_2_id.clone(), ContextServerStatus::Running),
663 (server_2_id.clone(), ContextServerStatus::Stopped),
664 ],
665 cx,
666 );
667
668 store
669 .update(cx, |store, cx| store.start_server(server_1, cx))
670 .unwrap();
671
672 cx.run_until_parked();
673
674 store
675 .update(cx, |store, cx| store.start_server(server_2.clone(), cx))
676 .unwrap();
677
678 cx.run_until_parked();
679
680 store
681 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
682 .unwrap();
683 }
684
685 #[gpui::test(iterations = 25)]
686 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
687 const SERVER_1_ID: &'static str = "mcp-1";
688
689 let (_fs, project) = setup_context_server_test(
690 cx,
691 json!({"code.rs": ""}),
692 vec![(SERVER_1_ID.into(), ContextServerConfiguration::default())],
693 )
694 .await;
695
696 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
697 let store = cx.new(|cx| {
698 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
699 });
700
701 let server_id = ContextServerId(SERVER_1_ID.into());
702
703 let transport_1 =
704 Arc::new(FakeTransport::new(
705 cx.executor(),
706 |_, request_type, _| match request_type {
707 Some(RequestType::Initialize) => {
708 Some(create_initialize_response(SERVER_1_ID.to_string()))
709 }
710 _ => None,
711 },
712 ));
713
714 let transport_2 =
715 Arc::new(FakeTransport::new(
716 cx.executor(),
717 |_, request_type, _| match request_type {
718 Some(RequestType::Initialize) => {
719 Some(create_initialize_response(SERVER_1_ID.to_string()))
720 }
721 _ => None,
722 },
723 ));
724
725 let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1));
726 let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2));
727
728 // If we start another server with the same id, we should report that we stopped the previous one
729 let _server_events = assert_server_events(
730 &store,
731 vec![
732 (server_id.clone(), ContextServerStatus::Starting),
733 (server_id.clone(), ContextServerStatus::Stopped),
734 (server_id.clone(), ContextServerStatus::Starting),
735 (server_id.clone(), ContextServerStatus::Running),
736 ],
737 cx,
738 );
739
740 store
741 .update(cx, |store, cx| {
742 store.start_server(server_with_same_id_1.clone(), cx)
743 })
744 .unwrap();
745 store
746 .update(cx, |store, cx| {
747 store.start_server(server_with_same_id_2.clone(), cx)
748 })
749 .unwrap();
750 cx.update(|cx| {
751 assert_eq!(
752 store.read(cx).status_for_server(&server_id),
753 Some(ContextServerStatus::Starting)
754 );
755 });
756
757 cx.run_until_parked();
758
759 cx.update(|cx| {
760 assert_eq!(
761 store.read(cx).status_for_server(&server_id),
762 Some(ContextServerStatus::Running)
763 );
764 });
765 }
766
767 #[gpui::test]
768 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
769 const SERVER_1_ID: &'static str = "mcp-1";
770 const SERVER_2_ID: &'static str = "mcp-2";
771
772 let server_1_id = ContextServerId(SERVER_1_ID.into());
773 let server_2_id = ContextServerId(SERVER_2_ID.into());
774
775 let (_fs, project) = setup_context_server_test(
776 cx,
777 json!({"code.rs": ""}),
778 vec![(
779 SERVER_1_ID.into(),
780 ContextServerConfiguration {
781 command: None,
782 settings: Some(json!({
783 "somevalue": true
784 })),
785 },
786 )],
787 )
788 .await;
789
790 let executor = cx.executor();
791 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
792 let store = cx.new(|cx| {
793 ContextServerStore::test_maintain_server_loop(
794 Box::new(move |id, _| {
795 let transport = FakeTransport::new(executor.clone(), {
796 let id = id.0.clone();
797 move |_, request_type, _| match request_type {
798 Some(RequestType::Initialize) => {
799 Some(create_initialize_response(id.clone().to_string()))
800 }
801 _ => None,
802 }
803 });
804 Arc::new(ContextServer::new(id.clone(), Arc::new(transport)))
805 }),
806 registry.clone(),
807 project.read(cx).worktree_store(),
808 cx,
809 )
810 });
811
812 // Ensure that mcp-1 starts up
813 {
814 let _server_events = assert_server_events(
815 &store,
816 vec![
817 (server_1_id.clone(), ContextServerStatus::Starting),
818 (server_1_id.clone(), ContextServerStatus::Running),
819 ],
820 cx,
821 );
822 cx.run_until_parked();
823 }
824
825 // Ensure that mcp-1 is restarted when the configuration was changed
826 {
827 let _server_events = assert_server_events(
828 &store,
829 vec![
830 (server_1_id.clone(), ContextServerStatus::Stopped),
831 (server_1_id.clone(), ContextServerStatus::Starting),
832 (server_1_id.clone(), ContextServerStatus::Running),
833 ],
834 cx,
835 );
836 set_context_server_configuration(
837 vec![(
838 server_1_id.0.clone(),
839 ContextServerConfiguration {
840 command: None,
841 settings: Some(json!({
842 "somevalue": false
843 })),
844 },
845 )],
846 cx,
847 );
848
849 cx.run_until_parked();
850 }
851
852 // Ensure that mcp-1 is not restarted when the configuration was not changed
853 {
854 let _server_events = assert_server_events(&store, vec![], cx);
855 set_context_server_configuration(
856 vec![(
857 server_1_id.0.clone(),
858 ContextServerConfiguration {
859 command: None,
860 settings: Some(json!({
861 "somevalue": false
862 })),
863 },
864 )],
865 cx,
866 );
867
868 cx.run_until_parked();
869 }
870
871 // Ensure that mcp-2 is started once it is added to the settings
872 {
873 let _server_events = assert_server_events(
874 &store,
875 vec![
876 (server_2_id.clone(), ContextServerStatus::Starting),
877 (server_2_id.clone(), ContextServerStatus::Running),
878 ],
879 cx,
880 );
881 set_context_server_configuration(
882 vec![
883 (
884 server_1_id.0.clone(),
885 ContextServerConfiguration {
886 command: None,
887 settings: Some(json!({
888 "somevalue": false
889 })),
890 },
891 ),
892 (
893 server_2_id.0.clone(),
894 ContextServerConfiguration {
895 command: None,
896 settings: Some(json!({
897 "somevalue": true
898 })),
899 },
900 ),
901 ],
902 cx,
903 );
904
905 cx.run_until_parked();
906 }
907
908 // Ensure that mcp-2 is removed once it is removed from the settings
909 {
910 let _server_events = assert_server_events(
911 &store,
912 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
913 cx,
914 );
915 set_context_server_configuration(
916 vec![(
917 server_1_id.0.clone(),
918 ContextServerConfiguration {
919 command: None,
920 settings: Some(json!({
921 "somevalue": false
922 })),
923 },
924 )],
925 cx,
926 );
927
928 cx.run_until_parked();
929
930 cx.update(|cx| {
931 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
932 });
933 }
934 }
935
936 fn set_context_server_configuration(
937 context_servers: Vec<(Arc<str>, ContextServerConfiguration)>,
938 cx: &mut TestAppContext,
939 ) {
940 cx.update(|cx| {
941 SettingsStore::update_global(cx, |store, cx| {
942 let mut settings = ProjectSettings::default();
943 for (id, config) in context_servers {
944 settings.context_servers.insert(id, config);
945 }
946 store
947 .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
948 .unwrap();
949 })
950 });
951 }
952
953 struct ServerEvents {
954 received_event_count: Rc<RefCell<usize>>,
955 expected_event_count: usize,
956 _subscription: Subscription,
957 }
958
959 impl Drop for ServerEvents {
960 fn drop(&mut self) {
961 let actual_event_count = *self.received_event_count.borrow();
962 assert_eq!(
963 actual_event_count, self.expected_event_count,
964 "
965 Expected to receive {} context server store events, but received {} events",
966 self.expected_event_count, actual_event_count
967 );
968 }
969 }
970
971 fn assert_server_events(
972 store: &Entity<ContextServerStore>,
973 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
974 cx: &mut TestAppContext,
975 ) -> ServerEvents {
976 cx.update(|cx| {
977 let mut ix = 0;
978 let received_event_count = Rc::new(RefCell::new(0));
979 let expected_event_count = expected_events.len();
980 let subscription = cx.subscribe(store, {
981 let received_event_count = received_event_count.clone();
982 move |_, event, _| match event {
983 Event::ServerStatusChanged {
984 server_id: actual_server_id,
985 status: actual_status,
986 } => {
987 let (expected_server_id, expected_status) = &expected_events[ix];
988
989 assert_eq!(
990 actual_server_id, expected_server_id,
991 "Expected different server id at index {}",
992 ix
993 );
994 assert_eq!(
995 actual_status, expected_status,
996 "Expected different status at index {}",
997 ix
998 );
999 ix += 1;
1000 *received_event_count.borrow_mut() += 1;
1001 }
1002 }
1003 });
1004 ServerEvents {
1005 expected_event_count,
1006 received_event_count,
1007 _subscription: subscription,
1008 }
1009 })
1010 }
1011
1012 async fn setup_context_server_test(
1013 cx: &mut TestAppContext,
1014 files: serde_json::Value,
1015 context_server_configurations: Vec<(Arc<str>, ContextServerConfiguration)>,
1016 ) -> (Arc<FakeFs>, Entity<Project>) {
1017 cx.update(|cx| {
1018 let settings_store = SettingsStore::test(cx);
1019 cx.set_global(settings_store);
1020 Project::init_settings(cx);
1021 let mut settings = ProjectSettings::get_global(cx).clone();
1022 for (id, config) in context_server_configurations {
1023 settings.context_servers.insert(id, config);
1024 }
1025 ProjectSettings::override_global(settings, cx);
1026 });
1027
1028 let fs = FakeFs::new(cx.executor());
1029 fs.insert_tree(path!("/test"), files).await;
1030 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1031
1032 (fs, project)
1033 }
1034
1035 fn create_initialize_response(server_name: String) -> serde_json::Value {
1036 serde_json::to_value(&InitializeResponse {
1037 protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
1038 server_info: Implementation {
1039 name: server_name,
1040 version: "1.0.0".to_string(),
1041 },
1042 capabilities: ServerCapabilities::default(),
1043 meta: None,
1044 })
1045 .unwrap()
1046 }
1047
1048 struct FakeTransport {
1049 on_request: Arc<
1050 dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
1051 + Send
1052 + Sync,
1053 >,
1054 tx: futures::channel::mpsc::UnboundedSender<String>,
1055 rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
1056 executor: BackgroundExecutor,
1057 }
1058
1059 impl FakeTransport {
1060 fn new(
1061 executor: BackgroundExecutor,
1062 on_request: impl Fn(
1063 u64,
1064 Option<RequestType>,
1065 serde_json::Value,
1066 ) -> Option<serde_json::Value>
1067 + 'static
1068 + Send
1069 + Sync,
1070 ) -> Self {
1071 let (tx, rx) = futures::channel::mpsc::unbounded();
1072 Self {
1073 on_request: Arc::new(on_request),
1074 tx,
1075 rx: Arc::new(Mutex::new(rx)),
1076 executor,
1077 }
1078 }
1079 }
1080
1081 #[async_trait::async_trait]
1082 impl Transport for FakeTransport {
1083 async fn send(&self, message: String) -> Result<()> {
1084 if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
1085 let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
1086
1087 if let Some(method) = msg.get("method") {
1088 let request_type = method
1089 .as_str()
1090 .and_then(|method| types::RequestType::try_from(method).ok());
1091 if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
1092 let response = serde_json::json!({
1093 "jsonrpc": "2.0",
1094 "id": id,
1095 "result": payload
1096 });
1097
1098 self.tx
1099 .unbounded_send(response.to_string())
1100 .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
1101 }
1102 }
1103 }
1104 Ok(())
1105 }
1106
1107 fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
1108 let rx = self.rx.clone();
1109 let executor = self.executor.clone();
1110 Box::pin(futures::stream::unfold(rx, move |rx| {
1111 let executor = executor.clone();
1112 async move {
1113 let mut rx_guard = rx.lock().await;
1114 executor.simulate_random_delay().await;
1115 if let Some(message) = rx_guard.next().await {
1116 drop(rx_guard);
1117 Some((message, rx))
1118 } else {
1119 None
1120 }
1121 }
1122 }))
1123 }
1124
1125 fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
1126 Box::pin(futures::stream::empty())
1127 }
1128 }
1129}