context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::{path::Path, sync::Arc};
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use futures::{FutureExt as _, future::join_all};
  10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  11use registry::ContextServerDescriptorRegistry;
  12use settings::{Settings as _, SettingsStore};
  13use util::ResultExt as _;
  14
  15use crate::{
  16    project_settings::{ContextServerSettings, ProjectSettings},
  17    worktree_store::WorktreeStore,
  18};
  19
  20pub fn init(cx: &mut App) {
  21    extension::init(cx);
  22}
  23
  24actions!(context_server, [Restart]);
  25
  26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  27pub enum ContextServerStatus {
  28    Starting,
  29    Running,
  30    Stopped,
  31    Error(Arc<str>),
  32}
  33
  34impl ContextServerStatus {
  35    fn from_state(state: &ContextServerState) -> Self {
  36        match state {
  37            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  38            ContextServerState::Running { .. } => ContextServerStatus::Running,
  39            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  40            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  41        }
  42    }
  43}
  44
  45enum ContextServerState {
  46    Starting {
  47        server: Arc<ContextServer>,
  48        configuration: Arc<ContextServerConfiguration>,
  49        _task: Task<()>,
  50    },
  51    Running {
  52        server: Arc<ContextServer>,
  53        configuration: Arc<ContextServerConfiguration>,
  54    },
  55    Stopped {
  56        server: Arc<ContextServer>,
  57        configuration: Arc<ContextServerConfiguration>,
  58    },
  59    Error {
  60        server: Arc<ContextServer>,
  61        configuration: Arc<ContextServerConfiguration>,
  62        error: Arc<str>,
  63    },
  64}
  65
  66impl ContextServerState {
  67    pub fn server(&self) -> Arc<ContextServer> {
  68        match self {
  69            ContextServerState::Starting { server, .. } => server.clone(),
  70            ContextServerState::Running { server, .. } => server.clone(),
  71            ContextServerState::Stopped { server, .. } => server.clone(),
  72            ContextServerState::Error { server, .. } => server.clone(),
  73        }
  74    }
  75
  76    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  77        match self {
  78            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  79            ContextServerState::Running { configuration, .. } => configuration.clone(),
  80            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  81            ContextServerState::Error { configuration, .. } => configuration.clone(),
  82        }
  83    }
  84}
  85
  86#[derive(Debug, PartialEq, Eq)]
  87pub enum ContextServerConfiguration {
  88    Custom {
  89        command: ContextServerCommand,
  90    },
  91    Extension {
  92        command: ContextServerCommand,
  93        settings: serde_json::Value,
  94    },
  95}
  96
  97impl ContextServerConfiguration {
  98    pub fn command(&self) -> &ContextServerCommand {
  99        match self {
 100            ContextServerConfiguration::Custom { command } => command,
 101            ContextServerConfiguration::Extension { command, .. } => command,
 102        }
 103    }
 104
 105    pub async fn from_settings(
 106        settings: ContextServerSettings,
 107        id: ContextServerId,
 108        registry: Entity<ContextServerDescriptorRegistry>,
 109        worktree_store: Entity<WorktreeStore>,
 110        cx: &AsyncApp,
 111    ) -> Option<Self> {
 112        match settings {
 113            ContextServerSettings::Custom {
 114                enabled: _,
 115                command,
 116            } => Some(ContextServerConfiguration::Custom { command }),
 117            ContextServerSettings::Extension {
 118                enabled: _,
 119                settings,
 120            } => {
 121                let descriptor = cx
 122                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 123                    .ok()
 124                    .flatten()?;
 125
 126                let command = descriptor.command(worktree_store, cx).await.log_err()?;
 127
 128                Some(ContextServerConfiguration::Extension { command, settings })
 129            }
 130        }
 131    }
 132}
 133
 134pub type ContextServerFactory =
 135    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 136
 137pub struct ContextServerStore {
 138    servers: HashMap<ContextServerId, ContextServerState>,
 139    worktree_store: Entity<WorktreeStore>,
 140    registry: Entity<ContextServerDescriptorRegistry>,
 141    update_servers_task: Option<Task<Result<()>>>,
 142    context_server_factory: Option<ContextServerFactory>,
 143    needs_server_update: bool,
 144    _subscriptions: Vec<Subscription>,
 145}
 146
 147pub enum Event {
 148    ServerStatusChanged {
 149        server_id: ContextServerId,
 150        status: ContextServerStatus,
 151    },
 152}
 153
 154impl EventEmitter<Event> for ContextServerStore {}
 155
 156impl ContextServerStore {
 157    pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
 158        Self::new_internal(
 159            true,
 160            None,
 161            ContextServerDescriptorRegistry::default_global(cx),
 162            worktree_store,
 163            cx,
 164        )
 165    }
 166
 167    #[cfg(any(test, feature = "test-support"))]
 168    pub fn test(
 169        registry: Entity<ContextServerDescriptorRegistry>,
 170        worktree_store: Entity<WorktreeStore>,
 171        cx: &mut Context<Self>,
 172    ) -> Self {
 173        Self::new_internal(false, None, registry, worktree_store, cx)
 174    }
 175
 176    #[cfg(any(test, feature = "test-support"))]
 177    pub fn test_maintain_server_loop(
 178        context_server_factory: ContextServerFactory,
 179        registry: Entity<ContextServerDescriptorRegistry>,
 180        worktree_store: Entity<WorktreeStore>,
 181        cx: &mut Context<Self>,
 182    ) -> Self {
 183        Self::new_internal(
 184            true,
 185            Some(context_server_factory),
 186            registry,
 187            worktree_store,
 188            cx,
 189        )
 190    }
 191
 192    fn new_internal(
 193        maintain_server_loop: bool,
 194        context_server_factory: Option<ContextServerFactory>,
 195        registry: Entity<ContextServerDescriptorRegistry>,
 196        worktree_store: Entity<WorktreeStore>,
 197        cx: &mut Context<Self>,
 198    ) -> Self {
 199        let subscriptions = if maintain_server_loop {
 200            vec![
 201                cx.observe(&registry, |this, _registry, cx| {
 202                    this.available_context_servers_changed(cx);
 203                }),
 204                cx.observe_global::<SettingsStore>(|this, cx| {
 205                    this.available_context_servers_changed(cx);
 206                }),
 207            ]
 208        } else {
 209            Vec::new()
 210        };
 211
 212        let mut this = Self {
 213            _subscriptions: subscriptions,
 214            worktree_store,
 215            registry,
 216            needs_server_update: false,
 217            servers: HashMap::default(),
 218            update_servers_task: None,
 219            context_server_factory,
 220        };
 221        if maintain_server_loop {
 222            this.available_context_servers_changed(cx);
 223        }
 224        this
 225    }
 226
 227    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 228        self.servers.get(id).map(|state| state.server())
 229    }
 230
 231    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 232        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 233            Some(server.clone())
 234        } else {
 235            None
 236        }
 237    }
 238
 239    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 240        self.servers.get(id).map(ContextServerStatus::from_state)
 241    }
 242
 243    pub fn configuration_for_server(
 244        &self,
 245        id: &ContextServerId,
 246    ) -> Option<Arc<ContextServerConfiguration>> {
 247        self.servers.get(id).map(|state| state.configuration())
 248    }
 249
 250    pub fn all_server_ids(&self) -> Vec<ContextServerId> {
 251        self.servers.keys().cloned().collect()
 252    }
 253
 254    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 255        self.servers
 256            .values()
 257            .filter_map(|state| {
 258                if let ContextServerState::Running { server, .. } = state {
 259                    Some(server.clone())
 260                } else {
 261                    None
 262                }
 263            })
 264            .collect()
 265    }
 266
 267    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 268        cx.spawn(async move |this, cx| {
 269            let this = this.upgrade().context("Context server store dropped")?;
 270            let settings = this
 271                .update(cx, |this, cx| {
 272                    this.context_server_settings(cx)
 273                        .get(&server.id().0)
 274                        .cloned()
 275                })
 276                .ok()
 277                .flatten()
 278                .context("Failed to get context server settings")?;
 279
 280            if !settings.enabled() {
 281                return Ok(());
 282            }
 283
 284            let (registry, worktree_store) = this.update(cx, |this, _| {
 285                (this.registry.clone(), this.worktree_store.clone())
 286            })?;
 287            let configuration = ContextServerConfiguration::from_settings(
 288                settings,
 289                server.id(),
 290                registry,
 291                worktree_store,
 292                cx,
 293            )
 294            .await
 295            .context("Failed to create context server configuration")?;
 296
 297            this.update(cx, |this, cx| {
 298                this.run_server(server, Arc::new(configuration), cx)
 299            })
 300        })
 301        .detach_and_log_err(cx);
 302    }
 303
 304    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 305        if matches!(
 306            self.servers.get(id),
 307            Some(ContextServerState::Stopped { .. })
 308        ) {
 309            return Ok(());
 310        }
 311
 312        let state = self
 313            .servers
 314            .remove(id)
 315            .context("Context server not found")?;
 316
 317        let server = state.server();
 318        let configuration = state.configuration();
 319        let mut result = Ok(());
 320        if let ContextServerState::Running { server, .. } = &state {
 321            result = server.stop();
 322        }
 323        drop(state);
 324
 325        self.update_server_state(
 326            id.clone(),
 327            ContextServerState::Stopped {
 328                configuration,
 329                server,
 330            },
 331            cx,
 332        );
 333
 334        result
 335    }
 336
 337    pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 338        if let Some(state) = self.servers.get(&id) {
 339            let configuration = state.configuration();
 340
 341            self.stop_server(&state.server().id(), cx)?;
 342            let new_server = self.create_context_server(id.clone(), configuration.clone())?;
 343            self.run_server(new_server, configuration, cx);
 344        }
 345        Ok(())
 346    }
 347
 348    fn run_server(
 349        &mut self,
 350        server: Arc<ContextServer>,
 351        configuration: Arc<ContextServerConfiguration>,
 352        cx: &mut Context<Self>,
 353    ) {
 354        let id = server.id();
 355        if matches!(
 356            self.servers.get(&id),
 357            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 358        ) {
 359            self.stop_server(&id, cx).log_err();
 360        }
 361
 362        let task = cx.spawn({
 363            let id = server.id();
 364            let server = server.clone();
 365            let configuration = configuration.clone();
 366            async move |this, cx| {
 367                match server.clone().start(&cx).await {
 368                    Ok(_) => {
 369                        log::info!("Started {} context server", id);
 370                        debug_assert!(server.client().is_some());
 371
 372                        this.update(cx, |this, cx| {
 373                            this.update_server_state(
 374                                id.clone(),
 375                                ContextServerState::Running {
 376                                    server,
 377                                    configuration,
 378                                },
 379                                cx,
 380                            )
 381                        })
 382                        .log_err()
 383                    }
 384                    Err(err) => {
 385                        log::error!("{} context server failed to start: {}", id, err);
 386                        this.update(cx, |this, cx| {
 387                            this.update_server_state(
 388                                id.clone(),
 389                                ContextServerState::Error {
 390                                    configuration,
 391                                    server,
 392                                    error: err.to_string().into(),
 393                                },
 394                                cx,
 395                            )
 396                        })
 397                        .log_err()
 398                    }
 399                };
 400            }
 401        });
 402
 403        self.update_server_state(
 404            id.clone(),
 405            ContextServerState::Starting {
 406                configuration,
 407                _task: task,
 408                server,
 409            },
 410            cx,
 411        );
 412    }
 413
 414    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 415        let state = self
 416            .servers
 417            .remove(id)
 418            .context("Context server not found")?;
 419        drop(state);
 420        cx.emit(Event::ServerStatusChanged {
 421            server_id: id.clone(),
 422            status: ContextServerStatus::Stopped,
 423        });
 424        Ok(())
 425    }
 426
 427    fn create_context_server(
 428        &self,
 429        id: ContextServerId,
 430        configuration: Arc<ContextServerConfiguration>,
 431    ) -> Result<Arc<ContextServer>> {
 432        if let Some(factory) = self.context_server_factory.as_ref() {
 433            Ok(factory(id, configuration))
 434        } else {
 435            Ok(Arc::new(ContextServer::stdio(
 436                id,
 437                configuration.command().clone(),
 438            )))
 439        }
 440    }
 441
 442    fn context_server_settings<'a>(
 443        &'a self,
 444        cx: &'a App,
 445    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 446        let location = self
 447            .worktree_store
 448            .read(cx)
 449            .visible_worktrees(cx)
 450            .next()
 451            .map(|worktree| settings::SettingsLocation {
 452                worktree_id: worktree.read(cx).id(),
 453                path: Path::new(""),
 454            });
 455        &ProjectSettings::get(location, cx).context_servers
 456    }
 457
 458    fn update_server_state(
 459        &mut self,
 460        id: ContextServerId,
 461        state: ContextServerState,
 462        cx: &mut Context<Self>,
 463    ) {
 464        let status = ContextServerStatus::from_state(&state);
 465        self.servers.insert(id.clone(), state);
 466        cx.emit(Event::ServerStatusChanged {
 467            server_id: id,
 468            status,
 469        });
 470    }
 471
 472    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 473        if self.update_servers_task.is_some() {
 474            self.needs_server_update = true;
 475        } else {
 476            self.needs_server_update = false;
 477            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 478                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 479                    log::error!("Error maintaining context servers: {}", err);
 480                }
 481
 482                this.update(cx, |this, cx| {
 483                    this.update_servers_task.take();
 484                    if this.needs_server_update {
 485                        this.available_context_servers_changed(cx);
 486                    }
 487                })?;
 488
 489                Ok(())
 490            }));
 491        }
 492    }
 493
 494    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 495        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, cx| {
 496            (
 497                this.context_server_settings(cx).clone(),
 498                this.registry.clone(),
 499                this.worktree_store.clone(),
 500            )
 501        })?;
 502
 503        for (id, _) in
 504            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 505        {
 506            configured_servers
 507                .entry(id)
 508                .or_insert(ContextServerSettings::default_extension());
 509        }
 510
 511        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
 512            configured_servers
 513                .into_iter()
 514                .partition(|(_, settings)| settings.enabled());
 515
 516        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
 517            let id = ContextServerId(id);
 518            ContextServerConfiguration::from_settings(
 519                settings,
 520                id.clone(),
 521                registry.clone(),
 522                worktree_store.clone(),
 523                cx,
 524            )
 525            .map(|config| (id, config))
 526        }))
 527        .await
 528        .into_iter()
 529        .filter_map(|(id, config)| config.map(|config| (id, config)))
 530        .collect::<HashMap<_, _>>();
 531
 532        let mut servers_to_start = Vec::new();
 533        let mut servers_to_remove = HashSet::default();
 534        let mut servers_to_stop = HashSet::default();
 535
 536        this.update(cx, |this, _cx| {
 537            for server_id in this.servers.keys() {
 538                // All servers that are not in desired_servers should be removed from the store.
 539                // This can happen if the user removed a server from the context server settings.
 540                if !configured_servers.contains_key(&server_id) {
 541                    if disabled_servers.contains_key(&server_id.0) {
 542                        servers_to_stop.insert(server_id.clone());
 543                    } else {
 544                        servers_to_remove.insert(server_id.clone());
 545                    }
 546                }
 547            }
 548
 549            for (id, config) in configured_servers {
 550                let state = this.servers.get(&id);
 551                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
 552                let existing_config = state.as_ref().map(|state| state.configuration());
 553                if existing_config.as_deref() != Some(&config) || is_stopped {
 554                    let config = Arc::new(config);
 555                    if let Some(server) = this
 556                        .create_context_server(id.clone(), config.clone())
 557                        .log_err()
 558                    {
 559                        servers_to_start.push((server, config));
 560                        if this.servers.contains_key(&id) {
 561                            servers_to_stop.insert(id);
 562                        }
 563                    }
 564                }
 565            }
 566        })?;
 567
 568        this.update(cx, |this, cx| {
 569            for id in servers_to_stop {
 570                this.stop_server(&id, cx)?;
 571            }
 572            for id in servers_to_remove {
 573                this.remove_server(&id, cx)?;
 574            }
 575            for (server, config) in servers_to_start {
 576                this.run_server(server, config, cx);
 577            }
 578            anyhow::Ok(())
 579        })?
 580    }
 581}
 582
 583#[cfg(test)]
 584mod tests {
 585    use super::*;
 586    use crate::{
 587        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 588        project_settings::ProjectSettings,
 589    };
 590    use context_server::test::create_fake_transport;
 591    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 592    use serde_json::json;
 593    use std::{cell::RefCell, rc::Rc};
 594    use util::path;
 595
 596    #[gpui::test]
 597    async fn test_context_server_status(cx: &mut TestAppContext) {
 598        const SERVER_1_ID: &'static str = "mcp-1";
 599        const SERVER_2_ID: &'static str = "mcp-2";
 600
 601        let (_fs, project) = setup_context_server_test(
 602            cx,
 603            json!({"code.rs": ""}),
 604            vec![
 605                (SERVER_1_ID.into(), dummy_server_settings()),
 606                (SERVER_2_ID.into(), dummy_server_settings()),
 607            ],
 608        )
 609        .await;
 610
 611        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 612        let store = cx.new(|cx| {
 613            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 614        });
 615
 616        let server_1_id = ContextServerId(SERVER_1_ID.into());
 617        let server_2_id = ContextServerId(SERVER_2_ID.into());
 618
 619        let server_1 = Arc::new(ContextServer::new(
 620            server_1_id.clone(),
 621            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 622        ));
 623        let server_2 = Arc::new(ContextServer::new(
 624            server_2_id.clone(),
 625            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 626        ));
 627
 628        store.update(cx, |store, cx| store.start_server(server_1, cx));
 629
 630        cx.run_until_parked();
 631
 632        cx.update(|cx| {
 633            assert_eq!(
 634                store.read(cx).status_for_server(&server_1_id),
 635                Some(ContextServerStatus::Running)
 636            );
 637            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 638        });
 639
 640        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 641
 642        cx.run_until_parked();
 643
 644        cx.update(|cx| {
 645            assert_eq!(
 646                store.read(cx).status_for_server(&server_1_id),
 647                Some(ContextServerStatus::Running)
 648            );
 649            assert_eq!(
 650                store.read(cx).status_for_server(&server_2_id),
 651                Some(ContextServerStatus::Running)
 652            );
 653        });
 654
 655        store
 656            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 657            .unwrap();
 658
 659        cx.update(|cx| {
 660            assert_eq!(
 661                store.read(cx).status_for_server(&server_1_id),
 662                Some(ContextServerStatus::Running)
 663            );
 664            assert_eq!(
 665                store.read(cx).status_for_server(&server_2_id),
 666                Some(ContextServerStatus::Stopped)
 667            );
 668        });
 669    }
 670
 671    #[gpui::test]
 672    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 673        const SERVER_1_ID: &'static str = "mcp-1";
 674        const SERVER_2_ID: &'static str = "mcp-2";
 675
 676        let (_fs, project) = setup_context_server_test(
 677            cx,
 678            json!({"code.rs": ""}),
 679            vec![
 680                (SERVER_1_ID.into(), dummy_server_settings()),
 681                (SERVER_2_ID.into(), dummy_server_settings()),
 682            ],
 683        )
 684        .await;
 685
 686        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 687        let store = cx.new(|cx| {
 688            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 689        });
 690
 691        let server_1_id = ContextServerId(SERVER_1_ID.into());
 692        let server_2_id = ContextServerId(SERVER_2_ID.into());
 693
 694        let server_1 = Arc::new(ContextServer::new(
 695            server_1_id.clone(),
 696            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 697        ));
 698        let server_2 = Arc::new(ContextServer::new(
 699            server_2_id.clone(),
 700            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 701        ));
 702
 703        let _server_events = assert_server_events(
 704            &store,
 705            vec![
 706                (server_1_id.clone(), ContextServerStatus::Starting),
 707                (server_1_id.clone(), ContextServerStatus::Running),
 708                (server_2_id.clone(), ContextServerStatus::Starting),
 709                (server_2_id.clone(), ContextServerStatus::Running),
 710                (server_2_id.clone(), ContextServerStatus::Stopped),
 711            ],
 712            cx,
 713        );
 714
 715        store.update(cx, |store, cx| store.start_server(server_1, cx));
 716
 717        cx.run_until_parked();
 718
 719        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 720
 721        cx.run_until_parked();
 722
 723        store
 724            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 725            .unwrap();
 726    }
 727
 728    #[gpui::test(iterations = 25)]
 729    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 730        const SERVER_1_ID: &'static str = "mcp-1";
 731
 732        let (_fs, project) = setup_context_server_test(
 733            cx,
 734            json!({"code.rs": ""}),
 735            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 736        )
 737        .await;
 738
 739        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 740        let store = cx.new(|cx| {
 741            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 742        });
 743
 744        let server_id = ContextServerId(SERVER_1_ID.into());
 745
 746        let server_with_same_id_1 = Arc::new(ContextServer::new(
 747            server_id.clone(),
 748            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 749        ));
 750        let server_with_same_id_2 = Arc::new(ContextServer::new(
 751            server_id.clone(),
 752            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 753        ));
 754
 755        // If we start another server with the same id, we should report that we stopped the previous one
 756        let _server_events = assert_server_events(
 757            &store,
 758            vec![
 759                (server_id.clone(), ContextServerStatus::Starting),
 760                (server_id.clone(), ContextServerStatus::Stopped),
 761                (server_id.clone(), ContextServerStatus::Starting),
 762                (server_id.clone(), ContextServerStatus::Running),
 763            ],
 764            cx,
 765        );
 766
 767        store.update(cx, |store, cx| {
 768            store.start_server(server_with_same_id_1.clone(), cx)
 769        });
 770        store.update(cx, |store, cx| {
 771            store.start_server(server_with_same_id_2.clone(), cx)
 772        });
 773
 774        cx.run_until_parked();
 775
 776        cx.update(|cx| {
 777            assert_eq!(
 778                store.read(cx).status_for_server(&server_id),
 779                Some(ContextServerStatus::Running)
 780            );
 781        });
 782    }
 783
 784    #[gpui::test]
 785    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 786        const SERVER_1_ID: &'static str = "mcp-1";
 787        const SERVER_2_ID: &'static str = "mcp-2";
 788
 789        let server_1_id = ContextServerId(SERVER_1_ID.into());
 790        let server_2_id = ContextServerId(SERVER_2_ID.into());
 791
 792        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 793
 794        let (_fs, project) = setup_context_server_test(
 795            cx,
 796            json!({"code.rs": ""}),
 797            vec![(
 798                SERVER_1_ID.into(),
 799                ContextServerSettings::Extension {
 800                    enabled: true,
 801                    settings: json!({
 802                        "somevalue": true
 803                    }),
 804                },
 805            )],
 806        )
 807        .await;
 808
 809        let executor = cx.executor();
 810        let registry = cx.new(|_| {
 811            let mut registry = ContextServerDescriptorRegistry::new();
 812            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1);
 813            registry
 814        });
 815        let store = cx.new(|cx| {
 816            ContextServerStore::test_maintain_server_loop(
 817                Box::new(move |id, _| {
 818                    Arc::new(ContextServer::new(
 819                        id.clone(),
 820                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 821                    ))
 822                }),
 823                registry.clone(),
 824                project.read(cx).worktree_store(),
 825                cx,
 826            )
 827        });
 828
 829        // Ensure that mcp-1 starts up
 830        {
 831            let _server_events = assert_server_events(
 832                &store,
 833                vec![
 834                    (server_1_id.clone(), ContextServerStatus::Starting),
 835                    (server_1_id.clone(), ContextServerStatus::Running),
 836                ],
 837                cx,
 838            );
 839            cx.run_until_parked();
 840        }
 841
 842        // Ensure that mcp-1 is restarted when the configuration was changed
 843        {
 844            let _server_events = assert_server_events(
 845                &store,
 846                vec![
 847                    (server_1_id.clone(), ContextServerStatus::Stopped),
 848                    (server_1_id.clone(), ContextServerStatus::Starting),
 849                    (server_1_id.clone(), ContextServerStatus::Running),
 850                ],
 851                cx,
 852            );
 853            set_context_server_configuration(
 854                vec![(
 855                    server_1_id.0.clone(),
 856                    ContextServerSettings::Extension {
 857                        enabled: true,
 858                        settings: json!({
 859                            "somevalue": false
 860                        }),
 861                    },
 862                )],
 863                cx,
 864            );
 865
 866            cx.run_until_parked();
 867        }
 868
 869        // Ensure that mcp-1 is not restarted when the configuration was not changed
 870        {
 871            let _server_events = assert_server_events(&store, vec![], cx);
 872            set_context_server_configuration(
 873                vec![(
 874                    server_1_id.0.clone(),
 875                    ContextServerSettings::Extension {
 876                        enabled: true,
 877                        settings: json!({
 878                            "somevalue": false
 879                        }),
 880                    },
 881                )],
 882                cx,
 883            );
 884
 885            cx.run_until_parked();
 886        }
 887
 888        // Ensure that mcp-2 is started once it is added to the settings
 889        {
 890            let _server_events = assert_server_events(
 891                &store,
 892                vec![
 893                    (server_2_id.clone(), ContextServerStatus::Starting),
 894                    (server_2_id.clone(), ContextServerStatus::Running),
 895                ],
 896                cx,
 897            );
 898            set_context_server_configuration(
 899                vec![
 900                    (
 901                        server_1_id.0.clone(),
 902                        ContextServerSettings::Extension {
 903                            enabled: true,
 904                            settings: json!({
 905                                "somevalue": false
 906                            }),
 907                        },
 908                    ),
 909                    (
 910                        server_2_id.0.clone(),
 911                        ContextServerSettings::Custom {
 912                            enabled: true,
 913                            command: ContextServerCommand {
 914                                path: "somebinary".to_string(),
 915                                args: vec!["arg".to_string()],
 916                                env: None,
 917                            },
 918                        },
 919                    ),
 920                ],
 921                cx,
 922            );
 923
 924            cx.run_until_parked();
 925        }
 926
 927        // Ensure that mcp-2 is restarted once the args have changed
 928        {
 929            let _server_events = assert_server_events(
 930                &store,
 931                vec![
 932                    (server_2_id.clone(), ContextServerStatus::Stopped),
 933                    (server_2_id.clone(), ContextServerStatus::Starting),
 934                    (server_2_id.clone(), ContextServerStatus::Running),
 935                ],
 936                cx,
 937            );
 938            set_context_server_configuration(
 939                vec![
 940                    (
 941                        server_1_id.0.clone(),
 942                        ContextServerSettings::Extension {
 943                            enabled: true,
 944                            settings: json!({
 945                                "somevalue": false
 946                            }),
 947                        },
 948                    ),
 949                    (
 950                        server_2_id.0.clone(),
 951                        ContextServerSettings::Custom {
 952                            enabled: true,
 953                            command: ContextServerCommand {
 954                                path: "somebinary".to_string(),
 955                                args: vec!["anotherArg".to_string()],
 956                                env: None,
 957                            },
 958                        },
 959                    ),
 960                ],
 961                cx,
 962            );
 963
 964            cx.run_until_parked();
 965        }
 966
 967        // Ensure that mcp-2 is removed once it is removed from the settings
 968        {
 969            let _server_events = assert_server_events(
 970                &store,
 971                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
 972                cx,
 973            );
 974            set_context_server_configuration(
 975                vec![(
 976                    server_1_id.0.clone(),
 977                    ContextServerSettings::Extension {
 978                        enabled: true,
 979                        settings: json!({
 980                            "somevalue": false
 981                        }),
 982                    },
 983                )],
 984                cx,
 985            );
 986
 987            cx.run_until_parked();
 988
 989            cx.update(|cx| {
 990                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 991            });
 992        }
 993    }
 994
 995    #[gpui::test]
 996    async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
 997        const SERVER_1_ID: &'static str = "mcp-1";
 998
 999        let server_1_id = ContextServerId(SERVER_1_ID.into());
