1pub mod extension;
2pub mod registry;
3
4use std::sync::Arc;
5
6use anyhow::{Context as _, Result};
7use collections::{HashMap, HashSet};
8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
9use futures::{FutureExt as _, future::join_all};
10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
11use registry::ContextServerDescriptorRegistry;
12use settings::{Settings as _, SettingsStore};
13use util::{ResultExt as _, rel_path::RelPath};
14
15use crate::{
16 Project,
17 project_settings::{ContextServerSettings, ProjectSettings},
18 trusted_worktrees::wait_for_workspace_trust,
19 worktree_store::WorktreeStore,
20};
21
22pub fn init(cx: &mut App) {
23 extension::init(cx);
24}
25
26actions!(
27 context_server,
28 [
29 /// Restarts the context server.
30 Restart
31 ]
32);
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
35pub enum ContextServerStatus {
36 Starting,
37 Running,
38 Stopped,
39 Error(Arc<str>),
40}
41
42impl ContextServerStatus {
43 fn from_state(state: &ContextServerState) -> Self {
44 match state {
45 ContextServerState::Starting { .. } => ContextServerStatus::Starting,
46 ContextServerState::Running { .. } => ContextServerStatus::Running,
47 ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
48 ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
49 }
50 }
51}
52
53enum ContextServerState {
54 Starting {
55 server: Arc<ContextServer>,
56 configuration: Arc<ContextServerConfiguration>,
57 _task: Task<()>,
58 },
59 Running {
60 server: Arc<ContextServer>,
61 configuration: Arc<ContextServerConfiguration>,
62 },
63 Stopped {
64 server: Arc<ContextServer>,
65 configuration: Arc<ContextServerConfiguration>,
66 },
67 Error {
68 server: Arc<ContextServer>,
69 configuration: Arc<ContextServerConfiguration>,
70 error: Arc<str>,
71 },
72}
73
74impl ContextServerState {
75 pub fn server(&self) -> Arc<ContextServer> {
76 match self {
77 ContextServerState::Starting { server, .. } => server.clone(),
78 ContextServerState::Running { server, .. } => server.clone(),
79 ContextServerState::Stopped { server, .. } => server.clone(),
80 ContextServerState::Error { server, .. } => server.clone(),
81 }
82 }
83
84 pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
85 match self {
86 ContextServerState::Starting { configuration, .. } => configuration.clone(),
87 ContextServerState::Running { configuration, .. } => configuration.clone(),
88 ContextServerState::Stopped { configuration, .. } => configuration.clone(),
89 ContextServerState::Error { configuration, .. } => configuration.clone(),
90 }
91 }
92}
93
94#[derive(Debug, PartialEq, Eq)]
95pub enum ContextServerConfiguration {
96 Custom {
97 command: ContextServerCommand,
98 },
99 Extension {
100 command: ContextServerCommand,
101 settings: serde_json::Value,
102 },
103 Http {
104 url: url::Url,
105 headers: HashMap<String, String>,
106 },
107}
108
109impl ContextServerConfiguration {
110 pub fn command(&self) -> Option<&ContextServerCommand> {
111 match self {
112 ContextServerConfiguration::Custom { command } => Some(command),
113 ContextServerConfiguration::Extension { command, .. } => Some(command),
114 ContextServerConfiguration::Http { .. } => None,
115 }
116 }
117
118 pub async fn from_settings(
119 settings: ContextServerSettings,
120 id: ContextServerId,
121 registry: Entity<ContextServerDescriptorRegistry>,
122 worktree_store: Entity<WorktreeStore>,
123 cx: &AsyncApp,
124 ) -> Option<Self> {
125 match settings {
126 ContextServerSettings::Stdio {
127 enabled: _,
128 command,
129 } => Some(ContextServerConfiguration::Custom { command }),
130 ContextServerSettings::Extension {
131 enabled: _,
132 settings,
133 } => {
134 let descriptor = cx
135 .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
136 .ok()
137 .flatten()?;
138
139 match descriptor.command(worktree_store, cx).await {
140 Ok(command) => {
141 Some(ContextServerConfiguration::Extension { command, settings })
142 }
143 Err(e) => {
144 log::error!(
145 "Failed to create context server configuration from settings: {e:#}"
146 );
147 None
148 }
149 }
150 }
151 ContextServerSettings::Http {
152 enabled: _,
153 url,
154 headers: auth,
155 } => {
156 let url = url::Url::parse(&url).log_err()?;
157 Some(ContextServerConfiguration::Http { url, headers: auth })
158 }
159 }
160 }
161}
162
163pub type ContextServerFactory =
164 Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
165
166pub struct ContextServerStore {
167 context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
168 servers: HashMap<ContextServerId, ContextServerState>,
169 worktree_store: Entity<WorktreeStore>,
170 project: WeakEntity<Project>,
171 registry: Entity<ContextServerDescriptorRegistry>,
172 update_servers_task: Option<Task<Result<()>>>,
173 context_server_factory: Option<ContextServerFactory>,
174 needs_server_update: bool,
175 _subscriptions: Vec<Subscription>,
176}
177
178pub enum Event {
179 ServerStatusChanged {
180 server_id: ContextServerId,
181 status: ContextServerStatus,
182 },
183}
184
185impl EventEmitter<Event> for ContextServerStore {}
186
187impl ContextServerStore {
188 pub fn new(
189 worktree_store: Entity<WorktreeStore>,
190 weak_project: WeakEntity<Project>,
191 cx: &mut Context<Self>,
192 ) -> Self {
193 Self::new_internal(
194 true,
195 None,
196 ContextServerDescriptorRegistry::default_global(cx),
197 worktree_store,
198 weak_project,
199 cx,
200 )
201 }
202
203 /// Returns all configured context server ids, excluding the ones that are disabled
204 pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
205 self.context_server_settings
206 .iter()
207 .filter(|(_, settings)| settings.enabled())
208 .map(|(id, _)| ContextServerId(id.clone()))
209 .collect()
210 }
211
212 #[cfg(any(test, feature = "test-support"))]
213 pub fn test(
214 registry: Entity<ContextServerDescriptorRegistry>,
215 worktree_store: Entity<WorktreeStore>,
216 weak_project: WeakEntity<Project>,
217 cx: &mut Context<Self>,
218 ) -> Self {
219 Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
220 }
221
222 #[cfg(any(test, feature = "test-support"))]
223 pub fn test_maintain_server_loop(
224 context_server_factory: Option<ContextServerFactory>,
225 registry: Entity<ContextServerDescriptorRegistry>,
226 worktree_store: Entity<WorktreeStore>,
227 weak_project: WeakEntity<Project>,
228 cx: &mut Context<Self>,
229 ) -> Self {
230 Self::new_internal(
231 true,
232 context_server_factory,
233 registry,
234 worktree_store,
235 weak_project,
236 cx,
237 )
238 }
239
240 fn new_internal(
241 maintain_server_loop: bool,
242 context_server_factory: Option<ContextServerFactory>,
243 registry: Entity<ContextServerDescriptorRegistry>,
244 worktree_store: Entity<WorktreeStore>,
245 weak_project: WeakEntity<Project>,
246 cx: &mut Context<Self>,
247 ) -> Self {
248 let subscriptions = if maintain_server_loop {
249 vec![
250 cx.observe(®istry, |this, _registry, cx| {
251 this.available_context_servers_changed(cx);
252 }),
253 cx.observe_global::<SettingsStore>(|this, cx| {
254 let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
255 if &this.context_server_settings == settings {
256 return;
257 }
258 this.context_server_settings = settings.clone();
259 this.available_context_servers_changed(cx);
260 }),
261 ]
262 } else {
263 Vec::new()
264 };
265
266 let mut this = Self {
267 _subscriptions: subscriptions,
268 context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
269 .clone(),
270 worktree_store,
271 project: weak_project,
272 registry,
273 needs_server_update: false,
274 servers: HashMap::default(),
275 update_servers_task: None,
276 context_server_factory,
277 };
278 if maintain_server_loop {
279 this.available_context_servers_changed(cx);
280 }
281 this
282 }
283
284 pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
285 self.servers.get(id).map(|state| state.server())
286 }
287
288 pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
289 if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
290 Some(server.clone())
291 } else {
292 None
293 }
294 }
295
296 pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
297 self.servers.get(id).map(ContextServerStatus::from_state)
298 }
299
300 pub fn configuration_for_server(
301 &self,
302 id: &ContextServerId,
303 ) -> Option<Arc<ContextServerConfiguration>> {
304 self.servers.get(id).map(|state| state.configuration())
305 }
306
307 pub fn server_ids(&self, cx: &App) -> HashSet<ContextServerId> {
308 self.servers
309 .keys()
310 .cloned()
311 .chain(
312 self.registry
313 .read(cx)
314 .context_server_descriptors()
315 .into_iter()
316 .map(|(id, _)| ContextServerId(id)),
317 )
318 .collect()
319 }
320
321 pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
322 self.servers
323 .values()
324 .filter_map(|state| {
325 if let ContextServerState::Running { server, .. } = state {
326 Some(server.clone())
327 } else {
328 None
329 }
330 })
331 .collect()
332 }
333
334 pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
335 cx.spawn(async move |this, cx| {
336 let wait_task = this.update(cx, |context_server_store, cx| {
337 context_server_store.project.update(cx, |project, cx| {
338 let remote_host = project.remote_connection_options(cx);
339 wait_for_workspace_trust(remote_host, "context servers", cx)
340 })
341 })??;
342 if let Some(wait_task) = wait_task {
343 wait_task.await;
344 }
345 let this = this.upgrade().context("Context server store dropped")?;
346 let settings = this
347 .update(cx, |this, _| {
348 this.context_server_settings.get(&server.id().0).cloned()
349 })
350 .ok()
351 .flatten()
352 .context("Failed to get context server settings")?;
353
354 if !settings.enabled() {
355 return Ok(());
356 }
357
358 let (registry, worktree_store) = this.update(cx, |this, _| {
359 (this.registry.clone(), this.worktree_store.clone())
360 })?;
361 let configuration = ContextServerConfiguration::from_settings(
362 settings,
363 server.id(),
364 registry,
365 worktree_store,
366 cx,
367 )
368 .await
369 .context("Failed to create context server configuration")?;
370
371 this.update(cx, |this, cx| {
372 this.run_server(server, Arc::new(configuration), cx)
373 })
374 })
375 .detach_and_log_err(cx);
376 }
377
378 pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
379 if matches!(
380 self.servers.get(id),
381 Some(ContextServerState::Stopped { .. })
382 ) {
383 return Ok(());
384 }
385
386 let state = self
387 .servers
388 .remove(id)
389 .context("Context server not found")?;
390
391 let server = state.server();
392 let configuration = state.configuration();
393 let mut result = Ok(());
394 if let ContextServerState::Running { server, .. } = &state {
395 result = server.stop();
396 }
397 drop(state);
398
399 self.update_server_state(
400 id.clone(),
401 ContextServerState::Stopped {
402 configuration,
403 server,
404 },
405 cx,
406 );
407
408 result
409 }
410
411 fn run_server(
412 &mut self,
413 server: Arc<ContextServer>,
414 configuration: Arc<ContextServerConfiguration>,
415 cx: &mut Context<Self>,
416 ) {
417 let id = server.id();
418 if matches!(
419 self.servers.get(&id),
420 Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
421 ) {
422 self.stop_server(&id, cx).log_err();
423 }
424 let task = cx.spawn({
425 let id = server.id();
426 let server = server.clone();
427 let configuration = configuration.clone();
428
429 async move |this, cx| {
430 match server.clone().start(cx).await {
431 Ok(_) => {
432 debug_assert!(server.client().is_some());
433
434 this.update(cx, |this, cx| {
435 this.update_server_state(
436 id.clone(),
437 ContextServerState::Running {
438 server,
439 configuration,
440 },
441 cx,
442 )
443 })
444 .log_err()
445 }
446 Err(err) => {
447 log::error!("{} context server failed to start: {}", id, err);
448 this.update(cx, |this, cx| {
449 this.update_server_state(
450 id.clone(),
451 ContextServerState::Error {
452 configuration,
453 server,
454 error: err.to_string().into(),
455 },
456 cx,
457 )
458 })
459 .log_err()
460 }
461 };
462 }
463 });
464
465 self.update_server_state(
466 id.clone(),
467 ContextServerState::Starting {
468 configuration,
469 _task: task,
470 server,
471 },
472 cx,
473 );
474 }
475
476 fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
477 let state = self
478 .servers
479 .remove(id)
480 .context("Context server not found")?;
481 drop(state);
482 cx.emit(Event::ServerStatusChanged {
483 server_id: id.clone(),
484 status: ContextServerStatus::Stopped,
485 });
486 Ok(())
487 }
488
489 fn create_context_server(
490 &self,
491 id: ContextServerId,
492 configuration: Arc<ContextServerConfiguration>,
493 cx: &mut Context<Self>,
494 ) -> Result<Arc<ContextServer>> {
495 if let Some(factory) = self.context_server_factory.as_ref() {
496 return Ok(factory(id, configuration));
497 }
498
499 match configuration.as_ref() {
500 ContextServerConfiguration::Http { url, headers } => Ok(Arc::new(ContextServer::http(
501 id,
502 url,
503 headers.clone(),
504 cx.http_client(),
505 cx.background_executor().clone(),
506 )?)),
507 _ => {
508 let root_path = self
509 .project
510 .read_with(cx, |project, cx| project.active_project_directory(cx))
511 .ok()
512 .flatten()
513 .or_else(|| {
514 self.worktree_store.read_with(cx, |store, cx| {
515 store.visible_worktrees(cx).fold(None, |acc, item| {
516 if acc.is_none() {
517 item.read(cx).root_dir()
518 } else {
519 acc
520 }
521 })
522 })
523 });
524 Ok(Arc::new(ContextServer::stdio(
525 id,
526 configuration.command().unwrap().clone(),
527 root_path,
528 )))
529 }
530 }
531 }
532
533 fn resolve_context_server_settings<'a>(
534 worktree_store: &'a Entity<WorktreeStore>,
535 cx: &'a App,
536 ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
537 let location = worktree_store
538 .read(cx)
539 .visible_worktrees(cx)
540 .next()
541 .map(|worktree| settings::SettingsLocation {
542 worktree_id: worktree.read(cx).id(),
543 path: RelPath::empty(),
544 });
545 &ProjectSettings::get(location, cx).context_servers
546 }
547
548 fn update_server_state(
549 &mut self,
550 id: ContextServerId,
551 state: ContextServerState,
552 cx: &mut Context<Self>,
553 ) {
554 let status = ContextServerStatus::from_state(&state);
555 self.servers.insert(id.clone(), state);
556 cx.emit(Event::ServerStatusChanged {
557 server_id: id,
558 status,
559 });
560 }
561
562 fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
563 if self.update_servers_task.is_some() {
564 self.needs_server_update = true;
565 } else {
566 self.needs_server_update = false;
567 self.update_servers_task = Some(cx.spawn(async move |this, cx| {
568 if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
569 log::error!("Error maintaining context servers: {}", err);
570 }
571
572 this.update(cx, |this, cx| {
573 this.update_servers_task.take();
574 if this.needs_server_update {
575 this.available_context_servers_changed(cx);
576 }
577 })?;
578
579 Ok(())
580 }));
581 }
582 }
583
584 async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
585 let wait_task = this.update(cx, |context_server_store, cx| {
586 context_server_store.project.update(cx, |project, cx| {
587 let remote_host = project.remote_connection_options(cx);
588 wait_for_workspace_trust(remote_host, "context servers", cx)
589 })
590 })??;
591 if let Some(wait_task) = wait_task {
592 wait_task.await;
593 }
594 let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
595 (
596 this.context_server_settings.clone(),
597 this.registry.clone(),
598 this.worktree_store.clone(),
599 )
600 })?;
601
602 for (id, _) in
603 registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
604 {
605 configured_servers
606 .entry(id)
607 .or_insert(ContextServerSettings::default_extension());
608 }
609
610 let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
611 configured_servers
612 .into_iter()
613 .partition(|(_, settings)| settings.enabled());
614
615 let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
616 let id = ContextServerId(id);
617 ContextServerConfiguration::from_settings(
618 settings,
619 id.clone(),
620 registry.clone(),
621 worktree_store.clone(),
622 cx,
623 )
624 .map(|config| (id, config))
625 }))
626 .await
627 .into_iter()
628 .filter_map(|(id, config)| config.map(|config| (id, config)))
629 .collect::<HashMap<_, _>>();
630
631 let mut servers_to_start = Vec::new();
632 let mut servers_to_remove = HashSet::default();
633 let mut servers_to_stop = HashSet::default();
634
635 this.update(cx, |this, cx| {
636 for server_id in this.servers.keys() {
637 // All servers that are not in desired_servers should be removed from the store.
638 // This can happen if the user removed a server from the context server settings.
639 if !configured_servers.contains_key(server_id) {
640 if disabled_servers.contains_key(&server_id.0) {
641 servers_to_stop.insert(server_id.clone());
642 } else {
643 servers_to_remove.insert(server_id.clone());
644 }
645 }
646 }
647
648 for (id, config) in configured_servers {
649 let state = this.servers.get(&id);
650 let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
651 let existing_config = state.as_ref().map(|state| state.configuration());
652 if existing_config.as_deref() != Some(&config) || is_stopped {
653 let config = Arc::new(config);
654 let server = this.create_context_server(id.clone(), config.clone(), cx)?;
655 servers_to_start.push((server, config));
656 if this.servers.contains_key(&id) {
657 servers_to_stop.insert(id);
658 }
659 }
660 }
661
662 anyhow::Ok(())
663 })??;
664
665 this.update(cx, |this, cx| {
666 for id in servers_to_stop {
667 this.stop_server(&id, cx)?;
668 }
669 for id in servers_to_remove {
670 this.remove_server(&id, cx)?;
671 }
672 for (server, config) in servers_to_start {
673 this.run_server(server, config, cx);
674 }
675 anyhow::Ok(())
676 })?
677 }
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683 use crate::{
684 FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
685 project_settings::ProjectSettings,
686 };
687 use context_server::test::create_fake_transport;
688 use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
689 use http_client::{FakeHttpClient, Response};
690 use serde_json::json;
691 use std::{cell::RefCell, path::PathBuf, rc::Rc};
692 use util::path;
693
694 #[gpui::test]
695 async fn test_context_server_status(cx: &mut TestAppContext) {
696 const SERVER_1_ID: &str = "mcp-1";
697 const SERVER_2_ID: &str = "mcp-2";
698
699 let (_fs, project) = setup_context_server_test(
700 cx,
701 json!({"code.rs": ""}),
702 vec![
703 (SERVER_1_ID.into(), dummy_server_settings()),
704 (SERVER_2_ID.into(), dummy_server_settings()),
705 ],
706 )
707 .await;
708
709 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
710 let store = cx.new(|cx| {
711 ContextServerStore::test(
712 registry.clone(),
713 project.read(cx).worktree_store(),
714 project.downgrade(),
715 cx,
716 )
717 });
718
719 let server_1_id = ContextServerId(SERVER_1_ID.into());
720 let server_2_id = ContextServerId(SERVER_2_ID.into());
721
722 let server_1 = Arc::new(ContextServer::new(
723 server_1_id.clone(),
724 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
725 ));
726 let server_2 = Arc::new(ContextServer::new(
727 server_2_id.clone(),
728 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
729 ));
730
731 store.update(cx, |store, cx| store.start_server(server_1, cx));
732
733 cx.run_until_parked();
734
735 cx.update(|cx| {
736 assert_eq!(
737 store.read(cx).status_for_server(&server_1_id),
738 Some(ContextServerStatus::Running)
739 );
740 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
741 });
742
743 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
744
745 cx.run_until_parked();
746
747 cx.update(|cx| {
748 assert_eq!(
749 store.read(cx).status_for_server(&server_1_id),
750 Some(ContextServerStatus::Running)
751 );
752 assert_eq!(
753 store.read(cx).status_for_server(&server_2_id),
754 Some(ContextServerStatus::Running)
755 );
756 });
757
758 store
759 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
760 .unwrap();
761
762 cx.update(|cx| {
763 assert_eq!(
764 store.read(cx).status_for_server(&server_1_id),
765 Some(ContextServerStatus::Running)
766 );
767 assert_eq!(
768 store.read(cx).status_for_server(&server_2_id),
769 Some(ContextServerStatus::Stopped)
770 );
771 });
772 }
773
774 #[gpui::test]
775 async fn test_context_server_status_events(cx: &mut TestAppContext) {
776 const SERVER_1_ID: &str = "mcp-1";
777 const SERVER_2_ID: &str = "mcp-2";
778
779 let (_fs, project) = setup_context_server_test(
780 cx,
781 json!({"code.rs": ""}),
782 vec![
783 (SERVER_1_ID.into(), dummy_server_settings()),
784 (SERVER_2_ID.into(), dummy_server_settings()),
785 ],
786 )
787 .await;
788
789 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
790 let store = cx.new(|cx| {
791 ContextServerStore::test(
792 registry.clone(),
793 project.read(cx).worktree_store(),
794 project.downgrade(),
795 cx,
796 )
797 });
798
799 let server_1_id = ContextServerId(SERVER_1_ID.into());
800 let server_2_id = ContextServerId(SERVER_2_ID.into());
801
802 let server_1 = Arc::new(ContextServer::new(
803 server_1_id.clone(),
804 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
805 ));
806 let server_2 = Arc::new(ContextServer::new(
807 server_2_id.clone(),
808 Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
809 ));
810
811 let _server_events = assert_server_events(
812 &store,
813 vec![
814 (server_1_id.clone(), ContextServerStatus::Starting),
815 (server_1_id, ContextServerStatus::Running),
816 (server_2_id.clone(), ContextServerStatus::Starting),
817 (server_2_id.clone(), ContextServerStatus::Running),
818 (server_2_id.clone(), ContextServerStatus::Stopped),
819 ],
820 cx,
821 );
822
823 store.update(cx, |store, cx| store.start_server(server_1, cx));
824
825 cx.run_until_parked();
826
827 store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
828
829 cx.run_until_parked();
830
831 store
832 .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
833 .unwrap();
834 }
835
836 #[gpui::test(iterations = 25)]
837 async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
838 const SERVER_1_ID: &str = "mcp-1";
839
840 let (_fs, project) = setup_context_server_test(
841 cx,
842 json!({"code.rs": ""}),
843 vec![(SERVER_1_ID.into(), dummy_server_settings())],
844 )
845 .await;
846
847 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
848 let store = cx.new(|cx| {
849 ContextServerStore::test(
850 registry.clone(),
851 project.read(cx).worktree_store(),
852 project.downgrade(),
853 cx,
854 )
855 });
856
857 let server_id = ContextServerId(SERVER_1_ID.into());
858
859 let server_with_same_id_1 = Arc::new(ContextServer::new(
860 server_id.clone(),
861 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
862 ));
863 let server_with_same_id_2 = Arc::new(ContextServer::new(
864 server_id.clone(),
865 Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
866 ));
867
868 // If we start another server with the same id, we should report that we stopped the previous one
869 let _server_events = assert_server_events(
870 &store,
871 vec![
872 (server_id.clone(), ContextServerStatus::Starting),
873 (server_id.clone(), ContextServerStatus::Stopped),
874 (server_id.clone(), ContextServerStatus::Starting),
875 (server_id.clone(), ContextServerStatus::Running),
876 ],
877 cx,
878 );
879
880 store.update(cx, |store, cx| {
881 store.start_server(server_with_same_id_1.clone(), cx)
882 });
883 store.update(cx, |store, cx| {
884 store.start_server(server_with_same_id_2.clone(), cx)
885 });
886
887 cx.run_until_parked();
888
889 cx.update(|cx| {
890 assert_eq!(
891 store.read(cx).status_for_server(&server_id),
892 Some(ContextServerStatus::Running)
893 );
894 });
895 }
896
897 #[gpui::test]
898 async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
899 const SERVER_1_ID: &str = "mcp-1";
900 const SERVER_2_ID: &str = "mcp-2";
901
902 let server_1_id = ContextServerId(SERVER_1_ID.into());
903 let server_2_id = ContextServerId(SERVER_2_ID.into());
904
905 let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
906
907 let (_fs, project) = setup_context_server_test(
908 cx,
909 json!({"code.rs": ""}),
910 vec![(
911 SERVER_1_ID.into(),
912 ContextServerSettings::Extension {
913 enabled: true,
914 settings: json!({
915 "somevalue": true
916 }),
917 },
918 )],
919 )
920 .await;
921
922 let executor = cx.executor();
923 let registry = cx.new(|cx| {
924 let mut registry = ContextServerDescriptorRegistry::new();
925 registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
926 registry
927 });
928 let store = cx.new(|cx| {
929 ContextServerStore::test_maintain_server_loop(
930 Some(Box::new(move |id, _| {
931 Arc::new(ContextServer::new(
932 id.clone(),
933 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
934 ))
935 })),
936 registry.clone(),
937 project.read(cx).worktree_store(),
938 project.downgrade(),
939 cx,
940 )
941 });
942
943 // Ensure that mcp-1 starts up
944 {
945 let _server_events = assert_server_events(
946 &store,
947 vec![
948 (server_1_id.clone(), ContextServerStatus::Starting),
949 (server_1_id.clone(), ContextServerStatus::Running),
950 ],
951 cx,
952 );
953 cx.run_until_parked();
954 }
955
956 // Ensure that mcp-1 is restarted when the configuration was changed
957 {
958 let _server_events = assert_server_events(
959 &store,
960 vec![
961 (server_1_id.clone(), ContextServerStatus::Stopped),
962 (server_1_id.clone(), ContextServerStatus::Starting),
963 (server_1_id.clone(), ContextServerStatus::Running),
964 ],
965 cx,
966 );
967 set_context_server_configuration(
968 vec![(
969 server_1_id.0.clone(),
970 settings::ContextServerSettingsContent::Extension {
971 enabled: true,
972 settings: json!({
973 "somevalue": false
974 }),
975 },
976 )],
977 cx,
978 );
979
980 cx.run_until_parked();
981 }
982
983 // Ensure that mcp-1 is not restarted when the configuration was not changed
984 {
985 let _server_events = assert_server_events(&store, vec![], cx);
986 set_context_server_configuration(
987 vec![(
988 server_1_id.0.clone(),
989 settings::ContextServerSettingsContent::Extension {
990 enabled: true,
991 settings: json!({
992 "somevalue": false
993 }),
994 },
995 )],
996 cx,
997 );
998
999 cx.run_until_parked();
1000 }
1001
1002 // Ensure that mcp-2 is started once it is added to the settings
1003 {
1004 let _server_events = assert_server_events(
1005 &store,
1006 vec![
1007 (server_2_id.clone(), ContextServerStatus::Starting),
1008 (server_2_id.clone(), ContextServerStatus::Running),
1009 ],
1010 cx,
1011 );
1012 set_context_server_configuration(
1013 vec![
1014 (
1015 server_1_id.0.clone(),
1016 settings::ContextServerSettingsContent::Extension {
1017 enabled: true,
1018 settings: json!({
1019 "somevalue": false
1020 }),
1021 },
1022 ),
1023 (
1024 server_2_id.0.clone(),
1025 settings::ContextServerSettingsContent::Stdio {
1026 enabled: true,
1027 command: ContextServerCommand {
1028 path: "somebinary".into(),
1029 args: vec!["arg".to_string()],
1030 env: None,
1031 timeout: None,
1032 },
1033 },
1034 ),
1035 ],
1036 cx,
1037 );
1038
1039 cx.run_until_parked();
1040 }
1041
1042 // Ensure that mcp-2 is restarted once the args have changed
1043 {
1044 let _server_events = assert_server_events(
1045 &store,
1046 vec![
1047 (server_2_id.clone(), ContextServerStatus::Stopped),
1048 (server_2_id.clone(), ContextServerStatus::Starting),
1049 (server_2_id.clone(), ContextServerStatus::Running),
1050 ],
1051 cx,
1052 );
1053 set_context_server_configuration(
1054 vec![
1055 (
1056 server_1_id.0.clone(),
1057 settings::ContextServerSettingsContent::Extension {
1058 enabled: true,
1059 settings: json!({
1060 "somevalue": false
1061 }),
1062 },
1063 ),
1064 (
1065 server_2_id.0.clone(),
1066 settings::ContextServerSettingsContent::Stdio {
1067 enabled: true,
1068 command: ContextServerCommand {
1069 path: "somebinary".into(),
1070 args: vec!["anotherArg".to_string()],
1071 env: None,
1072 timeout: None,
1073 },
1074 },
1075 ),
1076 ],
1077 cx,
1078 );
1079
1080 cx.run_until_parked();
1081 }
1082
1083 // Ensure that mcp-2 is removed once it is removed from the settings
1084 {
1085 let _server_events = assert_server_events(
1086 &store,
1087 vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1088 cx,
1089 );
1090 set_context_server_configuration(
1091 vec![(
1092 server_1_id.0.clone(),
1093 settings::ContextServerSettingsContent::Extension {
1094 enabled: true,
1095 settings: json!({
1096 "somevalue": false
1097 }),
1098 },
1099 )],
1100 cx,
1101 );
1102
1103 cx.run_until_parked();
1104
1105 cx.update(|cx| {
1106 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1107 });
1108 }
1109
1110 // Ensure that nothing happens if the settings do not change
1111 {
1112 let _server_events = assert_server_events(&store, vec![], cx);
1113 set_context_server_configuration(
1114 vec![(
1115 server_1_id.0.clone(),
1116 settings::ContextServerSettingsContent::Extension {
1117 enabled: true,
1118 settings: json!({
1119 "somevalue": false
1120 }),
1121 },
1122 )],
1123 cx,
1124 );
1125
1126 cx.run_until_parked();
1127
1128 cx.update(|cx| {
1129 assert_eq!(
1130 store.read(cx).status_for_server(&server_1_id),
1131 Some(ContextServerStatus::Running)
1132 );
1133 assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1134 });
1135 }
1136 }
1137
1138 #[gpui::test]
1139 async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1140 const SERVER_1_ID: &str = "mcp-1";
1141
1142 let server_1_id = ContextServerId(SERVER_1_ID.into());
1143
1144 let (_fs, project) = setup_context_server_test(
1145 cx,
1146 json!({"code.rs": ""}),
1147 vec![(
1148 SERVER_1_ID.into(),
1149 ContextServerSettings::Stdio {
1150 enabled: true,
1151 command: ContextServerCommand {
1152 path: "somebinary".into(),
1153 args: vec!["arg".to_string()],
1154 env: None,
1155 timeout: None,
1156 },
1157 },
1158 )],
1159 )
1160 .await;
1161
1162 let executor = cx.executor();
1163 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1164 let store = cx.new(|cx| {
1165 ContextServerStore::test_maintain_server_loop(
1166 Some(Box::new(move |id, _| {
1167 Arc::new(ContextServer::new(
1168 id.clone(),
1169 Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1170 ))
1171 })),
1172 registry.clone(),
1173 project.read(cx).worktree_store(),
1174 project.downgrade(),
1175 cx,
1176 )
1177 });
1178
1179 // Ensure that mcp-1 starts up
1180 {
1181 let _server_events = assert_server_events(
1182 &store,
1183 vec![
1184 (server_1_id.clone(), ContextServerStatus::Starting),
1185 (server_1_id.clone(), ContextServerStatus::Running),
1186 ],
1187 cx,
1188 );
1189 cx.run_until_parked();
1190 }
1191
1192 // Ensure that mcp-1 is stopped once it is disabled.
1193 {
1194 let _server_events = assert_server_events(
1195 &store,
1196 vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1197 cx,
1198 );
1199 set_context_server_configuration(
1200 vec![(
1201 server_1_id.0.clone(),
1202 settings::ContextServerSettingsContent::Stdio {
1203 enabled: false,
1204 command: ContextServerCommand {
1205 path: "somebinary".into(),
1206 args: vec!["arg".to_string()],
1207 env: None,
1208 timeout: None,
1209 },
1210 },
1211 )],
1212 cx,
1213 );
1214
1215 cx.run_until_parked();
1216 }
1217
1218 // Ensure that mcp-1 is started once it is enabled again.
1219 {
1220 let _server_events = assert_server_events(
1221 &store,
1222 vec![
1223 (server_1_id.clone(), ContextServerStatus::Starting),
1224 (server_1_id.clone(), ContextServerStatus::Running),
1225 ],
1226 cx,
1227 );
1228 set_context_server_configuration(
1229 vec![(
1230 server_1_id.0.clone(),
1231 settings::ContextServerSettingsContent::Stdio {
1232 enabled: true,
1233 command: ContextServerCommand {
1234 path: "somebinary".into(),
1235 args: vec!["arg".to_string()],
1236 timeout: None,
1237 env: None,
1238 },
1239 },
1240 )],
1241 cx,
1242 );
1243
1244 cx.run_until_parked();
1245 }
1246 }
1247
1248 fn set_context_server_configuration(
1249 context_servers: Vec<(Arc<str>, settings::ContextServerSettingsContent)>,
1250 cx: &mut TestAppContext,
1251 ) {
1252 cx.update(|cx| {
1253 SettingsStore::update_global(cx, |store, cx| {
1254 store.update_user_settings(cx, |content| {
1255 content.project.context_servers.clear();
1256 for (id, config) in context_servers {
1257 content.project.context_servers.insert(id, config);
1258 }
1259 });
1260 })
1261 });
1262 }
1263
1264 #[gpui::test]
1265 async fn test_remote_context_server(cx: &mut TestAppContext) {
1266 const SERVER_ID: &str = "remote-server";
1267 let server_id = ContextServerId(SERVER_ID.into());
1268 let server_url = "http://example.com/api";
1269
1270 let (_fs, project) = setup_context_server_test(
1271 cx,
1272 json!({ "code.rs": "" }),
1273 vec![(
1274 SERVER_ID.into(),
1275 ContextServerSettings::Http {
1276 enabled: true,
1277 url: server_url.to_string(),
1278 headers: Default::default(),
1279 },
1280 )],
1281 )
1282 .await;
1283
1284 let client = FakeHttpClient::create(|_| async move {
1285 use http_client::AsyncBody;
1286
1287 let response = Response::builder()
1288 .status(200)
1289 .header("Content-Type", "application/json")
1290 .body(AsyncBody::from(
1291 serde_json::to_string(&json!({
1292 "jsonrpc": "2.0",
1293 "id": 0,
1294 "result": {
1295 "protocolVersion": "2024-11-05",
1296 "capabilities": {},
1297 "serverInfo": {
1298 "name": "test-server",
1299 "version": "1.0.0"
1300 }
1301 }
1302 }))
1303 .unwrap(),
1304 ))
1305 .unwrap();
1306 Ok(response)
1307 });
1308 cx.update(|cx| cx.set_http_client(client));
1309 let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1310 let store = cx.new(|cx| {
1311 ContextServerStore::test_maintain_server_loop(
1312 None,
1313 registry.clone(),
1314 project.read(cx).worktree_store(),
1315 project.downgrade(),
1316 cx,
1317 )
1318 });
1319
1320 let _server_events = assert_server_events(
1321 &store,
1322 vec![
1323 (server_id.clone(), ContextServerStatus::Starting),
1324 (server_id.clone(), ContextServerStatus::Running),
1325 ],
1326 cx,
1327 );
1328 cx.run_until_parked();
1329 }
1330
1331 struct ServerEvents {
1332 received_event_count: Rc<RefCell<usize>>,
1333 expected_event_count: usize,
1334 _subscription: Subscription,
1335 }
1336
1337 impl Drop for ServerEvents {
1338 fn drop(&mut self) {
1339 let actual_event_count = *self.received_event_count.borrow();
1340 assert_eq!(
1341 actual_event_count, self.expected_event_count,
1342 "
1343 Expected to receive {} context server store events, but received {} events",
1344 self.expected_event_count, actual_event_count
1345 );
1346 }
1347 }
1348
1349 fn dummy_server_settings() -> ContextServerSettings {
1350 ContextServerSettings::Stdio {
1351 enabled: true,
1352 command: ContextServerCommand {
1353 path: "somebinary".into(),
1354 args: vec!["arg".to_string()],
1355 env: None,
1356 timeout: None,
1357 },
1358 }
1359 }
1360
1361 fn assert_server_events(
1362 store: &Entity<ContextServerStore>,
1363 expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1364 cx: &mut TestAppContext,
1365 ) -> ServerEvents {
1366 cx.update(|cx| {
1367 let mut ix = 0;
1368 let received_event_count = Rc::new(RefCell::new(0));
1369 let expected_event_count = expected_events.len();
1370 let subscription = cx.subscribe(store, {
1371 let received_event_count = received_event_count.clone();
1372 move |_, event, _| match event {
1373 Event::ServerStatusChanged {
1374 server_id: actual_server_id,
1375 status: actual_status,
1376 } => {
1377 let (expected_server_id, expected_status) = &expected_events[ix];
1378
1379 assert_eq!(
1380 actual_server_id, expected_server_id,
1381 "Expected different server id at index {}",
1382 ix
1383 );
1384 assert_eq!(
1385 actual_status, expected_status,
1386 "Expected different status at index {}",
1387 ix
1388 );
1389 ix += 1;
1390 *received_event_count.borrow_mut() += 1;
1391 }
1392 }
1393 });
1394 ServerEvents {
1395 expected_event_count,
1396 received_event_count,
1397 _subscription: subscription,
1398 }
1399 })
1400 }
1401
1402 async fn setup_context_server_test(
1403 cx: &mut TestAppContext,
1404 files: serde_json::Value,
1405 context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1406 ) -> (Arc<FakeFs>, Entity<Project>) {
1407 cx.update(|cx| {
1408 let settings_store = SettingsStore::test(cx);
1409 cx.set_global(settings_store);
1410 let mut settings = ProjectSettings::get_global(cx).clone();
1411 for (id, config) in context_server_configurations {
1412 settings.context_servers.insert(id, config);
1413 }
1414 ProjectSettings::override_global(settings, cx);
1415 });
1416
1417 let fs = FakeFs::new(cx.executor());
1418 fs.insert_tree(path!("/test"), files).await;
1419 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1420
1421 (fs, project)
1422 }
1423
1424 struct FakeContextServerDescriptor {
1425 path: PathBuf,
1426 }
1427
1428 impl FakeContextServerDescriptor {
1429 fn new(path: impl Into<PathBuf>) -> Self {
1430 Self { path: path.into() }
1431 }
1432 }
1433
1434 impl ContextServerDescriptor for FakeContextServerDescriptor {
1435 fn command(
1436 &self,
1437 _worktree_store: Entity<WorktreeStore>,
1438 _cx: &AsyncApp,
1439 ) -> Task<Result<ContextServerCommand>> {
1440 Task::ready(Ok(ContextServerCommand {
1441 path: self.path.clone(),
1442 args: vec!["arg1".to_string(), "arg2".to_string()],
1443 env: None,
1444 timeout: None,
1445 }))
1446 }
1447
1448 fn configuration(
1449 &self,
1450 _worktree_store: Entity<WorktreeStore>,
1451 _cx: &AsyncApp,
1452 ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1453 Task::ready(Ok(None))
1454 }
1455 }
1456}