1pub mod extension;
2pub mod registry;
3
4use std::sync::Arc;
5use std::time::Duration;
6
7use anyhow::{Context as _, Result};
8use collections::{HashMap, HashSet};
9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
10use futures::{FutureExt as _, future::join_all};
11use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
12use registry::ContextServerDescriptorRegistry;
13use settings::{Settings as _, SettingsStore};
14use util::{ResultExt as _, rel_path::RelPath};
15
16use crate::{
17 Project,
18 project_settings::{ContextServerSettings, ProjectSettings},
19 worktree_store::WorktreeStore,
20};
21
22/// Maximum timeout for context server requests
23/// Prevents extremely large timeout values from tying up resources indefinitely.
24const MAX_TIMEOUT_SECS: u64 = 600; // 10 minutes
25
26pub fn init(cx: &mut App) {
27 extension::init(cx);
28}
29
30actions!(
31 context_server,
32 [
33 /// Restarts the context server.
34 Restart
35 ]
36);
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39pub enum ContextServerStatus {
40 Starting,
41 Running,
42 Stopped,
43 Error(Arc<str>),
44}
45
46impl ContextServerStatus {
47 fn from_state(state: &ContextServerState) -> Self {
48 match state {
49 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
50 ContextServerState::Running { .. } => ContextServerStatus::Running,
51 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
52 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
53 }
54 }
55}
56
57enum ContextServerState {
58 Starting {
59 server: Arc<ContextServer>,
60 configuration: Arc<ContextServerConfiguration>,
61 _task: Task<()>,
62 },
63 Running {
64 server: Arc<ContextServer>,
65 configuration: Arc<ContextServerConfiguration>,
66 },
67 Stopped {
68 server: Arc<ContextServer>,
69 configuration: Arc<ContextServerConfiguration>,
70 },
71 Error {
72 server: Arc<ContextServer>,
73 configuration: Arc<ContextServerConfiguration>,
74 error: Arc<str>,
75 },
76}
77
78impl ContextServerState {
79 pub fn server(&self) -> Arc<ContextServer> {
80 match self {
81 ContextServerState::Starting { server, .. } => server.clone(),
82 ContextServerState::Running { server, .. } => server.clone(),
83 ContextServerState::Stopped { server, .. } => server.clone(),
84 ContextServerState::Error { server, .. } => server.clone(),
85 }
86 }
87
88 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
89 match self {
90 ContextServerState::Starting { configuration, .. } => configuration.clone(),
91 ContextServerState::Running { configuration, .. } => configuration.clone(),
92 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
93 ContextServerState::Error { configuration, .. } => configuration.clone(),
94 }
95 }
96}
97
98#[derive(Debug, PartialEq, Eq)]
99pub enum ContextServerConfiguration {
100 Custom {
101 command: ContextServerCommand,
102 },
103 Extension {
104 command: ContextServerCommand,
105 settings: serde_json::Value,
106 },
107 Http {
108 url: url::Url,
109 headers: HashMap<String, String>,
110 timeout: Option<u64>,
111 },
112}
113
114impl ContextServerConfiguration {
115 pub fn command(&self) -> Option<&ContextServerCommand> {
116 match self {
117 ContextServerConfiguration::Custom { command } => Some(command),
118 ContextServerConfiguration::Extension { command, .. } => Some(command),
119 ContextServerConfiguration::Http { .. } => None,
120 }
121 }
122
123 pub async fn from_settings(
124 settings: ContextServerSettings,
125 id: ContextServerId,
126 registry: Entity<ContextServerDescriptorRegistry>,
127 worktree_store: Entity<WorktreeStore>,
128 cx: &AsyncApp,
129 ) -> Option<Self> {
130 match settings {
131 ContextServerSettings::Stdio {
132 enabled: _,
133 command,
134 } => Some(ContextServerConfiguration::Custom { command }),
135 ContextServerSettings::Extension {
136 enabled: _,
137 settings,
138 } => {
139 let descriptor =
140 cx.update(|cx| registry.read(cx).context_server_descriptor(&id.0))?;
141
142 match descriptor.command(worktree_store, cx).await {
143 Ok(command) => {
144 Some(ContextServerConfiguration::Extension { command, settings })
145 }
146 Err(e) => {
147 log::error!(
148 "Failed to create context server configuration from settings: {e:#}"
149 );
150 None
151 }
152 }
153 }
154 ContextServerSettings::Http {
155 enabled: _,
156 url,
157 headers: auth,
158 timeout,
159 } => {
160 let url = url::Url::parse(&url).log_err()?;
161 Some(ContextServerConfiguration::Http {
162 url,
163 headers: auth,
164 timeout,
165 })
166 }
167 }
168 }
169}
170
171pub type ContextServerFactory =
172 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
173
174pub struct ContextServerStore {
175 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
176 servers: HashMap<ContextServerId, ContextServerState>,
177 worktree_store: Entity<WorktreeStore>,
178 project: WeakEntity<Project>,
179 registry: Entity<ContextServerDescriptorRegistry>,
180 update_servers_task: Option<Task<Result<()>>>,
181 context_server_factory: Option<ContextServerFactory>,
182 needs_server_update: bool,
183 _subscriptions: Vec<Subscription>,
184}
185
186pub enum Event {
187 ServerStatusChanged {
188 server_id: ContextServerId,
189 status: ContextServerStatus,
190 },
191}
192
193impl EventEmitter<Event> for ContextServerStore {}
194
195impl ContextServerStore {
196 pub fn new(
197 worktree_store: Entity<WorktreeStore>,
198 weak_project: WeakEntity<Project>,
199 cx: &mut Context<Self>,
200 ) -> Self {
201 Self::new_internal(
202 true,
203 None,
204 ContextServerDescriptorRegistry::default_global(cx),
205 worktree_store,
206 weak_project,
207 cx,
208 )
209 }
210
211 /// Returns all configured context server ids, excluding the ones that are disabled
212 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
213 self.context_server_settings
214 .iter()
215 .filter(|(_, settings)| settings.enabled())
216 .map(|(id, _)| ContextServerId(id.clone()))
217 .collect()
218 }
219
220 #[cfg(any(test, feature = "test-support"))]
221 pub fn test(
222 registry: Entity<ContextServerDescriptorRegistry>,
223 worktree_store: Entity<WorktreeStore>,
224 weak_project: WeakEntity<Project>,
225 cx: &mut Context<Self>,
226 ) -> Self {
227 Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
228 }
229
230 #[cfg(any(test, feature = "test-support"))]
231 pub fn test_maintain_server_loop(
232 context_server_factory: Option<ContextServerFactory>,
233 registry: Entity<ContextServerDescriptorRegistry>,
234 worktree_store: Entity<WorktreeStore>,
235 weak_project: WeakEntity<Project>,
236 cx: &mut Context<Self>,
237 ) -> Self {
238 Self::new_internal(
239 true,
240 context_server_factory,
241 registry,
242 worktree_store,
243 weak_project,
244 cx,
245 )
246 }
247
248 #[cfg(any(test, feature = "test-support"))]
249 pub fn set_context_server_factory(&mut self, factory: ContextServerFactory) {
250 self.context_server_factory = Some(factory);
251 }
252
253 #[cfg(any(test, feature = "test-support"))]
254 pub fn registry(&self) -> &Entity<ContextServerDescriptorRegistry> {
255 &self.registry
256 }
257
258 #[cfg(any(test, feature = "test-support"))]
259 pub fn test_start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
260 let configuration = Arc::new(ContextServerConfiguration::Custom {
261 command: ContextServerCommand {
262 path: "test".into(),
263 args: vec![],
264 env: None,
265 timeout: None,
266 },
267 });
268 self.run_server(server, configuration, cx);
269 }
270
271 fn new_internal(
272 maintain_server_loop: bool,
273 context_server_factory: Option<ContextServerFactory>,
274 registry: Entity<ContextServerDescriptorRegistry>,
275 worktree_store: Entity<WorktreeStore>,
276 weak_project: WeakEntity<Project>,
277 cx: &mut Context<Self>,
278 ) -> Self {
279 let subscriptions = if maintain_server_loop {
280 vec![
281 cx.observe(®istry, |this, _registry, cx| {
282 this.available_context_servers_changed(cx);
283 }),
284 cx.observe_global::<SettingsStore>(|this, cx| {
285 let settings =
286 &Self::resolve_project_settings(&this.worktree_store, cx).context_servers;
287 if &this.context_server_settings == settings {
288 return;
289 }
290 this.context_server_settings = settings.clone();
291 this.available_context_servers_changed(cx);
292 }),
293 ]
294 } else {
295 Vec::new()
296 };
297
298 let mut this = Self {
299 _subscriptions: subscriptions,
300 context_server_settings: Self::resolve_project_settings(&worktree_store, cx)
301 .context_servers
302 .clone(),
303 worktree_store,
304 project: weak_project,
305 registry,
306 needs_server_update: false,
307 servers: HashMap::default(),
308 update_servers_task: None,
309 context_server_factory,
310 };
311 if maintain_server_loop {
312 this.available_context_servers_changed(cx);
313 }
314 this
315 }
316
317 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
318 self.servers.get(id).map(|state| state.server())
319 }
320
321 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
322 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
323 Some(server.clone())
324 } else {
325 None
326 }
327 }
328
329 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
330 self.servers.get(id).map(ContextServerStatus::from_state)
331 }
332
333 pub fn configuration_for_server(
334 &self,
335 id: &ContextServerId,
336 ) -> Option<Arc<ContextServerConfiguration>> {
337 self.servers.get(id).map(|state| state.configuration())
338 }
339
340 pub fn server_ids(&self, cx: &App) -> HashSet<ContextServerId> {
341 self.servers
342 .keys()
343 .cloned()
344 .chain(
345 self.registry
346 .read(cx)
347 .context_server_descriptors()
348 .into_iter()
349 .map(|(id, _)| ContextServerId(id)),
350 )
351 .collect()
352 }
353
354 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
355 self.servers
356 .values()
357 .filter_map(|state| {
358 if let ContextServerState::Running { server, .. } = state {
359 Some(server.clone())
360 } else {
361 None
362 }
363 })
364 .collect()
365 }
366
367 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
368 cx.spawn(async move |this, cx| {
369 let this = this.upgrade().context("Context server store dropped")?;
370 let settings = this
371 .update(cx, |this, _| {
372 this.context_server_settings.get(&server.id().0).cloned()
373 })
374 .context("Failed to get context server settings")?;
375
376 if !settings.enabled() {
377 return anyhow::Ok(());
378 }
379
380 let (registry, worktree_store) = this.update(cx, |this, _| {
381 (this.registry.clone(), this.worktree_store.clone())
382 });
383 let configuration = ContextServerConfiguration::from_settings(
384 settings,
385 server.id(),
386 registry,
387 worktree_store,
388 cx,
389 )
390 .await
391 .context("Failed to create context server configuration")?;
392
393 this.update(cx, |this, cx| {
394 this.run_server(server, Arc::new(configuration), cx)
395 });
396 Ok(())
397 })
398 .detach_and_log_err(cx);
399 }
400
401 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
402 if matches!(
403 self.servers.get(id),
404 Some(ContextServerState::Stopped { .. })
405 ) {
406 return Ok(());
407 }
408
409 let state = self
410 .servers
411 .remove(id)
412 .context("Context server not found")?;
413
414 let server = state.server();
415 let configuration = state.configuration();
416 let mut result = Ok(());
417 if let ContextServerState::Running { server, .. } = &state {
418 result = server.stop();
419 }
420 drop(state);
421
422 self.update_server_state(
423 id.clone(),
424 ContextServerState::Stopped {
425 configuration,
426 server,
427 },
428 cx,
429 );
430
431 result
432 }
433
434 fn run_server(
435 &mut self,
436 server: Arc<ContextServer>,
437 configuration: Arc<ContextServerConfiguration>,
438 cx: &mut Context<Self>,
439 ) {
440 let id = server.id();
441 if matches!(
442 self.servers.get(&id),
443 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
444 ) {
445 self.stop_server(&id, cx).log_err();
446 }
447 let task = cx.spawn({
448 let id = server.id();
449 let server = server.clone();
450 let configuration = configuration.clone();
451
452 async move |this, cx| {
453 match server.clone().start(cx).await {
454 Ok(_) => {
455 debug_assert!(server.client().is_some());
456
457 this.update(cx, |this, cx| {
458 this.update_server_state(
459 id.clone(),
460 ContextServerState::Running {
461 server,
462 configuration,
463 },
464 cx,
465 )
466 })
467 .log_err()
468 }
469 Err(err) => {
470 log::error!("{} context server failed to start: {}", id, err);
471 this.update(cx, |this, cx| {
472 this.update_server_state(
473 id.clone(),
474 ContextServerState::Error {
475 configuration,
476 server,
477 error: err.to_string().into(),
478 },
479 cx,
480 )
481 })
482 .log_err()
483 }
484 };
485 }
486 });
487
488 self.update_server_state(
489 id.clone(),
490 ContextServerState::Starting {
491 configuration,
492 _task: task,
493 server,
494 },
495 cx,
496 );
497 }
498
499 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
500 let state = self
501 .servers
502 .remove(id)
503 .context("Context server not found")?;
504 drop(state);
505 cx.emit(Event::ServerStatusChanged {
506 server_id: id.clone(),
507 status: ContextServerStatus::Stopped,
508 });
509 Ok(())
510 }
511
512 fn create_context_server(
513 &self,
514 id: ContextServerId,
515 configuration: Arc<ContextServerConfiguration>,
516 cx: &mut Context<Self>,
517 ) -> Result<Arc<ContextServer>> {
518 let global_timeout =
519 Self::resolve_project_settings(&self.worktree_store, cx).context_server_timeout;
520
521 if let Some(factory) = self.context_server_factory.as_ref() {
522 return Ok(factory(id, configuration));
523 }
524
525 match configuration.as_ref() {
526 ContextServerConfiguration::Http {
527 url,
528 headers,
529 timeout,
530 } => Ok(Arc::new(ContextServer::http(
531 id,
532 url,
533 headers.clone(),
534 cx.http_client(),
535 cx.background_executor().clone(),
536 Some(Duration::from_secs(
537 timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
538 )),
539 )?)),
540 _ => {
541 let root_path = self
542 .project
543 .read_with(cx, |project, cx| project.active_project_directory(cx))
544 .ok()
545 .flatten()
546 .or_else(|| {
547 self.worktree_store.read_with(cx, |store, cx| {
548 store.visible_worktrees(cx).fold(None, |acc, item| {
549 if acc.is_none() {
550 item.read(cx).root_dir()
551 } else {
552 acc
553 }
554 })
555 })
556 });
557
558 let mut command = configuration
559 .command()
560 .context("Missing command configuration for stdio context server")?
561 .clone();
562 command.timeout = Some(
563 command
564 .timeout
565 .unwrap_or(global_timeout)
566 .min(MAX_TIMEOUT_SECS),
567 );
568
569 Ok(Arc::new(ContextServer::stdio(id, command, root_path)))
570 }
571 }
572 }
573
574 fn resolve_project_settings<'a>(
575 worktree_store: &'a Entity<WorktreeStore>,
576 cx: &'a App,
577 ) -> &'a ProjectSettings {
578 let location = worktree_store
579 .read(cx)
580 .visible_worktrees(cx)
581 .next()
582 .map(|worktree| settings::SettingsLocation {
583 worktree_id: worktree.read(cx).id(),
584 path: RelPath::empty(),
585 });
586 ProjectSettings::get(location, cx)
587 }
588
589 fn update_server_state(
590 &mut self,
591 id: ContextServerId,
592 state: ContextServerState,
593 cx: &mut Context<Self>,
594 ) {
595 let status = ContextServerStatus::from_state(&state);
596 self.servers.insert(id.clone(), state);
597 cx.emit(Event::ServerStatusChanged {
598 server_id: id,
599 status,
600 });
601 }
602
603 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
604 if self.update_servers_task.is_some() {
605 self.needs_server_update = true;
606 } else {
607 self.needs_server_update = false;
608 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
609 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
610 log::error!("Error maintaining context servers: {}", err);
611 }
612
613 this.update(cx, |this, cx| {
614 this.update_servers_task.take();
615 if this.needs_server_update {
616 this.available_context_servers_changed(cx);
617 }
618 })?;
619
620 Ok(())
621 }));
622 }
623 }
624
625 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
626 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
627 (
628 this.context_server_settings.clone(),
629 this.registry.clone(),
630 this.worktree_store.clone(),
631 )
632 })?;
633
634 for (id, _) in registry.read_with(cx, |registry, _| registry.context_server_descriptors()) {
635 configured_servers
636 .entry(id)
637 .or_insert(ContextServerSettings::default_extension());
638 }
639
640 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
641 configured_servers
642 .into_iter()
643 .partition(|(_, settings)| settings.enabled());
644
645 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
646 let id = ContextServerId(id);
647 ContextServerConfiguration::from_settings(
648 settings,
649 id.clone(),
650 registry.clone(),
651 worktree_store.clone(),
652 cx,
653 )
654 .map(|config| (id, config))
655 }))
656 .await
657 .into_iter()
658 .filter_map(|(id, config)| config.map(|config| (id, config)))
659 .collect::<HashMap<_, _>>();
660
661 let mut servers_to_start = Vec::new();
662 let mut servers_to_remove = HashSet::default();
663 let mut servers_to_stop = HashSet::default();
664
665 this.update(cx, |this, cx| {
666 for server_id in this.servers.keys() {
667 // All servers that are not in desired_servers should be removed from the store.
668 // This can happen if the user removed a server from the context server settings.
669 if !configured_servers.contains_key(server_id) {
670 if disabled_servers.contains_key(&server_id.0) {
671 servers_to_stop.insert(server_id.clone());
672 } else {
673 servers_to_remove.insert(server_id.clone());
674 }
675 }
676 }
677
678 for (id, config) in configured_servers {
679 let state = this.servers.get(&id);
680 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
681 let existing_config = state.as_ref().map(|state| state.configuration());
682 if existing_config.as_deref() != Some(&config) || is_stopped {
683 let config = Arc::new(config);
684 let server = this.create_context_server(id.clone(), config.clone(), cx)?;
685 servers_to_start.push((server, config));
686 if this.servers.contains_key(&id) {
687 servers_to_stop.insert(id);
688 }
689 }
690 }
691
692 anyhow::Ok(())
693 })??;
694
695 this.update(cx, |this, cx| {
696 for id in servers_to_stop {
697 this.stop_server(&id, cx)?;
698 }
699 for id in servers_to_remove {
700 this.remove_server(&id, cx)?;
701 }
702 for (server, config) in servers_to_start {
703 this.run_server(server, config, cx);
704 }
705 anyhow::Ok(())
706 })?
707 }
708}
709
710#[cfg(test)]
711mod tests {
712 use super::*;
713 use crate::{
714 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
715 project_settings::ProjectSettings,
716 };
717 use context_server::test::create_fake_transport;
718 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
719 use http_client::{FakeHttpClient, Response};
720 use serde_json::json;
721 use std::{cell::RefCell, path::PathBuf, rc::Rc};
722 use util::path;
723
724 #[gpui::test]
725 async fn test_context_server_status(cx: &mut TestAppContext) {
726 const SERVER_1_ID: &str = "mcp-1";
727 const SERVER_2_ID: &str = "mcp-2";
728
729 let (_fs, project) = setup_context_server_test(cx, json!({"code.rs": ""}), vec![]).await;
730
731 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
732 let store = cx.new(|cx| {
733 ContextServerStore::test(
734 registry.clone(),
735 project.read(cx).worktree_store(),
736 project.downgrade(),
737 cx,
738 )
739 });
740
741 let server_1_id = ContextServerId(SERVER_1_ID.into());
742 let server_2_id = ContextServerId(SERVER_2_ID.into());
743
744 let server_1 = Arc::new(ContextServer::new(
745 server_1_id.clone(),
746 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
747 ));
748 let server_2 = Arc::new(ContextServer::new(
749 server_2_id.clone(),
750 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
751 ));
752
753 store.update(cx, |store, cx| store.test_start_server(server_1, cx));
754
755 cx.run_until_parked();
756
757 cx.update(|cx| {
758 assert_eq!(
759 store.read(cx).status_for_server(&server_1_id),
760 Some(ContextServerStatus::Running)
761 );
762 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
763 });
764
765 store.update(cx, |store, cx| {
766 store.test_start_server(server_2.clone(), cx)
767 });
768
769 cx.run_until_parked();
770
771 cx.update(|cx| {
772 assert_eq!(
773 store.read(cx).status_for_server(&server_1_id),
774 Some(ContextServerStatus::Running)
775 );
776 assert_eq!(
777 store.read(cx).status_for_server(&server_2_id),
778 Some(ContextServerStatus::Running)
779 );
780 });
781
782 store
783 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
784 .unwrap();
785
786 cx.update(|cx| {
787 assert_eq!(
788 store.read(cx).status_for_server(&server_1_id),
789 Some(ContextServerStatus::Running)
790 );
791 assert_eq!(
792 store.read(cx).status_for_server(&server_2_id),
793 Some(ContextServerStatus::Stopped)
794 );
795 });
796 }
797
798 #[gpui::test]
799 async fn test_context_server_status_events(cx: &mut TestAppContext) {
800 const SERVER_1_ID: &str = "mcp-1";
801 const SERVER_2_ID: &str = "mcp-2";
802
803 let (_fs, project) = setup_context_server_test(cx, json!({"code.rs": ""}), vec![]).await;
804
805 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
806 let store = cx.new(|cx| {
807 ContextServerStore::test(
808 registry.clone(),
809 project.read(cx).worktree_store(),
810 project.downgrade(),
811 cx,
812 )
813 });
814
815 let server_1_id = ContextServerId(SERVER_1_ID.into());
816 let server_2_id = ContextServerId(SERVER_2_ID.into());
817
818 let server_1 = Arc::new(ContextServer::new(
819 server_1_id.clone(),
820 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
821 ));
822 let server_2 = Arc::new(ContextServer::new(
823 server_2_id.clone(),
824 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
825 ));
826
827 let _server_events = assert_server_events(
828 &store,
829 vec![
830 (server_1_id.clone(), ContextServerStatus::Starting),
831 (server_1_id, ContextServerStatus::Running),
832 (server_2_id.clone(), ContextServerStatus::Starting),
833 (server_2_id.clone(), ContextServerStatus::Running),
834 (server_2_id.clone(), ContextServerStatus::Stopped),
835 ],
836 cx,
837 );
838
839 store.update(cx, |store, cx| store.test_start_server(server_1, cx));
840
841 cx.run_until_parked();
842
843 store.update(cx, |store, cx| {
844 store.test_start_server(server_2.clone(), cx)
845 });
846
847 cx.run_until_parked();
848
849 store
850 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
851 .unwrap();
852 }
853
854 #[gpui::test(iterations = 25)]
855 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
856 const SERVER_1_ID: &str = "mcp-1";
857
858 let (_fs, project) = setup_context_server_test(cx, json!({"code.rs": ""}), vec![]).await;
859
860 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
861 let store = cx.new(|cx| {
862 ContextServerStore::test(
863 registry.clone(),
864 project.read(cx).worktree_store(),
865 project.downgrade(),
866 cx,
867 )
868 });
869
870 let server_id = ContextServerId(SERVER_1_ID.into());
871
872 let server_with_same_id_1 = Arc::new(ContextServer::new(
873 server_id.clone(),
874 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
875 ));
876 let server_with_same_id_2 = Arc::new(ContextServer::new(
877 server_id.clone(),
878 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
879 ));
880
881 // If we start another server with the same id, we should report that we stopped the previous one
882 let _server_events = assert_server_events(
883 &store,
884 vec![
885 (server_id.clone(), ContextServerStatus::Starting),
886 (server_id.clone(), ContextServerStatus::Stopped),
887 (server_id.clone(), ContextServerStatus::Starting),
888 (server_id.clone(), ContextServerStatus::Running),
889 ],
890 cx,
891 );
892
893 store.update(cx, |store, cx| {
894 store.test_start_server(server_with_same_id_1.clone(), cx)
895 });
896 store.update(cx, |store, cx| {
897 store.test_start_server(server_with_same_id_2.clone(), cx)
898 });
899
900 cx.run_until_parked();
901
902 cx.update(|cx| {
903 assert_eq!(
904 store.read(cx).status_for_server(&server_id),
905 Some(ContextServerStatus::Running)
906 );
907 });
908 }
909
910 #[gpui::test]
911 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
912 const SERVER_1_ID: &str = "mcp-1";
913 const SERVER_2_ID: &str = "mcp-2";
914
915 let server_1_id = ContextServerId(SERVER_1_ID.into());
916 let server_2_id = ContextServerId(SERVER_2_ID.into());
917
918 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
919
920 let (_fs, project) = setup_context_server_test(cx, json!({"code.rs": ""}), vec![]).await;
921
922 let executor = cx.executor();
923 let store = project.read_with(cx, |project, _| project.context_server_store());
924 store.update(cx, |store, cx| {
925 store.set_context_server_factory(Box::new(move |id, _| {
926 Arc::new(ContextServer::new(
927 id.clone(),
928 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
929 ))
930 }));
931 store.registry().update(cx, |registry, cx| {
932 registry.register_context_server_descriptor(
933 SERVER_1_ID.into(),
934 fake_descriptor_1,
935 cx,
936 );
937 });
938 });
939
940 set_context_server_configuration(
941 vec![(
942 server_1_id.0.clone(),
943 settings::ContextServerSettingsContent::Extension {
944 enabled: true,
945 settings: json!({
946 "somevalue": true
947 }),
948 },
949 )],
950 cx,
951 );
952
953 // Ensure that mcp-1 starts up
954 {
955 let _server_events = assert_server_events(
956 &store,
957 vec![
958 (server_1_id.clone(), ContextServerStatus::Starting),
959 (server_1_id.clone(), ContextServerStatus::Running),
960 ],
961 cx,
962 );
963 cx.run_until_parked();
964 }
965
966 // Ensure that mcp-1 is restarted when the configuration was changed
967 {
968 let _server_events = assert_server_events(
969 &store,
970 vec![
971 (server_1_id.clone(), ContextServerStatus::Stopped),
972 (server_1_id.clone(), ContextServerStatus::Starting),
973 (server_1_id.clone(), ContextServerStatus::Running),
974 ],
975 cx,
976 );
977 set_context_server_configuration(
978 vec![(
979 server_1_id.0.clone(),
980 settings::ContextServerSettingsContent::Extension {
981 enabled: true,
982 settings: json!({
983 "somevalue": false
984 }),
985 },
986 )],
987 cx,
988 );
989
990 cx.run_until_parked();
991 }
992
993 // Ensure that mcp-1 is not restarted when the configuration was not changed
994 {
995 let _server_events = assert_server_events(&store, vec![], cx);
996 set_context_server_configuration(
997 vec![(
998 server_1_id.0.clone(),
999 settings::ContextServerSettingsContent::Extension {
1000 enabled: true,
1001 settings: json!({
1002 "somevalue": false
1003 }),
1004 },
1005 )],
1006 cx,
1007 );
1008
1009 cx.run_until_parked();
1010 }
1011
1012 // Ensure that mcp-2 is started once it is added to the settings
1013 {
1014 let _server_events = assert_server_events(
1015 &store,
1016 vec![
1017 (server_2_id.clone(), ContextServerStatus::Starting),
1018 (server_2_id.clone(), ContextServerStatus::Running),
1019 ],
1020 cx,
1021 );
1022 set_context_server_configuration(
1023 vec![
1024 (
1025 server_1_id.0.clone(),
1026 settings::ContextServerSettingsContent::Extension {
1027 enabled: true,
1028 settings: json!({
1029 "somevalue": false
1030 }),
1031 },
1032 ),
1033 (
1034 server_2_id.0.clone(),
1035 settings::ContextServerSettingsContent::Stdio {
1036 enabled: true,
1037 command: ContextServerCommand {
1038 path: "somebinary".into(),
1039 args: vec!["arg".to_string()],
1040 env: None,
1041 timeout: None,
1042 },
1043 },
1044 ),
1045 ],
1046 cx,
1047 );
1048
1049 cx.run_until_parked();
1050 }
1051
1052 // Ensure that mcp-2 is restarted once the args have changed
1053 {
1054 let _server_events = assert_server_events(
1055 &store,
1056 vec![
1057 (server_2_id.clone(), ContextServerStatus::Stopped),
1058 (server_2_id.clone(), ContextServerStatus::Starting),
1059 (server_2_id.clone(), ContextServerStatus::Running),
1060 ],
1061 cx,
1062 );
1063 set_context_server_configuration(
1064 vec![
1065 (
1066 server_1_id.0.clone(),
1067 settings::ContextServerSettingsContent::Extension {
1068 enabled: true,
1069 settings: json!({
1070 "somevalue": false
1071 }),
1072 },
1073 ),
1074 (
1075 server_2_id.0.clone(),
1076 settings::ContextServerSettingsContent::Stdio {
1077 enabled: true,
1078 command: ContextServerCommand {
1079 path: "somebinary".into(),
1080 args: vec!["anotherArg".to_string()],
1081 env: None,
1082 timeout: None,
1083 },
1084 },
1085 ),
1086 ],
1087 cx,
1088 );
1089
1090 cx.run_until_parked();
1091 }
1092
1093 // Ensure that mcp-2 is removed once it is removed from the settings
1094 {
1095 let _server_events = assert_server_events(
1096 &store,
1097 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1098 cx,
1099 );
1100 set_context_server_configuration(
1101 vec![(
1102 server_1_id.0.clone(),
1103 settings::ContextServerSettingsContent::Extension {
1104 enabled: true,
1105 settings: json!({
1106 "somevalue": false
1107 }),
1108 },
1109 )],
1110 cx,
1111 );
1112
1113 cx.run_until_parked();
1114
1115 cx.update(|cx| {
1116 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1117 });
1118 }
1119
1120 // Ensure that nothing happens if the settings do not change
1121 {
1122 let _server_events = assert_server_events(&store, vec![], cx);
1123 set_context_server_configuration(
1124 vec![(
1125 server_1_id.0.clone(),
1126 settings::ContextServerSettingsContent::Extension {
1127 enabled: true,
1128 settings: json!({
1129 "somevalue": false
1130 }),
1131 },
1132 )],
1133 cx,
1134 );
1135
1136 cx.run_until_parked();
1137
1138 cx.update(|cx| {
1139 assert_eq!(
1140 store.read(cx).status_for_server(&server_1_id),
1141 Some(ContextServerStatus::Running)
1142 );
1143 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1144 });
1145 }
1146 }
1147
1148 #[gpui::test]
1149 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1150 const SERVER_1_ID: &str = "mcp-1";
1151
1152 let server_1_id = ContextServerId(SERVER_1_ID.into());
1153
1154 let (_fs, project) = setup_context_server_test(cx, json!({"code.rs": ""}), vec![]).await;
1155
1156 let executor = cx.executor();
1157 let store = project.read_with(cx, |project, _| project.context_server_store());
1158 store.update(cx, |store, _| {
1159 store.set_context_server_factory(Box::new(move |id, _| {
1160 Arc::new(ContextServer::new(
1161 id.clone(),
1162 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1163 ))
1164 }));
1165 });
1166
1167 set_context_server_configuration(
1168 vec![(
1169 server_1_id.0.clone(),
1170 settings::ContextServerSettingsContent::Stdio {
1171 enabled: true,
1172 command: ContextServerCommand {
1173 path: "somebinary".into(),
1174 args: vec!["arg".to_string()],
1175 env: None,
1176 timeout: None,
1177 },
1178 },
1179 )],
1180 cx,
1181 );
1182
1183 // Ensure that mcp-1 starts up
1184 {
1185 let _server_events = assert_server_events(
1186 &store,
1187 vec![
1188 (server_1_id.clone(), ContextServerStatus::Starting),
1189 (server_1_id.clone(), ContextServerStatus::Running),
1190 ],
1191 cx,
1192 );
1193 cx.run_until_parked();
1194 }
1195
1196 // Ensure that mcp-1 is stopped once it is disabled.
1197 {
1198 let _server_events = assert_server_events(
1199 &store,
1200 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1201 cx,
1202 );
1203 set_context_server_configuration(
1204 vec![(
1205 server_1_id.0.clone(),
1206 settings::ContextServerSettingsContent::Stdio {
1207 enabled: false,
1208 command: ContextServerCommand {
1209 path: "somebinary".into(),
1210 args: vec!["arg".to_string()],
1211 env: None,
1212 timeout: None,
1213 },
1214 },
1215 )],
1216 cx,
1217 );
1218
1219 cx.run_until_parked();
1220 }
1221
1222 // Ensure that mcp-1 is started once it is enabled again.
1223 {
1224 let _server_events = assert_server_events(
1225 &store,
1226 vec![
1227 (server_1_id.clone(), ContextServerStatus::Starting),
1228 (server_1_id.clone(), ContextServerStatus::Running),
1229 ],
1230 cx,
1231 );
1232 set_context_server_configuration(
1233 vec![(
1234 server_1_id.0.clone(),
1235 settings::ContextServerSettingsContent::Stdio {
1236 enabled: true,
1237 command: ContextServerCommand {
1238 path: "somebinary".into(),
1239 args: vec!["arg".to_string()],
1240 timeout: None,
1241 env: None,
1242 },
1243 },
1244 )],
1245 cx,
1246 );
1247
1248 cx.run_until_parked();
1249 }
1250 }
1251
1252 fn set_context_server_configuration(
1253 context_servers: Vec<(Arc<str>, settings::ContextServerSettingsContent)>,
1254 cx: &mut TestAppContext,
1255 ) {
1256 cx.update(|cx| {
1257 SettingsStore::update_global(cx, |store, cx| {
1258 store.update_user_settings(cx, |content| {
1259 content.project.context_servers.clear();
1260 for (id, config) in context_servers {
1261 content.project.context_servers.insert(id, config);
1262 }
1263 });
1264 })
1265 });
1266 }
1267
1268 #[gpui::test]
1269 async fn test_remote_context_server(cx: &mut TestAppContext) {
1270 const SERVER_ID: &str = "remote-server";
1271 let server_id = ContextServerId(SERVER_ID.into());
1272 let server_url = "http://example.com/api";
1273
1274 let client = FakeHttpClient::create(|_| async move {
1275 use http_client::AsyncBody;
1276
1277 let response = Response::builder()
1278 .status(200)
1279 .header("Content-Type", "application/json")
1280 .body(AsyncBody::from(
1281 serde_json::to_string(&json!({
1282 "jsonrpc": "2.0",
1283 "id": 0,
1284 "result": {
1285 "protocolVersion": "2024-11-05",
1286 "capabilities": {},
1287 "serverInfo": {
1288 "name": "test-server",
1289 "version": "1.0.0"
1290 }
1291 }
1292 }))
1293 .unwrap(),
1294 ))
1295 .unwrap();
1296 Ok(response)
1297 });
1298 cx.update(|cx| cx.set_http_client(client));
1299
1300 let (_fs, project) = setup_context_server_test(cx, json!({ "code.rs": "" }), vec![]).await;
1301
1302 let store = project.read_with(cx, |project, _| project.context_server_store());
1303
1304 set_context_server_configuration(
1305 vec![(
1306 server_id.0.clone(),
1307 settings::ContextServerSettingsContent::Http {
1308 enabled: true,
1309 url: server_url.to_string(),
1310 headers: Default::default(),
1311 timeout: None,
1312 },
1313 )],
1314 cx,
1315 );
1316
1317 let _server_events = assert_server_events(
1318 &store,
1319 vec![
1320 (server_id.clone(), ContextServerStatus::Starting),
1321 (server_id.clone(), ContextServerStatus::Running),
1322 ],
1323 cx,
1324 );
1325 cx.run_until_parked();
1326 }
1327
1328 struct ServerEvents {
1329 received_event_count: Rc<RefCell<usize>>,
1330 expected_event_count: usize,
1331 _subscription: Subscription,
1332 }
1333
1334 impl Drop for ServerEvents {
1335 fn drop(&mut self) {
1336 let actual_event_count = *self.received_event_count.borrow();
1337 assert_eq!(
1338 actual_event_count, self.expected_event_count,
1339 "
1340 Expected to receive {} context server store events, but received {} events",
1341 self.expected_event_count, actual_event_count
1342 );
1343 }
1344 }
1345
1346 #[gpui::test]
1347 async fn test_context_server_global_timeout(cx: &mut TestAppContext) {
1348 cx.update(|cx| {
1349 let settings_store = SettingsStore::test(cx);
1350 cx.set_global(settings_store);
1351 SettingsStore::update_global(cx, |store, cx| {
1352 store
1353 .set_user_settings(r#"{"context_server_timeout": 90}"#, cx)
1354 .expect("Failed to set test user settings");
1355 });
1356 });
1357
1358 let (_fs, project) = setup_context_server_test(cx, json!({"code.rs": ""}), vec![]).await;
1359
1360 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1361 let store = cx.new(|cx| {
1362 ContextServerStore::test(
1363 registry.clone(),
1364 project.read(cx).worktree_store(),
1365 project.downgrade(),
1366 cx,
1367 )
1368 });
1369
1370 let result = store.update(cx, |store, cx| {
1371 store.create_context_server(
1372 ContextServerId("test-server".into()),
1373 Arc::new(ContextServerConfiguration::Http {
1374 url: url::Url::parse("http://localhost:8080")
1375 .expect("Failed to parse test URL"),
1376 headers: Default::default(),
1377 timeout: None,
1378 }),
1379 cx,
1380 )
1381 });
1382
1383 assert!(
1384 result.is_ok(),
1385 "Server should be created successfully with global timeout"
1386 );
1387 }
1388
1389 #[gpui::test]
1390 async fn test_context_server_per_server_timeout_override(cx: &mut TestAppContext) {
1391 const SERVER_ID: &str = "test-server";
1392
1393 cx.update(|cx| {
1394 let settings_store = SettingsStore::test(cx);
1395 cx.set_global(settings_store);
1396 SettingsStore::update_global(cx, |store, cx| {
1397 store
1398 .set_user_settings(r#"{"context_server_timeout": 60}"#, cx)
1399 .expect("Failed to set test user settings");
1400 });
1401 });
1402
1403 let (_fs, project) = setup_context_server_test(
1404 cx,
1405 json!({"code.rs": ""}),
1406 vec![(
1407 SERVER_ID.into(),
1408 ContextServerSettings::Http {
1409 enabled: true,
1410 url: "http://localhost:8080".to_string(),
1411 headers: Default::default(),
1412 timeout: Some(120),
1413 },
1414 )],
1415 )
1416 .await;
1417
1418 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1419 let store = cx.new(|cx| {
1420 ContextServerStore::test(
1421 registry.clone(),
1422 project.read(cx).worktree_store(),
1423 project.downgrade(),
1424 cx,
1425 )
1426 });
1427
1428 let result = store.update(cx, |store, cx| {
1429 store.create_context_server(
1430 ContextServerId("test-server".into()),
1431 Arc::new(ContextServerConfiguration::Http {
1432 url: url::Url::parse("http://localhost:8080")
1433 .expect("Failed to parse test URL"),
1434 headers: Default::default(),
1435 timeout: Some(120),
1436 }),
1437 cx,
1438 )
1439 });
1440
1441 assert!(
1442 result.is_ok(),
1443 "Server should be created successfully with per-server timeout override"
1444 );
1445 }
1446
1447 #[gpui::test]
1448 async fn test_context_server_stdio_timeout(cx: &mut TestAppContext) {
1449 let (_fs, project) = setup_context_server_test(cx, json!({"code.rs": ""}), vec![]).await;
1450
1451 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1452 let store = cx.new(|cx| {
1453 ContextServerStore::test(
1454 registry.clone(),
1455 project.read(cx).worktree_store(),
1456 project.downgrade(),
1457 cx,
1458 )
1459 });
1460
1461 let result = store.update(cx, |store, cx| {
1462 store.create_context_server(
1463 ContextServerId("stdio-server".into()),
1464 Arc::new(ContextServerConfiguration::Custom {
1465 command: ContextServerCommand {
1466 path: "/usr/bin/node".into(),
1467 args: vec!["server.js".into()],
1468 env: None,
1469 timeout: Some(180000),
1470 },
1471 }),
1472 cx,
1473 )
1474 });
1475
1476 assert!(
1477 result.is_ok(),
1478 "Stdio server should be created successfully with timeout"
1479 );
1480 }
1481
1482 fn assert_server_events(
1483 store: &Entity<ContextServerStore>,
1484 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1485 cx: &mut TestAppContext,
1486 ) -> ServerEvents {
1487 cx.update(|cx| {
1488 let mut ix = 0;
1489 let received_event_count = Rc::new(RefCell::new(0));
1490 let expected_event_count = expected_events.len();
1491 let subscription = cx.subscribe(store, {
1492 let received_event_count = received_event_count.clone();
1493 move |_, event, _| match event {
1494 Event::ServerStatusChanged {
1495 server_id: actual_server_id,
1496 status: actual_status,
1497 } => {
1498 let (expected_server_id, expected_status) = &expected_events[ix];
1499
1500 assert_eq!(
1501 actual_server_id, expected_server_id,
1502 "Expected different server id at index {}",
1503 ix
1504 );
1505 assert_eq!(
1506 actual_status, expected_status,
1507 "Expected different status at index {}",
1508 ix
1509 );
1510 ix += 1;
1511 *received_event_count.borrow_mut() += 1;
1512 }
1513 }
1514 });
1515 ServerEvents {
1516 expected_event_count,
1517 received_event_count,
1518 _subscription: subscription,
1519 }
1520 })
1521 }
1522
1523 async fn setup_context_server_test(
1524 cx: &mut TestAppContext,
1525 files: serde_json::Value,
1526 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1527 ) -> (Arc<FakeFs>, Entity<Project>) {
1528 cx.update(|cx| {
1529 let settings_store = SettingsStore::test(cx);
1530 cx.set_global(settings_store);
1531 let mut settings = ProjectSettings::get_global(cx).clone();
1532 for (id, config) in context_server_configurations {
1533 settings.context_servers.insert(id, config);
1534 }
1535 ProjectSettings::override_global(settings, cx);
1536 });
1537
1538 let fs = FakeFs::new(cx.executor());
1539 fs.insert_tree(path!("/test"), files).await;
1540 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1541
1542 (fs, project)
1543 }
1544
1545 struct FakeContextServerDescriptor {
1546 path: PathBuf,
1547 }
1548
1549 impl FakeContextServerDescriptor {
1550 fn new(path: impl Into<PathBuf>) -> Self {
1551 Self { path: path.into() }
1552 }
1553 }
1554
1555 impl ContextServerDescriptor for FakeContextServerDescriptor {
1556 fn command(
1557 &self,
1558 _worktree_store: Entity<WorktreeStore>,
1559 _cx: &AsyncApp,
1560 ) -> Task<Result<ContextServerCommand>> {
1561 Task::ready(Ok(ContextServerCommand {
1562 path: self.path.clone(),
1563 args: vec!["arg1".to_string(), "arg2".to_string()],
1564 env: None,
1565 timeout: None,
1566 }))
1567 }
1568
1569 fn configuration(
1570 &self,
1571 _worktree_store: Entity<WorktreeStore>,
1572 _cx: &AsyncApp,
1573 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1574 Task::ready(Ok(None))
1575 }
1576 }
1577}