1pub mod extension;
2pub mod registry;
3
4use std::sync::Arc;
5use std::time::Duration;
6
7use anyhow::{Context as _, Result};
8use collections::{HashMap, HashSet};
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use futures::{FutureExt as _, future::join_all};
11use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
12use registry::ContextServerDescriptorRegistry;
13use settings::{Settings as _, SettingsStore};
14use util::{ResultExt as _, rel_path::RelPath};
15
16use crate::{
17 Project,
18 project_settings::{ContextServerSettings, ProjectSettings},
19 worktree_store::WorktreeStore,
20};
21
22pub fn init(cx: &mut App) {
23 extension::init(cx);
24}
25
26actions!(
27 context_server,
28 [
29 /// Restarts the context server.
30 Restart
31 ]
32);
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub enum ContextServerStatus {
36 Starting,
37 Running,
38 Stopped,
39 Error(Arc<str>),
40}
41
42impl ContextServerStatus {
43 fn from_state(state: &ContextServerState) -> Self {
44 match state {
45 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
46 ContextServerState::Running { .. } => ContextServerStatus::Running,
47 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
48 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
49 }
50 }
51}
52
53enum ContextServerState {
54 Starting {
55 server: Arc<ContextServer>,
56 configuration: Arc<ContextServerConfiguration>,
57 _task: Task<()>,
58 },
59 Running {
60 server: Arc<ContextServer>,
61 configuration: Arc<ContextServerConfiguration>,
62 },
63 Stopped {
64 server: Arc<ContextServer>,
65 configuration: Arc<ContextServerConfiguration>,
66 },
67 Error {
68 server: Arc<ContextServer>,
69 configuration: Arc<ContextServerConfiguration>,
70 error: Arc<str>,
71 },
72}
73
74impl ContextServerState {
75 pub fn server(&self) -> Arc<ContextServer> {
76 match self {
77 ContextServerState::Starting { server, .. } => server.clone(),
78 ContextServerState::Running { server, .. } => server.clone(),
79 ContextServerState::Stopped { server, .. } => server.clone(),
80 ContextServerState::Error { server, .. } => server.clone(),
81 }
82 }
83
84 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
85 match self {
86 ContextServerState::Starting { configuration, .. } => configuration.clone(),
87 ContextServerState::Running { configuration, .. } => configuration.clone(),
88 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
89 ContextServerState::Error { configuration, .. } => configuration.clone(),
90 }
91 }
92}
93
94#[derive(Debug, PartialEq, Eq)]
95pub enum ContextServerConfiguration {
96 Custom {
97 command: ContextServerCommand,
98 },
99 Extension {
100 command: ContextServerCommand,
101 settings: serde_json::Value,
102 },
103 Http {
104 url: url::Url,
105 headers: HashMap<String, String>,
106 timeout: Option<u64>,
107 },
108}
109
110impl ContextServerConfiguration {
111 pub fn command(&self) -> Option<&ContextServerCommand> {
112 match self {
113 ContextServerConfiguration::Custom { command } => Some(command),
114 ContextServerConfiguration::Extension { command, .. } => Some(command),
115 ContextServerConfiguration::Http { .. } => None,
116 }
117 }
118
119 pub async fn from_settings(
120 settings: ContextServerSettings,
121 id: ContextServerId,
122 registry: Entity<ContextServerDescriptorRegistry>,
123 worktree_store: Entity<WorktreeStore>,
124 cx: &AsyncApp,
125 ) -> Option<Self> {
126 match settings {
127 ContextServerSettings::Stdio {
128 enabled: _,
129 command,
130 } => Some(ContextServerConfiguration::Custom { command }),
131 ContextServerSettings::Extension {
132 enabled: _,
133 settings,
134 } => {
135 let descriptor = cx
136 .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
137 .ok()
138 .flatten()?;
139
140 match descriptor.command(worktree_store, cx).await {
141 Ok(command) => {
142 Some(ContextServerConfiguration::Extension { command, settings })
143 }
144 Err(e) => {
145 log::error!(
146 "Failed to create context server configuration from settings: {e:#}"
147 );
148 None
149 }
150 }
151 }
152 ContextServerSettings::Http {
153 enabled: _,
154 url,
155 headers: auth,
156 timeout,
157 } => {
158 let url = url::Url::parse(&url).log_err()?;
159 Some(ContextServerConfiguration::Http {
160 url,
161 headers: auth,
162 timeout,
163 })
164 }
165 }
166 }
167}
168
169pub type ContextServerFactory =
170 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
171
172pub struct ContextServerStore {
173 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
174 servers: HashMap<ContextServerId, ContextServerState>,
175 worktree_store: Entity<WorktreeStore>,
176 project: WeakEntity<Project>,
177 registry: Entity<ContextServerDescriptorRegistry>,
178 update_servers_task: Option<Task<Result<()>>>,
179 context_server_factory: Option<ContextServerFactory>,
180 needs_server_update: bool,
181 _subscriptions: Vec<Subscription>,
182}
183
184pub enum Event {
185 ServerStatusChanged {
186 server_id: ContextServerId,
187 status: ContextServerStatus,
188 },
189}
190
191impl EventEmitter<Event> for ContextServerStore {}
192
193impl ContextServerStore {
194 pub fn new(
195 worktree_store: Entity<WorktreeStore>,
196 weak_project: WeakEntity<Project>,
197 cx: &mut Context<Self>,
198 ) -> Self {
199 Self::new_internal(
200 true,
201 None,
202 ContextServerDescriptorRegistry::default_global(cx),
203 worktree_store,
204 weak_project,
205 cx,
206 )
207 }
208
209 /// Returns all configured context server ids, excluding the ones that are disabled
210 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
211 self.context_server_settings
212 .iter()
213 .filter(|(_, settings)| settings.enabled())
214 .map(|(id, _)| ContextServerId(id.clone()))
215 .collect()
216 }
217
218 #[cfg(any(test, feature = "test-support"))]
219 pub fn test(
220 registry: Entity<ContextServerDescriptorRegistry>,
221 worktree_store: Entity<WorktreeStore>,
222 weak_project: WeakEntity<Project>,
223 cx: &mut Context<Self>,
224 ) -> Self {
225 Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
226 }
227
228 #[cfg(any(test, feature = "test-support"))]
229 pub fn test_maintain_server_loop(
230 context_server_factory: Option<ContextServerFactory>,
231 registry: Entity<ContextServerDescriptorRegistry>,
232 worktree_store: Entity<WorktreeStore>,
233 weak_project: WeakEntity<Project>,
234 cx: &mut Context<Self>,
235 ) -> Self {
236 Self::new_internal(
237 true,
238 context_server_factory,
239 registry,
240 worktree_store,
241 weak_project,
242 cx,
243 )
244 }
245
246 fn new_internal(
247 maintain_server_loop: bool,
248 context_server_factory: Option<ContextServerFactory>,
249 registry: Entity<ContextServerDescriptorRegistry>,
250 worktree_store: Entity<WorktreeStore>,
251 weak_project: WeakEntity<Project>,
252 cx: &mut Context<Self>,
253 ) -> Self {
254 let subscriptions = if maintain_server_loop {
255 vec![
256 cx.observe(®istry, |this, _registry, cx| {
257 this.available_context_servers_changed(cx);
258 }),
259 cx.observe_global::<SettingsStore>(|this, cx| {
260 let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
261 if &this.context_server_settings == settings {
262 return;
263 }
264 this.context_server_settings = settings.clone();
265 this.available_context_servers_changed(cx);
266 }),
267 ]
268 } else {
269 Vec::new()
270 };
271
272 let mut this = Self {
273 _subscriptions: subscriptions,
274 context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
275 .clone(),
276 worktree_store,
277 project: weak_project,
278 registry,
279 needs_server_update: false,
280 servers: HashMap::default(),
281 update_servers_task: None,
282 context_server_factory,
283 };
284 if maintain_server_loop {
285 this.available_context_servers_changed(cx);
286 }
287 this
288 }
289
290 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
291 self.servers.get(id).map(|state| state.server())
292 }
293
294 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
295 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
296 Some(server.clone())
297 } else {
298 None
299 }
300 }
301
302 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
303 self.servers.get(id).map(ContextServerStatus::from_state)
304 }
305
306 pub fn configuration_for_server(
307 &self,
308 id: &ContextServerId,
309 ) -> Option<Arc<ContextServerConfiguration>> {
310 self.servers.get(id).map(|state| state.configuration())
311 }
312
313 pub fn server_ids(&self, cx: &App) -> HashSet<ContextServerId> {
314 self.servers
315 .keys()
316 .cloned()
317 .chain(
318 self.registry
319 .read(cx)
320 .context_server_descriptors()
321 .into_iter()
322 .map(|(id, _)| ContextServerId(id)),
323 )
324 .collect()
325 }
326
327 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
328 self.servers
329 .values()
330 .filter_map(|state| {
331 if let ContextServerState::Running { server, .. } = state {
332 Some(server.clone())
333 } else {
334 None
335 }
336 })
337 .collect()
338 }
339
340 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
341 cx.spawn(async move |this, cx| {
342 let this = this.upgrade().context("Context server store dropped")?;
343 let settings = this
344 .update(cx, |this, _| {
345 this.context_server_settings.get(&server.id().0).cloned()
346 })
347 .ok()
348 .flatten()
349 .context("Failed to get context server settings")?;
350
351 if !settings.enabled() {
352 return Ok(());
353 }
354
355 let (registry, worktree_store) = this.update(cx, |this, _| {
356 (this.registry.clone(), this.worktree_store.clone())
357 })?;
358 let configuration = ContextServerConfiguration::from_settings(
359 settings,
360 server.id(),
361 registry,
362 worktree_store,
363 cx,
364 )
365 .await
366 .context("Failed to create context server configuration")?;
367
368 this.update(cx, |this, cx| {
369 this.run_server(server, Arc::new(configuration), cx)
370 })
371 })
372 .detach_and_log_err(cx);
373 }
374
375 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
376 if matches!(
377 self.servers.get(id),
378 Some(ContextServerState::Stopped { .. })
379 ) {
380 return Ok(());
381 }
382
383 let state = self
384 .servers
385 .remove(id)
386 .context("Context server not found")?;
387
388 let server = state.server();
389 let configuration = state.configuration();
390 let mut result = Ok(());
391 if let ContextServerState::Running { server, .. } = &state {
392 result = server.stop();
393 }
394 drop(state);
395
396 self.update_server_state(
397 id.clone(),
398 ContextServerState::Stopped {
399 configuration,
400 server,
401 },
402 cx,
403 );
404
405 result
406 }
407
408 fn run_server(
409 &mut self,
410 server: Arc<ContextServer>,
411 configuration: Arc<ContextServerConfiguration>,
412 cx: &mut Context<Self>,
413 ) {
414 let id = server.id();
415 if matches!(
416 self.servers.get(&id),
417 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
418 ) {
419 self.stop_server(&id, cx).log_err();
420 }
421 let task = cx.spawn({
422 let id = server.id();
423 let server = server.clone();
424 let configuration = configuration.clone();
425
426 async move |this, cx| {
427 match server.clone().start(cx).await {
428 Ok(_) => {
429 debug_assert!(server.client().is_some());
430
431 this.update(cx, |this, cx| {
432 this.update_server_state(
433 id.clone(),
434 ContextServerState::Running {
435 server,
436 configuration,
437 },
438 cx,
439 )
440 })
441 .log_err()
442 }
443 Err(err) => {
444 log::error!("{} context server failed to start: {}", id, err);
445 this.update(cx, |this, cx| {
446 this.update_server_state(
447 id.clone(),
448 ContextServerState::Error {
449 configuration,
450 server,
451 error: err.to_string().into(),
452 },
453 cx,
454 )
455 })
456 .log_err()
457 }
458 };
459 }
460 });
461
462 self.update_server_state(
463 id.clone(),
464 ContextServerState::Starting {
465 configuration,
466 _task: task,
467 server,
468 },
469 cx,
470 );
471 }
472
473 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
474 let state = self
475 .servers
476 .remove(id)
477 .context("Context server not found")?;
478 drop(state);
479 cx.emit(Event::ServerStatusChanged {
480 server_id: id.clone(),
481 status: ContextServerStatus::Stopped,
482 });
483 Ok(())
484 }
485
486 fn create_context_server(
487 &self,
488 id: ContextServerId,
489 configuration: Arc<ContextServerConfiguration>,
490 cx: &mut Context<Self>,
491 ) -> Result<Arc<ContextServer>> {
492 // Get global timeout from settings
493 let global_timeout = ProjectSettings::get_global(cx).context_server_timeout;
494
495 if let Some(factory) = self.context_server_factory.as_ref() {
496 return Ok(factory(id, configuration));
497 }
498
499 match configuration.as_ref() {
500 ContextServerConfiguration::Http {
501 url,
502 headers,
503 timeout,
504 } => {
505 // Apply timeout precedence for HTTP servers: per-server > global
506 let resolved_timeout = timeout.unwrap_or(global_timeout);
507
508 Ok(Arc::new(ContextServer::http(
509 id,
510 url,
511 headers.clone(),
512 cx.http_client(),
513 cx.background_executor().clone(),
514 Some(Duration::from_millis(resolved_timeout)),
515 )?))
516 }
517 _ => {
518 let root_path = self
519 .project
520 .read_with(cx, |project, cx| project.active_project_directory(cx))
521 .ok()
522 .flatten()
523 .or_else(|| {
524 self.worktree_store.read_with(cx, |store, cx| {
525 store.visible_worktrees(cx).fold(None, |acc, item| {
526 if acc.is_none() {
527 item.read(cx).root_dir()
528 } else {
529 acc
530 }
531 })
532 })
533 });
534
535 // Apply timeout precedence for stdio servers: per-server > global
536 let mut command_with_timeout = configuration.command().unwrap().clone();
537 if command_with_timeout.timeout.is_none() {
538 command_with_timeout.timeout = Some(global_timeout);
539 }
540
541 Ok(Arc::new(ContextServer::stdio(
542 id,
543 command_with_timeout,
544 root_path,
545 )))
546 }
547 }
548 }
549
550 fn resolve_context_server_settings<'a>(
551 worktree_store: &'a Entity<WorktreeStore>,
552 cx: &'a App,
553 ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
554 let location = worktree_store
555 .read(cx)
556 .visible_worktrees(cx)
557 .next()
558 .map(|worktree| settings::SettingsLocation {
559 worktree_id: worktree.read(cx).id(),
560 path: RelPath::empty(),
561 });
562 &ProjectSettings::get(location, cx).context_servers
563 }
564
565 fn update_server_state(
566 &mut self,
567 id: ContextServerId,
568 state: ContextServerState,
569 cx: &mut Context<Self>,
570 ) {
571 let status = ContextServerStatus::from_state(&state);
572 self.servers.insert(id.clone(), state);
573 cx.emit(Event::ServerStatusChanged {
574 server_id: id,
575 status,
576 });
577 }
578
579 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
580 if self.update_servers_task.is_some() {
581 self.needs_server_update = true;
582 } else {
583 self.needs_server_update = false;
584 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
585 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
586 log::error!("Error maintaining context servers: {}", err);
587 }
588
589 this.update(cx, |this, cx| {
590 this.update_servers_task.take();
591 if this.needs_server_update {
592 this.available_context_servers_changed(cx);
593 }
594 })?;
595
596 Ok(())
597 }));
598 }
599 }
600
601 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
602 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
603 (
604 this.context_server_settings.clone(),
605 this.registry.clone(),
606 this.worktree_store.clone(),
607 )
608 })?;
609
610 for (id, _) in
611 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
612 {
613 configured_servers
614 .entry(id)
615 .or_insert(ContextServerSettings::default_extension());
616 }
617
618 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
619 configured_servers
620 .into_iter()
621 .partition(|(_, settings)| settings.enabled());
622
623 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
624 let id = ContextServerId(id);
625 ContextServerConfiguration::from_settings(
626 settings,
627 id.clone(),
628 registry.clone(),
629 worktree_store.clone(),
630 cx,
631 )
632 .map(|config| (id, config))
633 }))
634 .await
635 .into_iter()
636 .filter_map(|(id, config)| config.map(|config| (id, config)))
637 .collect::<HashMap<_, _>>();
638
639 let mut servers_to_start = Vec::new();
640 let mut servers_to_remove = HashSet::default();
641 let mut servers_to_stop = HashSet::default();
642
643 this.update(cx, |this, cx| {
644 for server_id in this.servers.keys() {
645 // All servers that are not in desired_servers should be removed from the store.
646 // This can happen if the user removed a server from the context server settings.
647 if !configured_servers.contains_key(server_id) {
648 if disabled_servers.contains_key(&server_id.0) {
649 servers_to_stop.insert(server_id.clone());
650 } else {
651 servers_to_remove.insert(server_id.clone());
652 }
653 }
654 }
655
656 for (id, config) in configured_servers {
657 let state = this.servers.get(&id);
658 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
659 let existing_config = state.as_ref().map(|state| state.configuration());
660 if existing_config.as_deref() != Some(&config) || is_stopped {
661 let config = Arc::new(config);
662 let server = this.create_context_server(id.clone(), config.clone(), cx)?;
663 servers_to_start.push((server, config));
664 if this.servers.contains_key(&id) {
665 servers_to_stop.insert(id);
666 }
667 }
668 }
669
670 anyhow::Ok(())
671 })??;
672
673 this.update(cx, |this, cx| {
674 for id in servers_to_stop {
675 this.stop_server(&id, cx)?;
676 }
677 for id in servers_to_remove {
678 this.remove_server(&id, cx)?;
679 }
680 for (server, config) in servers_to_start {
681 this.run_server(server, config, cx);
682 }
683 anyhow::Ok(())
684 })?
685 }
686}
687
688#[cfg(test)]
689mod tests {
690 use super::*;
691 use crate::{
692 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
693 project_settings::ProjectSettings,
694 };
695 use context_server::test::create_fake_transport;
696 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
697 use http_client::{FakeHttpClient, Response};
698 use serde_json::json;
699 use std::{cell::RefCell, path::PathBuf, rc::Rc};
700 use util::path;
701
702 #[gpui::test]
703 async fn test_context_server_status(cx: &mut TestAppContext) {
704 const SERVER_1_ID: &str = "mcp-1";
705 const SERVER_2_ID: &str = "mcp-2";
706
707 let (_fs, project) = setup_context_server_test(
708 cx,
709 json!({"code.rs": ""}),
710 vec![
711 (SERVER_1_ID.into(), dummy_server_settings()),
712 (SERVER_2_ID.into(), dummy_server_settings()),
713 ],
714 )
715 .await;
716
717 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
718 let store = cx.new(|cx| {
719 ContextServerStore::test(
720 registry.clone(),
721 project.read(cx).worktree_store(),
722 project.downgrade(),
723 cx,
724 )
725 });
726
727 let server_1_id = ContextServerId(SERVER_1_ID.into());
728 let server_2_id = ContextServerId(SERVER_2_ID.into());
729
730 let server_1 = Arc::new(ContextServer::new(
731 server_1_id.clone(),
732 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
733 ));
734 let server_2 = Arc::new(ContextServer::new(
735 server_2_id.clone(),
736 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
737 ));
738
739 store.update(cx, |store, cx| store.start_server(server_1, cx));
740
741 cx.run_until_parked();
742
743 cx.update(|cx| {
744 assert_eq!(
745 store.read(cx).status_for_server(&server_1_id),
746 Some(ContextServerStatus::Running)
747 );
748 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
749 });
750
751 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
752
753 cx.run_until_parked();
754
755 cx.update(|cx| {
756 assert_eq!(
757 store.read(cx).status_for_server(&server_1_id),
758 Some(ContextServerStatus::Running)
759 );
760 assert_eq!(
761 store.read(cx).status_for_server(&server_2_id),
762 Some(ContextServerStatus::Running)
763 );
764 });
765
766 store
767 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
768 .unwrap();
769
770 cx.update(|cx| {
771 assert_eq!(
772 store.read(cx).status_for_server(&server_1_id),
773 Some(ContextServerStatus::Running)
774 );
775 assert_eq!(
776 store.read(cx).status_for_server(&server_2_id),
777 Some(ContextServerStatus::Stopped)
778 );
779 });
780 }
781
782 #[gpui::test]
783 async fn test_context_server_status_events(cx: &mut TestAppContext) {
784 const SERVER_1_ID: &str = "mcp-1";
785 const SERVER_2_ID: &str = "mcp-2";
786
787 let (_fs, project) = setup_context_server_test(
788 cx,
789 json!({"code.rs": ""}),
790 vec![
791 (SERVER_1_ID.into(), dummy_server_settings()),
792 (SERVER_2_ID.into(), dummy_server_settings()),
793 ],
794 )
795 .await;
796
797 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
798 let store = cx.new(|cx| {
799 ContextServerStore::test(
800 registry.clone(),
801 project.read(cx).worktree_store(),
802 project.downgrade(),
803 cx,
804 )
805 });
806
807 let server_1_id = ContextServerId(SERVER_1_ID.into());
808 let server_2_id = ContextServerId(SERVER_2_ID.into());
809
810 let server_1 = Arc::new(ContextServer::new(
811 server_1_id.clone(),
812 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
813 ));
814 let server_2 = Arc::new(ContextServer::new(
815 server_2_id.clone(),
816 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
817 ));
818
819 let _server_events = assert_server_events(
820 &store,
821 vec![
822 (server_1_id.clone(), ContextServerStatus::Starting),
823 (server_1_id, ContextServerStatus::Running),
824 (server_2_id.clone(), ContextServerStatus::Starting),
825 (server_2_id.clone(), ContextServerStatus::Running),
826 (server_2_id.clone(), ContextServerStatus::Stopped),
827 ],
828 cx,
829 );
830
831 store.update(cx, |store, cx| store.start_server(server_1, cx));
832
833 cx.run_until_parked();
834
835 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
836
837 cx.run_until_parked();
838
839 store
840 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
841 .unwrap();
842 }
843
844 #[gpui::test(iterations = 25)]
845 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
846 const SERVER_1_ID: &str = "mcp-1";
847
848 let (_fs, project) = setup_context_server_test(
849 cx,
850 json!({"code.rs": ""}),
851 vec![(SERVER_1_ID.into(), dummy_server_settings())],
852 )
853 .await;
854
855 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
856 let store = cx.new(|cx| {
857 ContextServerStore::test(
858 registry.clone(),
859 project.read(cx).worktree_store(),
860 project.downgrade(),
861 cx,
862 )
863 });
864
865 let server_id = ContextServerId(SERVER_1_ID.into());
866
867 let server_with_same_id_1 = Arc::new(ContextServer::new(
868 server_id.clone(),
869 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
870 ));
871 let server_with_same_id_2 = Arc::new(ContextServer::new(
872 server_id.clone(),
873 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
874 ));
875
876 // If we start another server with the same id, we should report that we stopped the previous one
877 let _server_events = assert_server_events(
878 &store,
879 vec![
880 (server_id.clone(), ContextServerStatus::Starting),
881 (server_id.clone(), ContextServerStatus::Stopped),
882 (server_id.clone(), ContextServerStatus::Starting),
883 (server_id.clone(), ContextServerStatus::Running),
884 ],
885 cx,
886 );
887
888 store.update(cx, |store, cx| {
889 store.start_server(server_with_same_id_1.clone(), cx)
890 });
891 store.update(cx, |store, cx| {
892 store.start_server(server_with_same_id_2.clone(), cx)
893 });
894
895 cx.run_until_parked();
896
897 cx.update(|cx| {
898 assert_eq!(
899 store.read(cx).status_for_server(&server_id),
900 Some(ContextServerStatus::Running)
901 );
902 });
903 }
904
905 #[gpui::test]
906 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
907 const SERVER_1_ID: &str = "mcp-1";
908 const SERVER_2_ID: &str = "mcp-2";
909
910 let server_1_id = ContextServerId(SERVER_1_ID.into());
911 let server_2_id = ContextServerId(SERVER_2_ID.into());
912
913 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
914
915 let (_fs, project) = setup_context_server_test(
916 cx,
917 json!({"code.rs": ""}),
918 vec![(
919 SERVER_1_ID.into(),
920 ContextServerSettings::Extension {
921 enabled: true,
922 settings: json!({
923 "somevalue": true
924 }),
925 },
926 )],
927 )
928 .await;
929
930 let executor = cx.executor();
931 let registry = cx.new(|cx| {
932 let mut registry = ContextServerDescriptorRegistry::new();
933 registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
934 registry
935 });
936 let store = cx.new(|cx| {
937 ContextServerStore::test_maintain_server_loop(
938 Some(Box::new(move |id, _| {
939 Arc::new(ContextServer::new(
940 id.clone(),
941 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
942 ))
943 })),
944 registry.clone(),
945 project.read(cx).worktree_store(),
946 project.downgrade(),
947 cx,
948 )
949 });
950
951 // Ensure that mcp-1 starts up
952 {
953 let _server_events = assert_server_events(
954 &store,
955 vec![
956 (server_1_id.clone(), ContextServerStatus::Starting),
957 (server_1_id.clone(), ContextServerStatus::Running),
958 ],
959 cx,
960 );
961 cx.run_until_parked();
962 }
963
964 // Ensure that mcp-1 is restarted when the configuration was changed
965 {
966 let _server_events = assert_server_events(
967 &store,
968 vec![
969 (server_1_id.clone(), ContextServerStatus::Stopped),
970 (server_1_id.clone(), ContextServerStatus::Starting),
971 (server_1_id.clone(), ContextServerStatus::Running),
972 ],
973 cx,
974 );
975 set_context_server_configuration(
976 vec![(
977 server_1_id.0.clone(),
978 settings::ContextServerSettingsContent::Extension {
979 enabled: true,
980 settings: json!({
981 "somevalue": false
982 }),
983 },
984 )],
985 cx,
986 );
987
988 cx.run_until_parked();
989 }
990
991 // Ensure that mcp-1 is not restarted when the configuration was not changed
992 {
993 let _server_events = assert_server_events(&store, vec![], cx);
994 set_context_server_configuration(
995 vec![(
996 server_1_id.0.clone(),
997 settings::ContextServerSettingsContent::Extension {
998 enabled: true,
999 settings: json!({
1000 "somevalue": false
1001 }),
1002 },
1003 )],
1004 cx,
1005 );
1006
1007 cx.run_until_parked();
1008 }
1009
1010 // Ensure that mcp-2 is started once it is added to the settings
1011 {
1012 let _server_events = assert_server_events(
1013 &store,
1014 vec![
1015 (server_2_id.clone(), ContextServerStatus::Starting),
1016 (server_2_id.clone(), ContextServerStatus::Running),
1017 ],
1018 cx,
1019 );
1020 set_context_server_configuration(
1021 vec![
1022 (
1023 server_1_id.0.clone(),
1024 settings::ContextServerSettingsContent::Extension {
1025 enabled: true,
1026 settings: json!({
1027 "somevalue": false
1028 }),
1029 },
1030 ),
1031 (
1032 server_2_id.0.clone(),
1033 settings::ContextServerSettingsContent::Stdio {
1034 enabled: true,
1035 command: ContextServerCommand {
1036 path: "somebinary".into(),
1037 args: vec!["arg".to_string()],
1038 env: None,
1039 timeout: None,
1040 },
1041 },
1042 ),
1043 ],
1044 cx,
1045 );
1046
1047 cx.run_until_parked();
1048 }
1049
1050 // Ensure that mcp-2 is restarted once the args have changed
1051 {
1052 let _server_events = assert_server_events(
1053 &store,
1054 vec![
1055 (server_2_id.clone(), ContextServerStatus::Stopped),
1056 (server_2_id.clone(), ContextServerStatus::Starting),
1057 (server_2_id.clone(), ContextServerStatus::Running),
1058 ],
1059 cx,
1060 );
1061 set_context_server_configuration(
1062 vec![
1063 (
1064 server_1_id.0.clone(),
1065 settings::ContextServerSettingsContent::Extension {
1066 enabled: true,
1067 settings: json!({
1068 "somevalue": false
1069 }),
1070 },
1071 ),
1072 (
1073 server_2_id.0.clone(),
1074 settings::ContextServerSettingsContent::Stdio {
1075 enabled: true,
1076 command: ContextServerCommand {
1077 path: "somebinary".into(),
1078 args: vec!["anotherArg".to_string()],
1079 env: None,
1080 timeout: None,
1081 },
1082 },
1083 ),
1084 ],
1085 cx,
1086 );
1087
1088 cx.run_until_parked();
1089 }
1090
1091 // Ensure that mcp-2 is removed once it is removed from the settings
1092 {
1093 let _server_events = assert_server_events(
1094 &store,
1095 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1096 cx,
1097 );
1098 set_context_server_configuration(
1099 vec![(
1100 server_1_id.0.clone(),
1101 settings::ContextServerSettingsContent::Extension {
1102 enabled: true,
1103 settings: json!({
1104 "somevalue": false
1105 }),
1106 },
1107 )],
1108 cx,
1109 );
1110
1111 cx.run_until_parked();
1112
1113 cx.update(|cx| {
1114 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1115 });
1116 }
1117
1118 // Ensure that nothing happens if the settings do not change
1119 {
1120 let _server_events = assert_server_events(&store, vec![], cx);
1121 set_context_server_configuration(
1122 vec![(
1123 server_1_id.0.clone(),
1124 settings::ContextServerSettingsContent::Extension {
1125 enabled: true,
1126 settings: json!({
1127 "somevalue": false
1128 }),
1129 },
1130 )],
1131 cx,
1132 );
1133
1134 cx.run_until_parked();
1135
1136 cx.update(|cx| {
1137 assert_eq!(
1138 store.read(cx).status_for_server(&server_1_id),
1139 Some(ContextServerStatus::Running)
1140 );
1141 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1142 });
1143 }
1144 }
1145
1146 #[gpui::test]
1147 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1148 const SERVER_1_ID: &str = "mcp-1";
1149
1150 let server_1_id = ContextServerId(SERVER_1_ID.into());
1151
1152 let (_fs, project) = setup_context_server_test(
1153 cx,
1154 json!({"code.rs": ""}),
1155 vec![(
1156 SERVER_1_ID.into(),
1157 ContextServerSettings::Stdio {
1158 enabled: true,
1159 command: ContextServerCommand {
1160 path: "somebinary".into(),
1161 args: vec!["arg".to_string()],
1162 env: None,
1163 timeout: None,
1164 },
1165 },
1166 )],
1167 )
1168 .await;
1169
1170 let executor = cx.executor();
1171 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1172 let store = cx.new(|cx| {
1173 ContextServerStore::test_maintain_server_loop(
1174 Some(Box::new(move |id, _| {
1175 Arc::new(ContextServer::new(
1176 id.clone(),
1177 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1178 ))
1179 })),
1180 registry.clone(),
1181 project.read(cx).worktree_store(),
1182 project.downgrade(),
1183 cx,
1184 )
1185 });
1186
1187 // Ensure that mcp-1 starts up
1188 {
1189 let _server_events = assert_server_events(
1190 &store,
1191 vec![
1192 (server_1_id.clone(), ContextServerStatus::Starting),
1193 (server_1_id.clone(), ContextServerStatus::Running),
1194 ],
1195 cx,
1196 );
1197 cx.run_until_parked();
1198 }
1199
1200 // Ensure that mcp-1 is stopped once it is disabled.
1201 {
1202 let _server_events = assert_server_events(
1203 &store,
1204 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1205 cx,
1206 );
1207 set_context_server_configuration(
1208 vec![(
1209 server_1_id.0.clone(),
1210 settings::ContextServerSettingsContent::Stdio {
1211 enabled: false,
1212 command: ContextServerCommand {
1213 path: "somebinary".into(),
1214 args: vec!["arg".to_string()],
1215 env: None,
1216 timeout: None,
1217 },
1218 },
1219 )],
1220 cx,
1221 );
1222
1223 cx.run_until_parked();
1224 }
1225
1226 // Ensure that mcp-1 is started once it is enabled again.
1227 {
1228 let _server_events = assert_server_events(
1229 &store,
1230 vec![
1231 (server_1_id.clone(), ContextServerStatus::Starting),
1232 (server_1_id.clone(), ContextServerStatus::Running),
1233 ],
1234 cx,
1235 );
1236 set_context_server_configuration(
1237 vec![(
1238 server_1_id.0.clone(),
1239 settings::ContextServerSettingsContent::Stdio {
1240 enabled: true,
1241 command: ContextServerCommand {
1242 path: "somebinary".into(),
1243 args: vec!["arg".to_string()],
1244 timeout: None,
1245 env: None,
1246 },
1247 },
1248 )],
1249 cx,
1250 );
1251
1252 cx.run_until_parked();
1253 }
1254 }
1255
1256 fn set_context_server_configuration(
1257 context_servers: Vec<(Arc<str>, settings::ContextServerSettingsContent)>,
1258 cx: &mut TestAppContext,
1259 ) {
1260 cx.update(|cx| {
1261 SettingsStore::update_global(cx, |store, cx| {
1262 store.update_user_settings(cx, |content| {
1263 content.project.context_servers.clear();
1264 for (id, config) in context_servers {
1265 content.project.context_servers.insert(id, config);
1266 }
1267 });
1268 })
1269 });
1270 }
1271
1272 #[gpui::test]
1273 async fn test_remote_context_server(cx: &mut TestAppContext) {
1274 const SERVER_ID: &str = "remote-server";
1275 let server_id = ContextServerId(SERVER_ID.into());
1276 let server_url = "http://example.com/api";
1277
1278 let (_fs, project) = setup_context_server_test(
1279 cx,
1280 json!({ "code.rs": "" }),
1281 vec![(
1282 SERVER_ID.into(),
1283 ContextServerSettings::Http {
1284 enabled: true,
1285 url: server_url.to_string(),
1286 headers: Default::default(),
1287 timeout: None,
1288 },
1289 )],
1290 )
1291 .await;
1292
1293 let client = FakeHttpClient::create(|_| async move {
1294 use http_client::AsyncBody;
1295
1296 let response = Response::builder()
1297 .status(200)
1298 .header("Content-Type", "application/json")
1299 .body(AsyncBody::from(
1300 serde_json::to_string(&json!({
1301 "jsonrpc": "2.0",
1302 "id": 0,
1303 "result": {
1304 "protocolVersion": "2024-11-05",
1305 "capabilities": {},
1306 "serverInfo": {
1307 "name": "test-server",
1308 "version": "1.0.0"
1309 }
1310 }
1311 }))
1312 .unwrap(),
1313 ))
1314 .unwrap();
1315 Ok(response)
1316 });
1317 cx.update(|cx| cx.set_http_client(client));
1318 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1319 let store = cx.new(|cx| {
1320 ContextServerStore::test_maintain_server_loop(
1321 None,
1322 registry.clone(),
1323 project.read(cx).worktree_store(),
1324 project.downgrade(),
1325 cx,
1326 )
1327 });
1328
1329 let _server_events = assert_server_events(
1330 &store,
1331 vec![
1332 (server_id.clone(), ContextServerStatus::Starting),
1333 (server_id.clone(), ContextServerStatus::Running),
1334 ],
1335 cx,
1336 );
1337 cx.run_until_parked();
1338 }
1339
1340 struct ServerEvents {
1341 received_event_count: Rc<RefCell<usize>>,
1342 expected_event_count: usize,
1343 _subscription: Subscription,
1344 }
1345
1346 impl Drop for ServerEvents {
1347 fn drop(&mut self) {
1348 let actual_event_count = *self.received_event_count.borrow();
1349 assert_eq!(
1350 actual_event_count, self.expected_event_count,
1351 "
1352 Expected to receive {} context server store events, but received {} events",
1353 self.expected_event_count, actual_event_count
1354 );
1355 }
1356 }
1357
1358 fn dummy_server_settings() -> ContextServerSettings {
1359 ContextServerSettings::Stdio {
1360 enabled: true,
1361 command: ContextServerCommand {
1362 path: "somebinary".into(),
1363 args: vec!["arg".to_string()],
1364 env: None,
1365 timeout: None,
1366 },
1367 }
1368 }
1369
1370 fn assert_server_events(
1371 store: &Entity<ContextServerStore>,
1372 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1373 cx: &mut TestAppContext,
1374 ) -> ServerEvents {
1375 cx.update(|cx| {
1376 let mut ix = 0;
1377 let received_event_count = Rc::new(RefCell::new(0));
1378 let expected_event_count = expected_events.len();
1379 let subscription = cx.subscribe(store, {
1380 let received_event_count = received_event_count.clone();
1381 move |_, event, _| match event {
1382 Event::ServerStatusChanged {
1383 server_id: actual_server_id,
1384 status: actual_status,
1385 } => {
1386 let (expected_server_id, expected_status) = &expected_events[ix];
1387
1388 assert_eq!(
1389 actual_server_id, expected_server_id,
1390 "Expected different server id at index {}",
1391 ix
1392 );
1393 assert_eq!(
1394 actual_status, expected_status,
1395 "Expected different status at index {}",
1396 ix
1397 );
1398 ix += 1;
1399 *received_event_count.borrow_mut() += 1;
1400 }
1401 }
1402 });
1403 ServerEvents {
1404 expected_event_count,
1405 received_event_count,
1406 _subscription: subscription,
1407 }
1408 })
1409 }
1410
1411 async fn setup_context_server_test(
1412 cx: &mut TestAppContext,
1413 files: serde_json::Value,
1414 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1415 ) -> (Arc<FakeFs>, Entity<Project>) {
1416 cx.update(|cx| {
1417 let settings_store = SettingsStore::test(cx);
1418 cx.set_global(settings_store);
1419 let mut settings = ProjectSettings::get_global(cx).clone();
1420 for (id, config) in context_server_configurations {
1421 settings.context_servers.insert(id, config);
1422 }
1423 ProjectSettings::override_global(settings, cx);
1424 });
1425
1426 let fs = FakeFs::new(cx.executor());
1427 fs.insert_tree(path!("/test"), files).await;
1428 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1429
1430 (fs, project)
1431 }
1432
1433 struct FakeContextServerDescriptor {
1434 path: PathBuf,
1435 }
1436
1437 impl FakeContextServerDescriptor {
1438 fn new(path: impl Into<PathBuf>) -> Self {
1439 Self { path: path.into() }
1440 }
1441 }
1442
1443 impl ContextServerDescriptor for FakeContextServerDescriptor {
1444 fn command(
1445 &self,
1446 _worktree_store: Entity<WorktreeStore>,
1447 _cx: &AsyncApp,
1448 ) -> Task<Result<ContextServerCommand>> {
1449 Task::ready(Ok(ContextServerCommand {
1450 path: self.path.clone(),
1451 args: vec!["arg1".to_string(), "arg2".to_string()],
1452 env: None,
1453 timeout: None,
1454 }))
1455 }
1456
1457 fn configuration(
1458 &self,
1459 _worktree_store: Entity<WorktreeStore>,
1460 _cx: &AsyncApp,
1461 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1462 Task::ready(Ok(None))
1463 }
1464 }
1465}