1000
1001        let (_fs, project) = setup_context_server_test(
1002            cx,
1003            json!({"code.rs": ""}),
1004            vec![(
1005                SERVER_1_ID.into(),
1006                ContextServerSettings::Custom {
1007                    enabled: true,
1008                    command: ContextServerCommand {
1009                        path: "somebinary".to_string(),
1010                        args: vec!["arg".to_string()],
1011                        env: None,
1012                    },
1013                },
1014            )],
1015        )
1016        .await;
1017
1018        let executor = cx.executor();
1019        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1020        let store = cx.new(|cx| {
1021            ContextServerStore::test_maintain_server_loop(
1022                Box::new(move |id, _| {
1023                    Arc::new(ContextServer::new(
1024                        id.clone(),
1025                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1026                    ))
1027                }),
1028                registry.clone(),
1029                project.read(cx).worktree_store(),
1030                cx,
1031            )
1032        });
1033
1034        // Ensure that mcp-1 starts up
1035        {
1036            let _server_events = assert_server_events(
1037                &store,
1038                vec![
1039                    (server_1_id.clone(), ContextServerStatus::Starting),
1040                    (server_1_id.clone(), ContextServerStatus::Running),
1041                ],
1042                cx,
1043            );
1044            cx.run_until_parked();
1045        }
1046
1047        // Ensure that mcp-1 is stopped once it is disabled.
1048        {
1049            let _server_events = assert_server_events(
1050                &store,
1051                vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1052                cx,
1053            );
1054            set_context_server_configuration(
1055                vec![(
1056                    server_1_id.0.clone(),
1057                    ContextServerSettings::Custom {
1058                        enabled: false,
1059                        command: ContextServerCommand {
1060                            path: "somebinary".to_string(),
1061                            args: vec!["arg".to_string()],
1062                            env: None,
1063                        },
1064                    },
1065                )],
1066                cx,
1067            );
1068
1069            cx.run_until_parked();
1070        }
1071
1072        // Ensure that mcp-1 is started once it is enabled again.
1073        {
1074            let _server_events = assert_server_events(
1075                &store,
1076                vec![
1077                    (server_1_id.clone(), ContextServerStatus::Starting),
1078                    (server_1_id.clone(), ContextServerStatus::Running),
1079                ],
1080                cx,
1081            );
1082            set_context_server_configuration(
1083                vec![(
1084                    server_1_id.0.clone(),
1085                    ContextServerSettings::Custom {
1086                        enabled: true,
1087                        command: ContextServerCommand {
1088                            path: "somebinary".to_string(),
1089                            args: vec!["arg".to_string()],
1090                            env: None,
1091                        },
1092                    },
1093                )],
1094                cx,
1095            );
1096
1097            cx.run_until_parked();
1098        }
1099    }
1100
1101    fn set_context_server_configuration(
1102        context_servers: Vec<(Arc<str>, ContextServerSettings)>,
1103        cx: &mut TestAppContext,
1104    ) {
1105        cx.update(|cx| {
1106            SettingsStore::update_global(cx, |store, cx| {
1107                let mut settings = ProjectSettings::default();
1108                for (id, config) in context_servers {
1109                    settings.context_servers.insert(id, config);
1110                }
1111                store
1112                    .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
1113                    .unwrap();
1114            })
1115        });
1116    }
1117
1118    struct ServerEvents {
1119        received_event_count: Rc<RefCell<usize>>,
1120        expected_event_count: usize,
1121        _subscription: Subscription,
1122    }
1123
1124    impl Drop for ServerEvents {
1125        fn drop(&mut self) {
1126            let actual_event_count = *self.received_event_count.borrow();
1127            assert_eq!(
1128                actual_event_count, self.expected_event_count,
1129                "
1130                Expected to receive {} context server store events, but received {} events",
1131                self.expected_event_count, actual_event_count
1132            );
1133        }
1134    }
1135
1136    fn dummy_server_settings() -> ContextServerSettings {
1137        ContextServerSettings::Custom {
1138            enabled: true,
1139            command: ContextServerCommand {
1140                path: "somebinary".to_string(),
1141                args: vec!["arg".to_string()],
1142                env: None,
1143            },
1144        }
1145    }
1146
1147    fn assert_server_events(
1148        store: &Entity<ContextServerStore>,
1149        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1150        cx: &mut TestAppContext,
1151    ) -> ServerEvents {
1152        cx.update(|cx| {
1153            let mut ix = 0;
1154            let received_event_count = Rc::new(RefCell::new(0));
1155            let expected_event_count = expected_events.len();
1156            let subscription = cx.subscribe(store, {
1157                let received_event_count = received_event_count.clone();
1158                move |_, event, _| match event {
1159                    Event::ServerStatusChanged {
1160                        server_id: actual_server_id,
1161                        status: actual_status,
1162                    } => {
1163                        let (expected_server_id, expected_status) = &expected_events[ix];
1164
1165                        assert_eq!(
1166                            actual_server_id, expected_server_id,
1167                            "Expected different server id at index {}",
1168                            ix
1169                        );
1170                        assert_eq!(
1171                            actual_status, expected_status,
1172                            "Expected different status at index {}",
1173                            ix
1174                        );
1175                        ix += 1;
1176                        *received_event_count.borrow_mut() += 1;
1177                    }
1178                }
1179            });
1180            ServerEvents {
1181                expected_event_count,
1182                received_event_count,
1183                _subscription: subscription,
1184            }
1185        })
1186    }
1187
1188    async fn setup_context_server_test(
1189        cx: &mut TestAppContext,
1190        files: serde_json::Value,
1191        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1192    ) -> (Arc<FakeFs>, Entity<Project>) {
1193        cx.update(|cx| {
1194            let settings_store = SettingsStore::test(cx);
1195            cx.set_global(settings_store);
1196            Project::init_settings(cx);
1197            let mut settings = ProjectSettings::get_global(cx).clone();
1198            for (id, config) in context_server_configurations {
1199                settings.context_servers.insert(id, config);
1200            }
1201            ProjectSettings::override_global(settings, cx);
1202        });
1203
1204        let fs = FakeFs::new(cx.executor());
1205        fs.insert_tree(path!("/test"), files).await;
1206        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1207
1208        (fs, project)
1209    }
1210
1211    struct FakeContextServerDescriptor {
1212        path: String,
1213    }
1214
1215    impl FakeContextServerDescriptor {
1216        fn new(path: impl Into<String>) -> Self {
1217            Self { path: path.into() }
1218        }
1219    }
1220
1221    impl ContextServerDescriptor for FakeContextServerDescriptor {
1222        fn command(
1223            &self,
1224            _worktree_store: Entity<WorktreeStore>,
1225            _cx: &AsyncApp,
1226        ) -> Task<Result<ContextServerCommand>> {
1227            Task::ready(Ok(ContextServerCommand {
1228                path: self.path.clone(),
1229                args: vec!["arg1".to_string(), "arg2".to_string()],
1230                env: None,
1231            }))
1232        }
1233
1234        fn configuration(
1235            &self,
1236            _worktree_store: Entity<WorktreeStore>,
1237            _cx: &AsyncApp,
1238        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1239            Task::ready(Ok(None))
1240        }
1241    }
1242}