1pub mod extension;
2pub mod registry;
3
4use std::{path::Path, sync::Arc};
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerId};
9use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
10use registry::ContextServerDescriptorRegistry;
11use settings::{Settings as _, SettingsStore};
12use util::ResultExt as _;
13
14use crate::{
15 project_settings::{ContextServerConfiguration, ProjectSettings},
16 worktree_store::WorktreeStore,
17};
18
19pub fn init(cx: &mut App) {
20 extension::init(cx);
21}
22
23actions!(context_server, [Restart]);
24
25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
26pub enum ContextServerStatus {
27 Starting,
28 Running,
29 Stopped,
30 Error(Arc<str>),
31}
32
33impl ContextServerStatus {
34 fn from_state(state: &ContextServerState) -> Self {
35 match state {
36 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
37 ContextServerState::Running { .. } => ContextServerStatus::Running,
38 ContextServerState::Stopped { error, .. } => {
39 if let Some(error) = error {
40 ContextServerStatus::Error(error.clone())
41 } else {
42 ContextServerStatus::Stopped
43 }
44 }
45 }
46 }
47}
48
49enum ContextServerState {
50 Starting {
51 server: Arc<ContextServer>,
52 configuration: Arc<ContextServerConfiguration>,
53 _task: Task<()>,
54 },
55 Running {
56 server: Arc<ContextServer>,
57 configuration: Arc<ContextServerConfiguration>,
58 },
59 Stopped {
60 server: Arc<ContextServer>,
61 configuration: Arc<ContextServerConfiguration>,
62 error: Option<Arc<str>>,
63 },
64}
65
66impl ContextServerState {
67 pub fn server(&self) -> Arc<ContextServer> {
68 match self {
69 ContextServerState::Starting { server, .. } => server.clone(),
70 ContextServerState::Running { server, .. } => server.clone(),
71 ContextServerState::Stopped { server, .. } => server.clone(),
72 }
73 }
74
75 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
76 match self {
77 ContextServerState::Starting { configuration, .. } => configuration.clone(),
78 ContextServerState::Running { configuration, .. } => configuration.clone(),
79 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
80 }
81 }
82}
83
84pub type ContextServerFactory =
85 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
86
87pub struct ContextServerStore {
88 servers: HashMap<ContextServerId, ContextServerState>,
89 worktree_store: Entity<WorktreeStore>,
90 registry: Entity<ContextServerDescriptorRegistry>,
91 update_servers_task: Option<Task<Result<()>>>,
92 context_server_factory: Option<ContextServerFactory>,
93 needs_server_update: bool,
94 _subscriptions: Vec<Subscription>,
95}
96
97pub enum Event {
98 ServerStatusChanged {
99 server_id: ContextServerId,
100 status: ContextServerStatus,
101 },
102}
103
104impl EventEmitter<Event> for ContextServerStore {}
105
106impl ContextServerStore {
107 pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
108 Self::new_internal(
109 true,
110 None,
111 ContextServerDescriptorRegistry::default_global(cx),
112 worktree_store,
113 cx,
114 )
115 }
116
117 #[cfg(any(test, feature = "test-support"))]
118 pub fn test(
119 registry: Entity<ContextServerDescriptorRegistry>,
120 worktree_store: Entity<WorktreeStore>,
121 cx: &mut Context<Self>,
122 ) -> Self {
123 Self::new_internal(false, None, registry, worktree_store, cx)
124 }
125
126 #[cfg(any(test, feature = "test-support"))]
127 pub fn test_maintain_server_loop(
128 context_server_factory: ContextServerFactory,
129 registry: Entity<ContextServerDescriptorRegistry>,
130 worktree_store: Entity<WorktreeStore>,
131 cx: &mut Context<Self>,
132 ) -> Self {
133 Self::new_internal(
134 true,
135 Some(context_server_factory),
136 registry,
137 worktree_store,
138 cx,
139 )
140 }
141
142 fn new_internal(
143 maintain_server_loop: bool,
144 context_server_factory: Option<ContextServerFactory>,
145 registry: Entity<ContextServerDescriptorRegistry>,
146 worktree_store: Entity<WorktreeStore>,
147 cx: &mut Context<Self>,
148 ) -> Self {
149 let subscriptions = if maintain_server_loop {
150 vec![
151 cx.observe(®istry, |this, _registry, cx| {
152 this.available_context_servers_changed(cx);
153 }),
154 cx.observe_global::<SettingsStore>(|this, cx| {
155 this.available_context_servers_changed(cx);
156 }),
157 ]
158 } else {
159 Vec::new()
160 };
161
162 let mut this = Self {
163 _subscriptions: subscriptions,
164 worktree_store,
165 registry,
166 needs_server_update: false,
167 servers: HashMap::default(),
168 update_servers_task: None,
169 context_server_factory,
170 };
171 if maintain_server_loop {
172 this.available_context_servers_changed(cx);
173 }
174 this
175 }
176
177 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
178 self.servers.get(id).map(|state| state.server())
179 }
180
181 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
182 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
183 Some(server.clone())
184 } else {
185 None
186 }
187 }
188
189 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
190 self.servers.get(id).map(ContextServerStatus::from_state)
191 }
192
193 pub fn all_server_ids(&self) -> Vec<ContextServerId> {
194 self.servers.keys().cloned().collect()
195 }
196
197 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
198 self.servers
199 .values()
200 .filter_map(|state| {
201 if let ContextServerState::Running { server, .. } = state {
202 Some(server.clone())
203 } else {
204 None
205 }
206 })
207 .collect()
208 }
209
210 pub fn start_server(
211 &mut self,
212 server: Arc<ContextServer>,
213 cx: &mut Context<Self>,
214 ) -> Result<()> {
215 let location = self
216 .worktree_store
217 .read(cx)
218 .visible_worktrees(cx)
219 .next()
220 .map(|worktree| settings::SettingsLocation {
221 worktree_id: worktree.read(cx).id(),
222 path: Path::new(""),
223 });
224 let settings = ProjectSettings::get(location, cx);
225 let configuration = settings
226 .context_servers
227 .get(&server.id().0)
228 .context("Failed to load context server configuration from settings")?
229 .clone();
230
231 self.run_server(server, Arc::new(configuration), cx);
232 Ok(())
233 }
234
235 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
236 let state = self
237 .servers
238 .remove(id)
239 .context("Context server not found")?;
240
241 let server = state.server();
242 let configuration = state.configuration();
243 let mut result = Ok(());
244 if let ContextServerState::Running { server, .. } = &state {
245 result = server.stop();
246 }
247 drop(state);
248
249 self.update_server_state(
250 id.clone(),
251 ContextServerState::Stopped {
252 configuration,
253 server,
254 error: None,
255 },
256 cx,
257 );
258
259 result
260 }
261
262 pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
263 if let Some(state) = self.servers.get(&id) {
264 let configuration = state.configuration();
265
266 self.stop_server(&state.server().id(), cx)?;
267 let new_server = self.create_context_server(id.clone(), configuration.clone())?;
268 self.run_server(new_server, configuration, cx);
269 }
270 Ok(())
271 }
272
273 fn run_server(
274 &mut self,
275 server: Arc<ContextServer>,
276 configuration: Arc<ContextServerConfiguration>,
277 cx: &mut Context<Self>,
278 ) {
279 let id = server.id();
280 if matches!(
281 self.servers.get(&id),
282 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
283 ) {
284 self.stop_server(&id, cx).log_err();
285 }
286
287 let task = cx.spawn({
288 let id = server.id();
289 let server = server.clone();
290 let configuration = configuration.clone();
291 async move |this, cx| {
292 match server.clone().start(&cx).await {
293 Ok(_) => {
294 log::info!("Started {} context server", id);
295 debug_assert!(server.client().is_some());
296
297 this.update(cx, |this, cx| {
298 this.update_server_state(
299 id.clone(),
300 ContextServerState::Running {
301 server,
302 configuration,
303 },
304 cx,
305 )
306 })
307 .log_err()
308 }
309 Err(err) => {
310 log::error!("{} context server failed to start: {}", id, err);
311 this.update(cx, |this, cx| {
312 this.update_server_state(
313 id.clone(),
314 ContextServerState::Stopped {
315 configuration,
316 server,
317 error: Some(err.to_string().into()),
318 },
319 cx,
320 )
321 })
322 .log_err()
323 }
324 };
325 }
326 });
327
328 self.update_server_state(
329 id.clone(),
330 ContextServerState::Starting {
331 configuration,
332 _task: task,
333 server,
334 },
335 cx,
336 );
337 }
338
339 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
340 let state = self
341 .servers
342 .remove(id)
343 .context("Context server not found")?;
344 drop(state);
345 cx.emit(Event::ServerStatusChanged {
346 server_id: id.clone(),
347 status: ContextServerStatus::Stopped,
348 });
349 Ok(())
350 }
351
352 fn is_configuration_valid(&self, configuration: &ContextServerConfiguration) -> bool {
353 // Command must be some when we are running in stdio mode.
354 self.context_server_factory.as_ref().is_some() || configuration.command.is_some()
355 }
356
357 fn create_context_server(
358 &self,
359 id: ContextServerId,
360 configuration: Arc<ContextServerConfiguration>,
361 ) -> Result<Arc<ContextServer>> {
362 if let Some(factory) = self.context_server_factory.as_ref() {
363 Ok(factory(id, configuration))
364 } else {
365 let command = configuration
366 .command
367 .clone()
368 .context("Missing command to run context server")?;
369 Ok(Arc::new(ContextServer::stdio(id, command)))
370 }
371 }
372
373 fn update_server_state(
374 &mut self,
375 id: ContextServerId,
376 state: ContextServerState,
377 cx: &mut Context<Self>,
378 ) {
379 let status = ContextServerStatus::from_state(&state);
380 self.servers.insert(id.clone(), state);
381 cx.emit(Event::ServerStatusChanged {
382 server_id: id,
383 status,
384 });
385 }
386
387 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
388 if self.update_servers_task.is_some() {
389 self.needs_server_update = true;
390 } else {
391 self.needs_server_update = false;
392 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
393 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
394 log::error!("Error maintaining context servers: {}", err);
395 }
396
397 this.update(cx, |this, cx| {
398 this.update_servers_task.take();
399 if this.needs_server_update {
400 this.available_context_servers_changed(cx);
401 }
402 })?;
403
404 Ok(())
405 }));
406 }
407 }
408
409 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
410 let mut desired_servers = HashMap::default();
411
412 let (registry, worktree_store) = this.update(cx, |this, cx| {
413 let location = this
414 .worktree_store
415 .read(cx)
416 .visible_worktrees(cx)
417 .next()
418 .map(|worktree| settings::SettingsLocation {
419 worktree_id: worktree.read(cx).id(),
420 path: Path::new(""),
421 });
422 let settings = ProjectSettings::get(location, cx);
423 desired_servers = settings.context_servers.clone();
424
425 (this.registry.clone(), this.worktree_store.clone())
426 })?;
427
428 for (id, descriptor) in
429 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
430 {
431 let config = desired_servers.entry(id.clone()).or_default();
432 if config.command.is_none() {
433 if let Some(extension_command) = descriptor
434 .command(worktree_store.clone(), &cx)
435 .await
436 .log_err()
437 {
438 config.command = Some(extension_command);
439 }
440 }
441 }
442
443 this.update(cx, |this, _| {
444 // Filter out configurations without commands, the user uninstalled an extension.
445 desired_servers.retain(|_, configuration| this.is_configuration_valid(configuration));
446 })?;
447
448 let mut servers_to_start = Vec::new();
449 let mut servers_to_remove = HashSet::default();
450 let mut servers_to_stop = HashSet::default();
451
452 this.update(cx, |this, _cx| {
453 for server_id in this.servers.keys() {
454 // All servers that are not in desired_servers should be removed from the store.
455 // E.g. this can happen if the user removed a server from the configuration,
456 // or the user uninstalled an extension.
457 if !desired_servers.contains_key(&server_id.0) {
458 servers_to_remove.insert(server_id.clone());
459 }
460 }
461
462 for (id, config) in desired_servers {
463 let id = ContextServerId(id.clone());
464
465 let existing_config = this.servers.get(&id).map(|state| state.configuration());
466 if existing_config.as_deref() != Some(&config) {
467 let config = Arc::new(config);
468 if let Some(server) = this
469 .create_context_server(id.clone(), config.clone())
470 .log_err()
471 {
472 servers_to_start.push((server, config));
473 if this.servers.contains_key(&id) {
474 servers_to_stop.insert(id);
475 }
476 }
477 }
478 }
479 })?;
480
481 for id in servers_to_stop {
482 this.update(cx, |this, cx| this.stop_server(&id, cx).ok())?;
483 }
484
485 for id in servers_to_remove {
486 this.update(cx, |this, cx| this.remove_server(&id, cx).ok())?;
487 }
488
489 for (server, config) in servers_to_start {
490 this.update(cx, |this, cx| this.run_server(server, config, cx))
491 .log_err();
492 }
493
494 Ok(())
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use crate::{FakeFs, Project, project_settings::ProjectSettings};
502 use context_server::{
503 transport::Transport,
504 types::{
505 self, Implementation, InitializeResponse, ProtocolVersion, RequestType,
506 ServerCapabilities,
507 },
508 };
509 use futures::{Stream, StreamExt as _, lock::Mutex};
510 use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _};
511 use serde_json::json;
512 use std::{cell::RefCell, pin::Pin, rc::Rc};
513 use util::path;
514
515 #[gpui::test]
516 async fn test_context_server_status(cx: &mut TestAppContext) {
517 const SERVER_1_ID: &'static str = "mcp-1";
518 const SERVER_2_ID: &'static str = "mcp-2";
519
520 let (_fs, project) = setup_context_server_test(
521 cx,
522 json!({"code.rs": ""}),
523 vec![
524 (SERVER_1_ID.into(), ContextServerConfiguration::default()),
525 (SERVER_2_ID.into(), ContextServerConfiguration::default()),
526 ],
527 )
528 .await;
529
530 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
531 let store = cx.new(|cx| {
532 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
533 });
534
535 let server_1_id = ContextServerId("mcp-1".into());
536 let server_2_id = ContextServerId("mcp-2".into());
537
538 let transport_1 =
539 Arc::new(FakeTransport::new(
540 cx.executor(),
541 |_, request_type, _| match request_type {
542 Some(RequestType::Initialize) => {
543 Some(create_initialize_response("mcp-1".to_string()))
544 }
545 _ => None,
546 },
547 ));
548
549 let transport_2 =
550 Arc::new(FakeTransport::new(
551 cx.executor(),
552 |_, request_type, _| match request_type {
553 Some(RequestType::Initialize) => {
554 Some(create_initialize_response("mcp-2".to_string()))
555 }
556 _ => None,
557 },
558 ));
559
560 let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
561 let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
562
563 store
564 .update(cx, |store, cx| store.start_server(server_1, cx))
565 .unwrap();
566
567 cx.run_until_parked();
568
569 cx.update(|cx| {
570 assert_eq!(
571 store.read(cx).status_for_server(&server_1_id),
572 Some(ContextServerStatus::Running)
573 );
574 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
575 });
576
577 store
578 .update(cx, |store, cx| store.start_server(server_2.clone(), cx))
579 .unwrap();
580
581 cx.run_until_parked();
582
583 cx.update(|cx| {
584 assert_eq!(
585 store.read(cx).status_for_server(&server_1_id),
586 Some(ContextServerStatus::Running)
587 );
588 assert_eq!(
589 store.read(cx).status_for_server(&server_2_id),
590 Some(ContextServerStatus::Running)
591 );
592 });
593
594 store
595 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
596 .unwrap();
597
598 cx.update(|cx| {
599 assert_eq!(
600 store.read(cx).status_for_server(&server_1_id),
601 Some(ContextServerStatus::Running)
602 );
603 assert_eq!(
604 store.read(cx).status_for_server(&server_2_id),
605 Some(ContextServerStatus::Stopped)
606 );
607 });
608 }
609
610 #[gpui::test]
611 async fn test_context_server_status_events(cx: &mut TestAppContext) {
612 const SERVER_1_ID: &'static str = "mcp-1";
613 const SERVER_2_ID: &'static str = "mcp-2";
614
615 let (_fs, project) = setup_context_server_test(
616 cx,
617 json!({"code.rs": ""}),
618 vec![
619 (SERVER_1_ID.into(), ContextServerConfiguration::default()),
620 (SERVER_2_ID.into(), ContextServerConfiguration::default()),
621 ],
622 )
623 .await;
624
625 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
626 let store = cx.new(|cx| {
627 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
628 });
629
630 let server_1_id = ContextServerId("mcp-1".into());
631 let server_2_id = ContextServerId("mcp-2".into());
632
633 let transport_1 =
634 Arc::new(FakeTransport::new(
635 cx.executor(),
636 |_, request_type, _| match request_type {
637 Some(RequestType::Initialize) => {
638 Some(create_initialize_response("mcp-1".to_string()))
639 }
640 _ => None,
641 },
642 ));
643
644 let transport_2 =
645 Arc::new(FakeTransport::new(
646 cx.executor(),
647 |_, request_type, _| match request_type {
648 Some(RequestType::Initialize) => {
649 Some(create_initialize_response("mcp-2".to_string()))
650 }
651 _ => None,
652 },
653 ));
654
655 let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
656 let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
657
658 let _server_events = assert_server_events(
659 &store,
660 vec![
661 (server_1_id.clone(), ContextServerStatus::Starting),
662 (server_1_id.clone(), ContextServerStatus::Running),
663 (server_2_id.clone(), ContextServerStatus::Starting),
664 (server_2_id.clone(), ContextServerStatus::Running),
665 (server_2_id.clone(), ContextServerStatus::Stopped),
666 ],
667 cx,
668 );
669
670 store
671 .update(cx, |store, cx| store.start_server(server_1, cx))
672 .unwrap();
673
674 cx.run_until_parked();
675
676 store
677 .update(cx, |store, cx| store.start_server(server_2.clone(), cx))
678 .unwrap();
679
680 cx.run_until_parked();
681
682 store
683 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
684 .unwrap();
685 }
686
687 #[gpui::test(iterations = 25)]
688 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
689 const SERVER_1_ID: &'static str = "mcp-1";
690
691 let (_fs, project) = setup_context_server_test(
692 cx,
693 json!({"code.rs": ""}),
694 vec![(SERVER_1_ID.into(), ContextServerConfiguration::default())],
695 )
696 .await;
697
698 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
699 let store = cx.new(|cx| {
700 ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
701 });
702
703 let server_id = ContextServerId(SERVER_1_ID.into());
704
705 let transport_1 =
706 Arc::new(FakeTransport::new(
707 cx.executor(),
708 |_, request_type, _| match request_type {
709 Some(RequestType::Initialize) => {
710 Some(create_initialize_response(SERVER_1_ID.to_string()))
711 }
712 _ => None,
713 },
714 ));
715
716 let transport_2 =
717 Arc::new(FakeTransport::new(
718 cx.executor(),
719 |_, request_type, _| match request_type {
720 Some(RequestType::Initialize) => {
721 Some(create_initialize_response(SERVER_1_ID.to_string()))
722 }
723 _ => None,
724 },
725 ));
726
727 let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1));
728 let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2));
729
730 // If we start another server with the same id, we should report that we stopped the previous one
731 let _server_events = assert_server_events(
732 &store,
733 vec![
734 (server_id.clone(), ContextServerStatus::Starting),
735 (server_id.clone(), ContextServerStatus::Stopped),
736 (server_id.clone(), ContextServerStatus::Starting),
737 (server_id.clone(), ContextServerStatus::Running),
738 ],
739 cx,
740 );
741
742 store
743 .update(cx, |store, cx| {
744 store.start_server(server_with_same_id_1.clone(), cx)
745 })
746 .unwrap();
747 store
748 .update(cx, |store, cx| {
749 store.start_server(server_with_same_id_2.clone(), cx)
750 })
751 .unwrap();
752 cx.update(|cx| {
753 assert_eq!(
754 store.read(cx).status_for_server(&server_id),
755 Some(ContextServerStatus::Starting)
756 );
757 });
758
759 cx.run_until_parked();
760
761 cx.update(|cx| {
762 assert_eq!(
763 store.read(cx).status_for_server(&server_id),
764 Some(ContextServerStatus::Running)
765 );
766 });
767 }
768
769 #[gpui::test]
770 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
771 const SERVER_1_ID: &'static str = "mcp-1";
772 const SERVER_2_ID: &'static str = "mcp-2";
773
774 let server_1_id = ContextServerId(SERVER_1_ID.into());
775 let server_2_id = ContextServerId(SERVER_2_ID.into());
776
777 let (_fs, project) = setup_context_server_test(
778 cx,
779 json!({"code.rs": ""}),
780 vec![(
781 SERVER_1_ID.into(),
782 ContextServerConfiguration {
783 command: None,
784 settings: Some(json!({
785 "somevalue": true
786 })),
787 },
788 )],
789 )
790 .await;
791
792 let executor = cx.executor();
793 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
794 let store = cx.new(|cx| {
795 ContextServerStore::test_maintain_server_loop(
796 Box::new(move |id, _| {
797 let transport = FakeTransport::new(executor.clone(), {
798 let id = id.0.clone();
799 move |_, request_type, _| match request_type {
800 Some(RequestType::Initialize) => {
801 Some(create_initialize_response(id.clone().to_string()))
802 }
803 _ => None,
804 }
805 });
806 Arc::new(ContextServer::new(id.clone(), Arc::new(transport)))
807 }),
808 registry.clone(),
809 project.read(cx).worktree_store(),
810 cx,
811 )
812 });
813
814 // Ensure that mcp-1 starts up
815 {
816 let _server_events = assert_server_events(
817 &store,
818 vec![
819 (server_1_id.clone(), ContextServerStatus::Starting),
820 (server_1_id.clone(), ContextServerStatus::Running),
821 ],
822 cx,
823 );
824 cx.run_until_parked();
825 }
826
827 // Ensure that mcp-1 is restarted when the configuration was changed
828 {
829 let _server_events = assert_server_events(
830 &store,
831 vec![
832 (server_1_id.clone(), ContextServerStatus::Stopped),
833 (server_1_id.clone(), ContextServerStatus::Starting),
834 (server_1_id.clone(), ContextServerStatus::Running),
835 ],
836 cx,
837 );
838 set_context_server_configuration(
839 vec![(
840 server_1_id.0.clone(),
841 ContextServerConfiguration {
842 command: None,
843 settings: Some(json!({
844 "somevalue": false
845 })),
846 },
847 )],
848 cx,
849 );
850
851 cx.run_until_parked();
852 }
853
854 // Ensure that mcp-1 is not restarted when the configuration was not changed
855 {
856 let _server_events = assert_server_events(&store, vec![], cx);
857 set_context_server_configuration(
858 vec![(
859 server_1_id.0.clone(),
860 ContextServerConfiguration {
861 command: None,
862 settings: Some(json!({
863 "somevalue": false
864 })),
865 },
866 )],
867 cx,
868 );
869
870 cx.run_until_parked();
871 }
872
873 // Ensure that mcp-2 is started once it is added to the settings
874 {
875 let _server_events = assert_server_events(
876 &store,
877 vec![
878 (server_2_id.clone(), ContextServerStatus::Starting),
879 (server_2_id.clone(), ContextServerStatus::Running),
880 ],
881 cx,
882 );
883 set_context_server_configuration(
884 vec![
885 (
886 server_1_id.0.clone(),
887 ContextServerConfiguration {
888 command: None,
889 settings: Some(json!({
890 "somevalue": false
891 })),
892 },
893 ),
894 (
895 server_2_id.0.clone(),
896 ContextServerConfiguration {
897 command: None,
898 settings: Some(json!({
899 "somevalue": true
900 })),
901 },
902 ),
903 ],
904 cx,
905 );
906
907 cx.run_until_parked();
908 }
909
910 // Ensure that mcp-2 is removed once it is removed from the settings
911 {
912 let _server_events = assert_server_events(
913 &store,
914 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
915 cx,
916 );
917 set_context_server_configuration(
918 vec![(
919 server_1_id.0.clone(),
920 ContextServerConfiguration {
921 command: None,
922 settings: Some(json!({
923 "somevalue": false
924 })),
925 },
926 )],
927 cx,
928 );
929
930 cx.run_until_parked();
931
932 cx.update(|cx| {
933 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
934 });
935 }
936 }
937
938 fn set_context_server_configuration(
939 context_servers: Vec<(Arc<str>, ContextServerConfiguration)>,
940 cx: &mut TestAppContext,
941 ) {
942 cx.update(|cx| {
943 SettingsStore::update_global(cx, |store, cx| {
944 let mut settings = ProjectSettings::default();
945 for (id, config) in context_servers {
946 settings.context_servers.insert(id, config);
947 }
948 store
949 .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
950 .unwrap();
951 })
952 });
953 }
954
955 struct ServerEvents {
956 received_event_count: Rc<RefCell<usize>>,
957 expected_event_count: usize,
958 _subscription: Subscription,
959 }
960
961 impl Drop for ServerEvents {
962 fn drop(&mut self) {
963 let actual_event_count = *self.received_event_count.borrow();
964 assert_eq!(
965 actual_event_count, self.expected_event_count,
966 "
967 Expected to receive {} context server store events, but received {} events",
968 self.expected_event_count, actual_event_count
969 );
970 }
971 }
972
973 fn assert_server_events(
974 store: &Entity<ContextServerStore>,
975 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
976 cx: &mut TestAppContext,
977 ) -> ServerEvents {
978 cx.update(|cx| {
979 let mut ix = 0;
980 let received_event_count = Rc::new(RefCell::new(0));
981 let expected_event_count = expected_events.len();
982 let subscription = cx.subscribe(store, {
983 let received_event_count = received_event_count.clone();
984 move |_, event, _| match event {
985 Event::ServerStatusChanged {
986 server_id: actual_server_id,
987 status: actual_status,
988 } => {
989 let (expected_server_id, expected_status) = &expected_events[ix];
990
991 assert_eq!(
992 actual_server_id, expected_server_id,
993 "Expected different server id at index {}",
994 ix
995 );
996 assert_eq!(
997 actual_status, expected_status,
998 "Expected different status at index {}",
999 ix
1000 );
1001 ix += 1;
1002 *received_event_count.borrow_mut() += 1;
1003 }
1004 }
1005 });
1006 ServerEvents {
1007 expected_event_count,
1008 received_event_count,
1009 _subscription: subscription,
1010 }
1011 })
1012 }
1013
1014 async fn setup_context_server_test(
1015 cx: &mut TestAppContext,
1016 files: serde_json::Value,
1017 context_server_configurations: Vec<(Arc<str>, ContextServerConfiguration)>,
1018 ) -> (Arc<FakeFs>, Entity<Project>) {
1019 cx.update(|cx| {
1020 let settings_store = SettingsStore::test(cx);
1021 cx.set_global(settings_store);
1022 Project::init_settings(cx);
1023 let mut settings = ProjectSettings::get_global(cx).clone();
1024 for (id, config) in context_server_configurations {
1025 settings.context_servers.insert(id, config);
1026 }
1027 ProjectSettings::override_global(settings, cx);
1028 });
1029
1030 let fs = FakeFs::new(cx.executor());
1031 fs.insert_tree(path!("/test"), files).await;
1032 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1033
1034 (fs, project)
1035 }
1036
1037 fn create_initialize_response(server_name: String) -> serde_json::Value {
1038 serde_json::to_value(&InitializeResponse {
1039 protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
1040 server_info: Implementation {
1041 name: server_name,
1042 version: "1.0.0".to_string(),
1043 },
1044 capabilities: ServerCapabilities::default(),
1045 meta: None,
1046 })
1047 .unwrap()
1048 }
1049
1050 struct FakeTransport {
1051 on_request: Arc<
1052 dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
1053 + Send
1054 + Sync,
1055 >,
1056 tx: futures::channel::mpsc::UnboundedSender<String>,
1057 rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
1058 executor: BackgroundExecutor,
1059 }
1060
1061 impl FakeTransport {
1062 fn new(
1063 executor: BackgroundExecutor,
1064 on_request: impl Fn(
1065 u64,
1066 Option<RequestType>,
1067 serde_json::Value,
1068 ) -> Option<serde_json::Value>
1069 + 'static
1070 + Send
1071 + Sync,
1072 ) -> Self {
1073 let (tx, rx) = futures::channel::mpsc::unbounded();
1074 Self {
1075 on_request: Arc::new(on_request),
1076 tx,
1077 rx: Arc::new(Mutex::new(rx)),
1078 executor,
1079 }
1080 }
1081 }
1082
1083 #[async_trait::async_trait]
1084 impl Transport for FakeTransport {
1085 async fn send(&self, message: String) -> Result<()> {
1086 if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
1087 let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
1088
1089 if let Some(method) = msg.get("method") {
1090 let request_type = method
1091 .as_str()
1092 .and_then(|method| types::RequestType::try_from(method).ok());
1093 if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
1094 let response = serde_json::json!({
1095 "jsonrpc": "2.0",
1096 "id": id,
1097 "result": payload
1098 });
1099
1100 self.tx
1101 .unbounded_send(response.to_string())
1102 .context("sending a message")?;
1103 }
1104 }
1105 }
1106 Ok(())
1107 }
1108
1109 fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
1110 let rx = self.rx.clone();
1111 let executor = self.executor.clone();
1112 Box::pin(futures::stream::unfold(rx, move |rx| {
1113 let executor = executor.clone();
1114 async move {
1115 let mut rx_guard = rx.lock().await;
1116 executor.simulate_random_delay().await;
1117 if let Some(message) = rx_guard.next().await {
1118 drop(rx_guard);
1119 Some((message, rx))
1120 } else {
1121 None
1122 }
1123 }
1124 }))
1125 }
1126
1127 fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
1128 Box::pin(futures::stream::empty())
1129 }
1130 }
1131}