1pub mod extension;
2pub mod registry;
3
4use std::sync::Arc;
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
9use futures::{FutureExt as _, future::join_all};
10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
11use registry::ContextServerDescriptorRegistry;
12use settings::{Settings as _, SettingsStore};
13use util::{ResultExt as _, rel_path::RelPath};
14
15use crate::{
16 Project,
17 project_settings::{ContextServerSettings, ProjectSettings},
18 worktree_store::WorktreeStore,
19};
20
21pub fn init(cx: &mut App) {
22 extension::init(cx);
23}
24
25actions!(
26 context_server,
27 [
28 /// Restarts the context server.
29 Restart
30 ]
31);
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub enum ContextServerStatus {
35 Starting,
36 Running,
37 Stopped,
38 Error(Arc<str>),
39}
40
41impl ContextServerStatus {
42 fn from_state(state: &ContextServerState) -> Self {
43 match state {
44 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
45 ContextServerState::Running { .. } => ContextServerStatus::Running,
46 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
47 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
48 }
49 }
50}
51
52enum ContextServerState {
53 Starting {
54 server: Arc<ContextServer>,
55 configuration: Arc<ContextServerConfiguration>,
56 _task: Task<()>,
57 },
58 Running {
59 server: Arc<ContextServer>,
60 configuration: Arc<ContextServerConfiguration>,
61 },
62 Stopped {
63 server: Arc<ContextServer>,
64 configuration: Arc<ContextServerConfiguration>,
65 },
66 Error {
67 server: Arc<ContextServer>,
68 configuration: Arc<ContextServerConfiguration>,
69 error: Arc<str>,
70 },
71}
72
73impl ContextServerState {
74 pub fn server(&self) -> Arc<ContextServer> {
75 match self {
76 ContextServerState::Starting { server, .. } => server.clone(),
77 ContextServerState::Running { server, .. } => server.clone(),
78 ContextServerState::Stopped { server, .. } => server.clone(),
79 ContextServerState::Error { server, .. } => server.clone(),
80 }
81 }
82
83 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
84 match self {
85 ContextServerState::Starting { configuration, .. } => configuration.clone(),
86 ContextServerState::Running { configuration, .. } => configuration.clone(),
87 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
88 ContextServerState::Error { configuration, .. } => configuration.clone(),
89 }
90 }
91}
92
93#[derive(Debug, PartialEq, Eq)]
94pub enum ContextServerConfiguration {
95 Custom {
96 command: ContextServerCommand,
97 },
98 Extension {
99 command: ContextServerCommand,
100 settings: serde_json::Value,
101 },
102 Http {
103 url: url::Url,
104 headers: HashMap<String, String>,
105 },
106}
107
108impl ContextServerConfiguration {
109 pub fn command(&self) -> Option<&ContextServerCommand> {
110 match self {
111 ContextServerConfiguration::Custom { command } => Some(command),
112 ContextServerConfiguration::Extension { command, .. } => Some(command),
113 ContextServerConfiguration::Http { .. } => None,
114 }
115 }
116
117 pub async fn from_settings(
118 settings: ContextServerSettings,
119 id: ContextServerId,
120 registry: Entity<ContextServerDescriptorRegistry>,
121 worktree_store: Entity<WorktreeStore>,
122 cx: &AsyncApp,
123 ) -> Option<Self> {
124 match settings {
125 ContextServerSettings::Stdio {
126 enabled: _,
127 command,
128 } => Some(ContextServerConfiguration::Custom { command }),
129 ContextServerSettings::Extension {
130 enabled: _,
131 settings,
132 } => {
133 let descriptor = cx
134 .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
135 .ok()
136 .flatten()?;
137
138 match descriptor.command(worktree_store, cx).await {
139 Ok(command) => {
140 Some(ContextServerConfiguration::Extension { command, settings })
141 }
142 Err(e) => {
143 log::error!(
144 "Failed to create context server configuration from settings: {e:#}"
145 );
146 None
147 }
148 }
149 }
150 ContextServerSettings::Http {
151 enabled: _,
152 url,
153 headers: auth,
154 } => {
155 let url = url::Url::parse(&url).log_err()?;
156 Some(ContextServerConfiguration::Http { url, headers: auth })
157 }
158 }
159 }
160}
161
162pub type ContextServerFactory =
163 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
164
165pub struct ContextServerStore {
166 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
167 servers: HashMap<ContextServerId, ContextServerState>,
168 worktree_store: Entity<WorktreeStore>,
169 project: WeakEntity<Project>,
170 registry: Entity<ContextServerDescriptorRegistry>,
171 update_servers_task: Option<Task<Result<()>>>,
172 context_server_factory: Option<ContextServerFactory>,
173 needs_server_update: bool,
174 _subscriptions: Vec<Subscription>,
175}
176
177pub enum Event {
178 ServerStatusChanged {
179 server_id: ContextServerId,
180 status: ContextServerStatus,
181 },
182}
183
184impl EventEmitter<Event> for ContextServerStore {}
185
186impl ContextServerStore {
187 pub fn new(
188 worktree_store: Entity<WorktreeStore>,
189 weak_project: WeakEntity<Project>,
190 cx: &mut Context<Self>,
191 ) -> Self {
192 Self::new_internal(
193 true,
194 None,
195 ContextServerDescriptorRegistry::default_global(cx),
196 worktree_store,
197 weak_project,
198 cx,
199 )
200 }
201
202 /// Returns all configured context server ids, excluding the ones that are disabled
203 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
204 self.context_server_settings
205 .iter()
206 .filter(|(_, settings)| settings.enabled())
207 .map(|(id, _)| ContextServerId(id.clone()))
208 .collect()
209 }
210
211 #[cfg(any(test, feature = "test-support"))]
212 pub fn test(
213 registry: Entity<ContextServerDescriptorRegistry>,
214 worktree_store: Entity<WorktreeStore>,
215 weak_project: WeakEntity<Project>,
216 cx: &mut Context<Self>,
217 ) -> Self {
218 Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
219 }
220
221 #[cfg(any(test, feature = "test-support"))]
222 pub fn test_maintain_server_loop(
223 context_server_factory: Option<ContextServerFactory>,
224 registry: Entity<ContextServerDescriptorRegistry>,
225 worktree_store: Entity<WorktreeStore>,
226 weak_project: WeakEntity<Project>,
227 cx: &mut Context<Self>,
228 ) -> Self {
229 Self::new_internal(
230 true,
231 context_server_factory,
232 registry,
233 worktree_store,
234 weak_project,
235 cx,
236 )
237 }
238
239 fn new_internal(
240 maintain_server_loop: bool,
241 context_server_factory: Option<ContextServerFactory>,
242 registry: Entity<ContextServerDescriptorRegistry>,
243 worktree_store: Entity<WorktreeStore>,
244 weak_project: WeakEntity<Project>,
245 cx: &mut Context<Self>,
246 ) -> Self {
247 let subscriptions = if maintain_server_loop {
248 vec![
249 cx.observe(®istry, |this, _registry, cx| {
250 this.available_context_servers_changed(cx);
251 }),
252 cx.observe_global::<SettingsStore>(|this, cx| {
253 let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
254 if &this.context_server_settings == settings {
255 return;
256 }
257 this.context_server_settings = settings.clone();
258 this.available_context_servers_changed(cx);
259 }),
260 ]
261 } else {
262 Vec::new()
263 };
264
265 let mut this = Self {
266 _subscriptions: subscriptions,
267 context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
268 .clone(),
269 worktree_store,
270 project: weak_project,
271 registry,
272 needs_server_update: false,
273 servers: HashMap::default(),
274 update_servers_task: None,
275 context_server_factory,
276 };
277 if maintain_server_loop {
278 this.available_context_servers_changed(cx);
279 }
280 this
281 }
282
283 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
284 self.servers.get(id).map(|state| state.server())
285 }
286
287 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
288 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
289 Some(server.clone())
290 } else {
291 None
292 }
293 }
294
295 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
296 self.servers.get(id).map(ContextServerStatus::from_state)
297 }
298
299 pub fn configuration_for_server(
300 &self,
301 id: &ContextServerId,
302 ) -> Option<Arc<ContextServerConfiguration>> {
303 self.servers.get(id).map(|state| state.configuration())
304 }
305
306 pub fn server_ids(&self, cx: &App) -> HashSet<ContextServerId> {
307 self.servers
308 .keys()
309 .cloned()
310 .chain(
311 self.registry
312 .read(cx)
313 .context_server_descriptors()
314 .into_iter()
315 .map(|(id, _)| ContextServerId(id)),
316 )
317 .collect()
318 }
319
320 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
321 self.servers
322 .values()
323 .filter_map(|state| {
324 if let ContextServerState::Running { server, .. } = state {
325 Some(server.clone())
326 } else {
327 None
328 }
329 })
330 .collect()
331 }
332
333 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
334 cx.spawn(async move |this, cx| {
335 let this = this.upgrade().context("Context server store dropped")?;
336 let settings = this
337 .update(cx, |this, _| {
338 this.context_server_settings.get(&server.id().0).cloned()
339 })
340 .ok()
341 .flatten()
342 .context("Failed to get context server settings")?;
343
344 if !settings.enabled() {
345 return Ok(());
346 }
347
348 let (registry, worktree_store) = this.update(cx, |this, _| {
349 (this.registry.clone(), this.worktree_store.clone())
350 })?;
351 let configuration = ContextServerConfiguration::from_settings(
352 settings,
353 server.id(),
354 registry,
355 worktree_store,
356 cx,
357 )
358 .await
359 .context("Failed to create context server configuration")?;
360
361 this.update(cx, |this, cx| {
362 this.run_server(server, Arc::new(configuration), cx)
363 })
364 })
365 .detach_and_log_err(cx);
366 }
367
368 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
369 if matches!(
370 self.servers.get(id),
371 Some(ContextServerState::Stopped { .. })
372 ) {
373 return Ok(());
374 }
375
376 let state = self
377 .servers
378 .remove(id)
379 .context("Context server not found")?;
380
381 let server = state.server();
382 let configuration = state.configuration();
383 let mut result = Ok(());
384 if let ContextServerState::Running { server, .. } = &state {
385 result = server.stop();
386 }
387 drop(state);
388
389 self.update_server_state(
390 id.clone(),
391 ContextServerState::Stopped {
392 configuration,
393 server,
394 },
395 cx,
396 );
397
398 result
399 }
400
401 fn run_server(
402 &mut self,
403 server: Arc<ContextServer>,
404 configuration: Arc<ContextServerConfiguration>,
405 cx: &mut Context<Self>,
406 ) {
407 let id = server.id();
408 if matches!(
409 self.servers.get(&id),
410 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
411 ) {
412 self.stop_server(&id, cx).log_err();
413 }
414
415 let task = cx.spawn({
416 let id = server.id();
417 let server = server.clone();
418 let configuration = configuration.clone();
419 async move |this, cx| {
420 match server.clone().start(cx).await {
421 Ok(_) => {
422 debug_assert!(server.client().is_some());
423
424 this.update(cx, |this, cx| {
425 this.update_server_state(
426 id.clone(),
427 ContextServerState::Running {
428 server,
429 configuration,
430 },
431 cx,
432 )
433 })
434 .log_err()
435 }
436 Err(err) => {
437 log::error!("{} context server failed to start: {}", id, err);
438 this.update(cx, |this, cx| {
439 this.update_server_state(
440 id.clone(),
441 ContextServerState::Error {
442 configuration,
443 server,
444 error: err.to_string().into(),
445 },
446 cx,
447 )
448 })
449 .log_err()
450 }
451 };
452 }
453 });
454
455 self.update_server_state(
456 id.clone(),
457 ContextServerState::Starting {
458 configuration,
459 _task: task,
460 server,
461 },
462 cx,
463 );
464 }
465
466 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
467 let state = self
468 .servers
469 .remove(id)
470 .context("Context server not found")?;
471 drop(state);
472 cx.emit(Event::ServerStatusChanged {
473 server_id: id.clone(),
474 status: ContextServerStatus::Stopped,
475 });
476 Ok(())
477 }
478
479 fn create_context_server(
480 &self,
481 id: ContextServerId,
482 configuration: Arc<ContextServerConfiguration>,
483 cx: &mut Context<Self>,
484 ) -> Result<Arc<ContextServer>> {
485 if let Some(factory) = self.context_server_factory.as_ref() {
486 return Ok(factory(id, configuration));
487 }
488
489 match configuration.as_ref() {
490 ContextServerConfiguration::Http { url, headers } => Ok(Arc::new(ContextServer::http(
491 id,
492 url,
493 headers.clone(),
494 cx.http_client(),
495 cx.background_executor().clone(),
496 )?)),
497 _ => {
498 let root_path = self
499 .project
500 .read_with(cx, |project, cx| project.active_project_directory(cx))
501 .ok()
502 .flatten()
503 .or_else(|| {
504 self.worktree_store.read_with(cx, |store, cx| {
505 store.visible_worktrees(cx).fold(None, |acc, item| {
506 if acc.is_none() {
507 item.read(cx).root_dir()
508 } else {
509 acc
510 }
511 })
512 })
513 });
514 Ok(Arc::new(ContextServer::stdio(
515 id,
516 configuration.command().unwrap().clone(),
517 root_path,
518 )))
519 }
520 }
521 }
522
523 fn resolve_context_server_settings<'a>(
524 worktree_store: &'a Entity<WorktreeStore>,
525 cx: &'a App,
526 ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
527 let location = worktree_store
528 .read(cx)
529 .visible_worktrees(cx)
530 .next()
531 .map(|worktree| settings::SettingsLocation {
532 worktree_id: worktree.read(cx).id(),
533 path: RelPath::empty(),
534 });
535 &ProjectSettings::get(location, cx).context_servers
536 }
537
538 fn update_server_state(
539 &mut self,
540 id: ContextServerId,
541 state: ContextServerState,
542 cx: &mut Context<Self>,
543 ) {
544 let status = ContextServerStatus::from_state(&state);
545 self.servers.insert(id.clone(), state);
546 cx.emit(Event::ServerStatusChanged {
547 server_id: id,
548 status,
549 });
550 }
551
552 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
553 if self.update_servers_task.is_some() {
554 self.needs_server_update = true;
555 } else {
556 self.needs_server_update = false;
557 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
558 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
559 log::error!("Error maintaining context servers: {}", err);
560 }
561
562 this.update(cx, |this, cx| {
563 this.update_servers_task.take();
564 if this.needs_server_update {
565 this.available_context_servers_changed(cx);
566 }
567 })?;
568
569 Ok(())
570 }));
571 }
572 }
573
574 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
575 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
576 (
577 this.context_server_settings.clone(),
578 this.registry.clone(),
579 this.worktree_store.clone(),
580 )
581 })?;
582
583 for (id, _) in
584 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
585 {
586 configured_servers
587 .entry(id)
588 .or_insert(ContextServerSettings::default_extension());
589 }
590
591 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
592 configured_servers
593 .into_iter()
594 .partition(|(_, settings)| settings.enabled());
595
596 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
597 let id = ContextServerId(id);
598 ContextServerConfiguration::from_settings(
599 settings,
600 id.clone(),
601 registry.clone(),
602 worktree_store.clone(),
603 cx,
604 )
605 .map(|config| (id, config))
606 }))
607 .await
608 .into_iter()
609 .filter_map(|(id, config)| config.map(|config| (id, config)))
610 .collect::<HashMap<_, _>>();
611
612 let mut servers_to_start = Vec::new();
613 let mut servers_to_remove = HashSet::default();
614 let mut servers_to_stop = HashSet::default();
615
616 this.update(cx, |this, cx| {
617 for server_id in this.servers.keys() {
618 // All servers that are not in desired_servers should be removed from the store.
619 // This can happen if the user removed a server from the context server settings.
620 if !configured_servers.contains_key(server_id) {
621 if disabled_servers.contains_key(&server_id.0) {
622 servers_to_stop.insert(server_id.clone());
623 } else {
624 servers_to_remove.insert(server_id.clone());
625 }
626 }
627 }
628
629 for (id, config) in configured_servers {
630 let state = this.servers.get(&id);
631 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
632 let existing_config = state.as_ref().map(|state| state.configuration());
633 if existing_config.as_deref() != Some(&config) || is_stopped {
634 let config = Arc::new(config);
635 let server = this.create_context_server(id.clone(), config.clone(), cx)?;
636 servers_to_start.push((server, config));
637 if this.servers.contains_key(&id) {
638 servers_to_stop.insert(id);
639 }
640 }
641 }
642
643 anyhow::Ok(())
644 })??;
645
646 this.update(cx, |this, cx| {
647 for id in servers_to_stop {
648 this.stop_server(&id, cx)?;
649 }
650 for id in servers_to_remove {
651 this.remove_server(&id, cx)?;
652 }
653 for (server, config) in servers_to_start {
654 this.run_server(server, config, cx);
655 }
656 anyhow::Ok(())
657 })?
658 }
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664 use crate::{
665 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
666 project_settings::ProjectSettings,
667 };
668 use context_server::test::create_fake_transport;
669 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
670 use http_client::{FakeHttpClient, Response};
671 use serde_json::json;
672 use std::{cell::RefCell, path::PathBuf, rc::Rc};
673 use util::path;
674
675 #[gpui::test]
676 async fn test_context_server_status(cx: &mut TestAppContext) {
677 const SERVER_1_ID: &str = "mcp-1";
678 const SERVER_2_ID: &str = "mcp-2";
679
680 let (_fs, project) = setup_context_server_test(
681 cx,
682 json!({"code.rs": ""}),
683 vec![
684 (SERVER_1_ID.into(), dummy_server_settings()),
685 (SERVER_2_ID.into(), dummy_server_settings()),
686 ],
687 )
688 .await;
689
690 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
691 let store = cx.new(|cx| {
692 ContextServerStore::test(
693 registry.clone(),
694 project.read(cx).worktree_store(),
695 project.downgrade(),
696 cx,
697 )
698 });
699
700 let server_1_id = ContextServerId(SERVER_1_ID.into());
701 let server_2_id = ContextServerId(SERVER_2_ID.into());
702
703 let server_1 = Arc::new(ContextServer::new(
704 server_1_id.clone(),
705 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
706 ));
707 let server_2 = Arc::new(ContextServer::new(
708 server_2_id.clone(),
709 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
710 ));
711
712 store.update(cx, |store, cx| store.start_server(server_1, cx));
713
714 cx.run_until_parked();
715
716 cx.update(|cx| {
717 assert_eq!(
718 store.read(cx).status_for_server(&server_1_id),
719 Some(ContextServerStatus::Running)
720 );
721 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
722 });
723
724 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
725
726 cx.run_until_parked();
727
728 cx.update(|cx| {
729 assert_eq!(
730 store.read(cx).status_for_server(&server_1_id),
731 Some(ContextServerStatus::Running)
732 );
733 assert_eq!(
734 store.read(cx).status_for_server(&server_2_id),
735 Some(ContextServerStatus::Running)
736 );
737 });
738
739 store
740 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
741 .unwrap();
742
743 cx.update(|cx| {
744 assert_eq!(
745 store.read(cx).status_for_server(&server_1_id),
746 Some(ContextServerStatus::Running)
747 );
748 assert_eq!(
749 store.read(cx).status_for_server(&server_2_id),
750 Some(ContextServerStatus::Stopped)
751 );
752 });
753 }
754
755 #[gpui::test]
756 async fn test_context_server_status_events(cx: &mut TestAppContext) {
757 const SERVER_1_ID: &str = "mcp-1";
758 const SERVER_2_ID: &str = "mcp-2";
759
760 let (_fs, project) = setup_context_server_test(
761 cx,
762 json!({"code.rs": ""}),
763 vec![
764 (SERVER_1_ID.into(), dummy_server_settings()),
765 (SERVER_2_ID.into(), dummy_server_settings()),
766 ],
767 )
768 .await;
769
770 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
771 let store = cx.new(|cx| {
772 ContextServerStore::test(
773 registry.clone(),
774 project.read(cx).worktree_store(),
775 project.downgrade(),
776 cx,
777 )
778 });
779
780 let server_1_id = ContextServerId(SERVER_1_ID.into());
781 let server_2_id = ContextServerId(SERVER_2_ID.into());
782
783 let server_1 = Arc::new(ContextServer::new(
784 server_1_id.clone(),
785 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
786 ));
787 let server_2 = Arc::new(ContextServer::new(
788 server_2_id.clone(),
789 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
790 ));
791
792 let _server_events = assert_server_events(
793 &store,
794 vec![
795 (server_1_id.clone(), ContextServerStatus::Starting),
796 (server_1_id, ContextServerStatus::Running),
797 (server_2_id.clone(), ContextServerStatus::Starting),
798 (server_2_id.clone(), ContextServerStatus::Running),
799 (server_2_id.clone(), ContextServerStatus::Stopped),
800 ],
801 cx,
802 );
803
804 store.update(cx, |store, cx| store.start_server(server_1, cx));
805
806 cx.run_until_parked();
807
808 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
809
810 cx.run_until_parked();
811
812 store
813 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
814 .unwrap();
815 }
816
817 #[gpui::test(iterations = 25)]
818 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
819 const SERVER_1_ID: &str = "mcp-1";
820
821 let (_fs, project) = setup_context_server_test(
822 cx,
823 json!({"code.rs": ""}),
824 vec![(SERVER_1_ID.into(), dummy_server_settings())],
825 )
826 .await;
827
828 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
829 let store = cx.new(|cx| {
830 ContextServerStore::test(
831 registry.clone(),
832 project.read(cx).worktree_store(),
833 project.downgrade(),
834 cx,
835 )
836 });
837
838 let server_id = ContextServerId(SERVER_1_ID.into());
839
840 let server_with_same_id_1 = Arc::new(ContextServer::new(
841 server_id.clone(),
842 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
843 ));
844 let server_with_same_id_2 = Arc::new(ContextServer::new(
845 server_id.clone(),
846 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
847 ));
848
849 // If we start another server with the same id, we should report that we stopped the previous one
850 let _server_events = assert_server_events(
851 &store,
852 vec![
853 (server_id.clone(), ContextServerStatus::Starting),
854 (server_id.clone(), ContextServerStatus::Stopped),
855 (server_id.clone(), ContextServerStatus::Starting),
856 (server_id.clone(), ContextServerStatus::Running),
857 ],
858 cx,
859 );
860
861 store.update(cx, |store, cx| {
862 store.start_server(server_with_same_id_1.clone(), cx)
863 });
864 store.update(cx, |store, cx| {
865 store.start_server(server_with_same_id_2.clone(), cx)
866 });
867
868 cx.run_until_parked();
869
870 cx.update(|cx| {
871 assert_eq!(
872 store.read(cx).status_for_server(&server_id),
873 Some(ContextServerStatus::Running)
874 );
875 });
876 }
877
878 #[gpui::test]
879 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
880 const SERVER_1_ID: &str = "mcp-1";
881 const SERVER_2_ID: &str = "mcp-2";
882
883 let server_1_id = ContextServerId(SERVER_1_ID.into());
884 let server_2_id = ContextServerId(SERVER_2_ID.into());
885
886 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
887
888 let (_fs, project) = setup_context_server_test(
889 cx,
890 json!({"code.rs": ""}),
891 vec![(
892 SERVER_1_ID.into(),
893 ContextServerSettings::Extension {
894 enabled: true,
895 settings: json!({
896 "somevalue": true
897 }),
898 },
899 )],
900 )
901 .await;
902
903 let executor = cx.executor();
904 let registry = cx.new(|cx| {
905 let mut registry = ContextServerDescriptorRegistry::new();
906 registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
907 registry
908 });
909 let store = cx.new(|cx| {
910 ContextServerStore::test_maintain_server_loop(
911 Some(Box::new(move |id, _| {
912 Arc::new(ContextServer::new(
913 id.clone(),
914 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
915 ))
916 })),
917 registry.clone(),
918 project.read(cx).worktree_store(),
919 project.downgrade(),
920 cx,
921 )
922 });
923
924 // Ensure that mcp-1 starts up
925 {
926 let _server_events = assert_server_events(
927 &store,
928 vec![
929 (server_1_id.clone(), ContextServerStatus::Starting),
930 (server_1_id.clone(), ContextServerStatus::Running),
931 ],
932 cx,
933 );
934 cx.run_until_parked();
935 }
936
937 // Ensure that mcp-1 is restarted when the configuration was changed
938 {
939 let _server_events = assert_server_events(
940 &store,
941 vec![
942 (server_1_id.clone(), ContextServerStatus::Stopped),
943 (server_1_id.clone(), ContextServerStatus::Starting),
944 (server_1_id.clone(), ContextServerStatus::Running),
945 ],
946 cx,
947 );
948 set_context_server_configuration(
949 vec![(
950 server_1_id.0.clone(),
951 settings::ContextServerSettingsContent::Extension {
952 enabled: true,
953 settings: json!({
954 "somevalue": false
955 }),
956 },
957 )],
958 cx,
959 );
960
961 cx.run_until_parked();
962 }
963
964 // Ensure that mcp-1 is not restarted when the configuration was not changed
965 {
966 let _server_events = assert_server_events(&store, vec![], cx);
967 set_context_server_configuration(
968 vec![(
969 server_1_id.0.clone(),
970 settings::ContextServerSettingsContent::Extension {
971 enabled: true,
972 settings: json!({
973 "somevalue": false
974 }),
975 },
976 )],
977 cx,
978 );
979
980 cx.run_until_parked();
981 }
982
983 // Ensure that mcp-2 is started once it is added to the settings
984 {
985 let _server_events = assert_server_events(
986 &store,
987 vec![
988 (server_2_id.clone(), ContextServerStatus::Starting),
989 (server_2_id.clone(), ContextServerStatus::Running),
990 ],
991 cx,
992 );
993 set_context_server_configuration(
994 vec![
995 (
996 server_1_id.0.clone(),
997 settings::ContextServerSettingsContent::Extension {
998 enabled: true,
999 settings: json!({
1000 "somevalue": false
1001 }),
1002 },
1003 ),
1004 (
1005 server_2_id.0.clone(),
1006 settings::ContextServerSettingsContent::Stdio {
1007 enabled: true,
1008 command: ContextServerCommand {
1009 path: "somebinary".into(),
1010 args: vec!["arg".to_string()],
1011 env: None,
1012 timeout: None,
1013 },
1014 },
1015 ),
1016 ],
1017 cx,
1018 );
1019
1020 cx.run_until_parked();
1021 }
1022
1023 // Ensure that mcp-2 is restarted once the args have changed
1024 {
1025 let _server_events = assert_server_events(
1026 &store,
1027 vec![
1028 (server_2_id.clone(), ContextServerStatus::Stopped),
1029 (server_2_id.clone(), ContextServerStatus::Starting),
1030 (server_2_id.clone(), ContextServerStatus::Running),
1031 ],
1032 cx,
1033 );
1034 set_context_server_configuration(
1035 vec![
1036 (
1037 server_1_id.0.clone(),
1038 settings::ContextServerSettingsContent::Extension {
1039 enabled: true,
1040 settings: json!({
1041 "somevalue": false
1042 }),
1043 },
1044 ),
1045 (
1046 server_2_id.0.clone(),
1047 settings::ContextServerSettingsContent::Stdio {
1048 enabled: true,
1049 command: ContextServerCommand {
1050 path: "somebinary".into(),
1051 args: vec!["anotherArg".to_string()],
1052 env: None,
1053 timeout: None,
1054 },
1055 },
1056 ),
1057 ],
1058 cx,
1059 );
1060
1061 cx.run_until_parked();
1062 }
1063
1064 // Ensure that mcp-2 is removed once it is removed from the settings
1065 {
1066 let _server_events = assert_server_events(
1067 &store,
1068 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1069 cx,
1070 );
1071 set_context_server_configuration(
1072 vec![(
1073 server_1_id.0.clone(),
1074 settings::ContextServerSettingsContent::Extension {
1075 enabled: true,
1076 settings: json!({
1077 "somevalue": false
1078 }),
1079 },
1080 )],
1081 cx,
1082 );
1083
1084 cx.run_until_parked();
1085
1086 cx.update(|cx| {
1087 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1088 });
1089 }
1090
1091 // Ensure that nothing happens if the settings do not change
1092 {
1093 let _server_events = assert_server_events(&store, vec![], cx);
1094 set_context_server_configuration(
1095 vec![(
1096 server_1_id.0.clone(),
1097 settings::ContextServerSettingsContent::Extension {
1098 enabled: true,
1099 settings: json!({
1100 "somevalue": false
1101 }),
1102 },
1103 )],
1104 cx,
1105 );
1106
1107 cx.run_until_parked();
1108
1109 cx.update(|cx| {
1110 assert_eq!(
1111 store.read(cx).status_for_server(&server_1_id),
1112 Some(ContextServerStatus::Running)
1113 );
1114 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1115 });
1116 }
1117 }
1118
1119 #[gpui::test]
1120 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1121 const SERVER_1_ID: &str = "mcp-1";
1122
1123 let server_1_id = ContextServerId(SERVER_1_ID.into());
1124
1125 let (_fs, project) = setup_context_server_test(
1126 cx,
1127 json!({"code.rs": ""}),
1128 vec![(
1129 SERVER_1_ID.into(),
1130 ContextServerSettings::Stdio {
1131 enabled: true,
1132 command: ContextServerCommand {
1133 path: "somebinary".into(),
1134 args: vec!["arg".to_string()],
1135 env: None,
1136 timeout: None,
1137 },
1138 },
1139 )],
1140 )
1141 .await;
1142
1143 let executor = cx.executor();
1144 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1145 let store = cx.new(|cx| {
1146 ContextServerStore::test_maintain_server_loop(
1147 Some(Box::new(move |id, _| {
1148 Arc::new(ContextServer::new(
1149 id.clone(),
1150 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1151 ))
1152 })),
1153 registry.clone(),
1154 project.read(cx).worktree_store(),
1155 project.downgrade(),
1156 cx,
1157 )
1158 });
1159
1160 // Ensure that mcp-1 starts up
1161 {
1162 let _server_events = assert_server_events(
1163 &store,
1164 vec![
1165 (server_1_id.clone(), ContextServerStatus::Starting),
1166 (server_1_id.clone(), ContextServerStatus::Running),
1167 ],
1168 cx,
1169 );
1170 cx.run_until_parked();
1171 }
1172
1173 // Ensure that mcp-1 is stopped once it is disabled.
1174 {
1175 let _server_events = assert_server_events(
1176 &store,
1177 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1178 cx,
1179 );
1180 set_context_server_configuration(
1181 vec![(
1182 server_1_id.0.clone(),
1183 settings::ContextServerSettingsContent::Stdio {
1184 enabled: false,
1185 command: ContextServerCommand {
1186 path: "somebinary".into(),
1187 args: vec!["arg".to_string()],
1188 env: None,
1189 timeout: None,
1190 },
1191 },
1192 )],
1193 cx,
1194 );
1195
1196 cx.run_until_parked();
1197 }
1198
1199 // Ensure that mcp-1 is started once it is enabled again.
1200 {
1201 let _server_events = assert_server_events(
1202 &store,
1203 vec![
1204 (server_1_id.clone(), ContextServerStatus::Starting),
1205 (server_1_id.clone(), ContextServerStatus::Running),
1206 ],
1207 cx,
1208 );
1209 set_context_server_configuration(
1210 vec![(
1211 server_1_id.0.clone(),
1212 settings::ContextServerSettingsContent::Stdio {
1213 enabled: true,
1214 command: ContextServerCommand {
1215 path: "somebinary".into(),
1216 args: vec!["arg".to_string()],
1217 timeout: None,
1218 env: None,
1219 },
1220 },
1221 )],
1222 cx,
1223 );
1224
1225 cx.run_until_parked();
1226 }
1227 }
1228
1229 fn set_context_server_configuration(
1230 context_servers: Vec<(Arc<str>, settings::ContextServerSettingsContent)>,
1231 cx: &mut TestAppContext,
1232 ) {
1233 cx.update(|cx| {
1234 SettingsStore::update_global(cx, |store, cx| {
1235 store.update_user_settings(cx, |content| {
1236 content.project.context_servers.clear();
1237 for (id, config) in context_servers {
1238 content.project.context_servers.insert(id, config);
1239 }
1240 });
1241 })
1242 });
1243 }
1244
1245 #[gpui::test]
1246 async fn test_remote_context_server(cx: &mut TestAppContext) {
1247 const SERVER_ID: &str = "remote-server";
1248 let server_id = ContextServerId(SERVER_ID.into());
1249 let server_url = "http://example.com/api";
1250
1251 let (_fs, project) = setup_context_server_test(
1252 cx,
1253 json!({ "code.rs": "" }),
1254 vec![(
1255 SERVER_ID.into(),
1256 ContextServerSettings::Http {
1257 enabled: true,
1258 url: server_url.to_string(),
1259 headers: Default::default(),
1260 },
1261 )],
1262 )
1263 .await;
1264
1265 let client = FakeHttpClient::create(|_| async move {
1266 use http_client::AsyncBody;
1267
1268 let response = Response::builder()
1269 .status(200)
1270 .header("Content-Type", "application/json")
1271 .body(AsyncBody::from(
1272 serde_json::to_string(&json!({
1273 "jsonrpc": "2.0",
1274 "id": 0,
1275 "result": {
1276 "protocolVersion": "2024-11-05",
1277 "capabilities": {},
1278 "serverInfo": {
1279 "name": "test-server",
1280 "version": "1.0.0"
1281 }
1282 }
1283 }))
1284 .unwrap(),
1285 ))
1286 .unwrap();
1287 Ok(response)
1288 });
1289 cx.update(|cx| cx.set_http_client(client));
1290 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1291 let store = cx.new(|cx| {
1292 ContextServerStore::test_maintain_server_loop(
1293 None,
1294 registry.clone(),
1295 project.read(cx).worktree_store(),
1296 project.downgrade(),
1297 cx,
1298 )
1299 });
1300
1301 let _server_events = assert_server_events(
1302 &store,
1303 vec![
1304 (server_id.clone(), ContextServerStatus::Starting),
1305 (server_id.clone(), ContextServerStatus::Running),
1306 ],
1307 cx,
1308 );
1309 cx.run_until_parked();
1310 }
1311
1312 struct ServerEvents {
1313 received_event_count: Rc<RefCell<usize>>,
1314 expected_event_count: usize,
1315 _subscription: Subscription,
1316 }
1317
1318 impl Drop for ServerEvents {
1319 fn drop(&mut self) {
1320 let actual_event_count = *self.received_event_count.borrow();
1321 assert_eq!(
1322 actual_event_count, self.expected_event_count,
1323 "
1324 Expected to receive {} context server store events, but received {} events",
1325 self.expected_event_count, actual_event_count
1326 );
1327 }
1328 }
1329
1330 fn dummy_server_settings() -> ContextServerSettings {
1331 ContextServerSettings::Stdio {
1332 enabled: true,
1333 command: ContextServerCommand {
1334 path: "somebinary".into(),
1335 args: vec!["arg".to_string()],
1336 env: None,
1337 timeout: None,
1338 },
1339 }
1340 }
1341
1342 fn assert_server_events(
1343 store: &Entity<ContextServerStore>,
1344 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1345 cx: &mut TestAppContext,
1346 ) -> ServerEvents {
1347 cx.update(|cx| {
1348 let mut ix = 0;
1349 let received_event_count = Rc::new(RefCell::new(0));
1350 let expected_event_count = expected_events.len();
1351 let subscription = cx.subscribe(store, {
1352 let received_event_count = received_event_count.clone();
1353 move |_, event, _| match event {
1354 Event::ServerStatusChanged {
1355 server_id: actual_server_id,
1356 status: actual_status,
1357 } => {
1358 let (expected_server_id, expected_status) = &expected_events[ix];
1359
1360 assert_eq!(
1361 actual_server_id, expected_server_id,
1362 "Expected different server id at index {}",
1363 ix
1364 );
1365 assert_eq!(
1366 actual_status, expected_status,
1367 "Expected different status at index {}",
1368 ix
1369 );
1370 ix += 1;
1371 *received_event_count.borrow_mut() += 1;
1372 }
1373 }
1374 });
1375 ServerEvents {
1376 expected_event_count,
1377 received_event_count,
1378 _subscription: subscription,
1379 }
1380 })
1381 }
1382
1383 async fn setup_context_server_test(
1384 cx: &mut TestAppContext,
1385 files: serde_json::Value,
1386 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1387 ) -> (Arc<FakeFs>, Entity<Project>) {
1388 cx.update(|cx| {
1389 let settings_store = SettingsStore::test(cx);
1390 cx.set_global(settings_store);
1391 let mut settings = ProjectSettings::get_global(cx).clone();
1392 for (id, config) in context_server_configurations {
1393 settings.context_servers.insert(id, config);
1394 }
1395 ProjectSettings::override_global(settings, cx);
1396 });
1397
1398 let fs = FakeFs::new(cx.executor());
1399 fs.insert_tree(path!("/test"), files).await;
1400 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1401
1402 (fs, project)
1403 }
1404
1405 struct FakeContextServerDescriptor {
1406 path: PathBuf,
1407 }
1408
1409 impl FakeContextServerDescriptor {
1410 fn new(path: impl Into<PathBuf>) -> Self {
1411 Self { path: path.into() }
1412 }
1413 }
1414
1415 impl ContextServerDescriptor for FakeContextServerDescriptor {
1416 fn command(
1417 &self,
1418 _worktree_store: Entity<WorktreeStore>,
1419 _cx: &AsyncApp,
1420 ) -> Task<Result<ContextServerCommand>> {
1421 Task::ready(Ok(ContextServerCommand {
1422 path: self.path.clone(),
1423 args: vec!["arg1".to_string(), "arg2".to_string()],
1424 env: None,
1425 timeout: None,
1426 }))
1427 }
1428
1429 fn configuration(
1430 &self,
1431 _worktree_store: Entity<WorktreeStore>,
1432 _cx: &AsyncApp,
1433 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1434 Task::ready(Ok(None))
1435 }
1436 }
1437}