context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::{path::Path, sync::Arc};
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use futures::{FutureExt as _, future::join_all};
  10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  11use registry::ContextServerDescriptorRegistry;
  12use settings::{Settings as _, SettingsStore};
  13use util::ResultExt as _;
  14
  15use crate::{
  16    project_settings::{ContextServerSettings, ProjectSettings},
  17    worktree_store::WorktreeStore,
  18};
  19
  20pub fn init(cx: &mut App) {
  21    extension::init(cx);
  22}
  23
  24actions!(
  25    context_server,
  26    [
  27        /// Restarts the context server.
  28        Restart
  29    ]
  30);
  31
  32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  33pub enum ContextServerStatus {
  34    Starting,
  35    Running,
  36    Stopped,
  37    Error(Arc<str>),
  38}
  39
  40impl ContextServerStatus {
  41    fn from_state(state: &ContextServerState) -> Self {
  42        match state {
  43            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  44            ContextServerState::Running { .. } => ContextServerStatus::Running,
  45            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  46            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  47        }
  48    }
  49}
  50
  51enum ContextServerState {
  52    Starting {
  53        server: Arc<ContextServer>,
  54        configuration: Arc<ContextServerConfiguration>,
  55        _task: Task<()>,
  56    },
  57    Running {
  58        server: Arc<ContextServer>,
  59        configuration: Arc<ContextServerConfiguration>,
  60    },
  61    Stopped {
  62        server: Arc<ContextServer>,
  63        configuration: Arc<ContextServerConfiguration>,
  64    },
  65    Error {
  66        server: Arc<ContextServer>,
  67        configuration: Arc<ContextServerConfiguration>,
  68        error: Arc<str>,
  69    },
  70}
  71
  72impl ContextServerState {
  73    pub fn server(&self) -> Arc<ContextServer> {
  74        match self {
  75            ContextServerState::Starting { server, .. } => server.clone(),
  76            ContextServerState::Running { server, .. } => server.clone(),
  77            ContextServerState::Stopped { server, .. } => server.clone(),
  78            ContextServerState::Error { server, .. } => server.clone(),
  79        }
  80    }
  81
  82    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  83        match self {
  84            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  85            ContextServerState::Running { configuration, .. } => configuration.clone(),
  86            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  87            ContextServerState::Error { configuration, .. } => configuration.clone(),
  88        }
  89    }
  90}
  91
  92#[derive(Debug, PartialEq, Eq)]
  93pub enum ContextServerConfiguration {
  94    Custom {
  95        command: ContextServerCommand,
  96    },
  97    Extension {
  98        command: ContextServerCommand,
  99        settings: serde_json::Value,
 100    },
 101}
 102
 103impl ContextServerConfiguration {
 104    pub fn command(&self) -> &ContextServerCommand {
 105        match self {
 106            ContextServerConfiguration::Custom { command } => command,
 107            ContextServerConfiguration::Extension { command, .. } => command,
 108        }
 109    }
 110
 111    pub async fn from_settings(
 112        settings: ContextServerSettings,
 113        id: ContextServerId,
 114        registry: Entity<ContextServerDescriptorRegistry>,
 115        worktree_store: Entity<WorktreeStore>,
 116        cx: &AsyncApp,
 117    ) -> Option<Self> {
 118        match settings {
 119            ContextServerSettings::Custom {
 120                enabled: _,
 121                command,
 122            } => Some(ContextServerConfiguration::Custom { command }),
 123            ContextServerSettings::Extension {
 124                enabled: _,
 125                settings,
 126            } => {
 127                let descriptor = cx
 128                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 129                    .ok()
 130                    .flatten()?;
 131
 132                let command = descriptor.command(worktree_store, cx).await.log_err()?;
 133
 134                Some(ContextServerConfiguration::Extension { command, settings })
 135            }
 136        }
 137    }
 138}
 139
 140pub type ContextServerFactory =
 141    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 142
 143pub struct ContextServerStore {
 144    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 145    servers: HashMap<ContextServerId, ContextServerState>,
 146    worktree_store: Entity<WorktreeStore>,
 147    registry: Entity<ContextServerDescriptorRegistry>,
 148    update_servers_task: Option<Task<Result<()>>>,
 149    context_server_factory: Option<ContextServerFactory>,
 150    needs_server_update: bool,
 151    _subscriptions: Vec<Subscription>,
 152}
 153
 154pub enum Event {
 155    ServerStatusChanged {
 156        server_id: ContextServerId,
 157        status: ContextServerStatus,
 158    },
 159}
 160
 161impl EventEmitter<Event> for ContextServerStore {}
 162
 163impl ContextServerStore {
 164    pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
 165        Self::new_internal(
 166            true,
 167            None,
 168            ContextServerDescriptorRegistry::default_global(cx),
 169            worktree_store,
 170            cx,
 171        )
 172    }
 173
 174    /// Returns all configured context server ids, regardless of enabled state.
 175    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 176        self.context_server_settings
 177            .keys()
 178            .cloned()
 179            .map(ContextServerId)
 180            .collect()
 181    }
 182
 183    #[cfg(any(test, feature = "test-support"))]
 184    pub fn test(
 185        registry: Entity<ContextServerDescriptorRegistry>,
 186        worktree_store: Entity<WorktreeStore>,
 187        cx: &mut Context<Self>,
 188    ) -> Self {
 189        Self::new_internal(false, None, registry, worktree_store, cx)
 190    }
 191
 192    #[cfg(any(test, feature = "test-support"))]
 193    pub fn test_maintain_server_loop(
 194        context_server_factory: ContextServerFactory,
 195        registry: Entity<ContextServerDescriptorRegistry>,
 196        worktree_store: Entity<WorktreeStore>,
 197        cx: &mut Context<Self>,
 198    ) -> Self {
 199        Self::new_internal(
 200            true,
 201            Some(context_server_factory),
 202            registry,
 203            worktree_store,
 204            cx,
 205        )
 206    }
 207
 208    fn new_internal(
 209        maintain_server_loop: bool,
 210        context_server_factory: Option<ContextServerFactory>,
 211        registry: Entity<ContextServerDescriptorRegistry>,
 212        worktree_store: Entity<WorktreeStore>,
 213        cx: &mut Context<Self>,
 214    ) -> Self {
 215        let subscriptions = if maintain_server_loop {
 216            vec![
 217                cx.observe(&registry, |this, _registry, cx| {
 218                    this.available_context_servers_changed(cx);
 219                }),
 220                cx.observe_global::<SettingsStore>(|this, cx| {
 221                    let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
 222                    if &this.context_server_settings == settings {
 223                        return;
 224                    }
 225                    this.context_server_settings = settings.clone();
 226                    this.available_context_servers_changed(cx);
 227                }),
 228            ]
 229        } else {
 230            Vec::new()
 231        };
 232
 233        let mut this = Self {
 234            _subscriptions: subscriptions,
 235            context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
 236                .clone(),
 237            worktree_store,
 238            registry,
 239            needs_server_update: false,
 240            servers: HashMap::default(),
 241            update_servers_task: None,
 242            context_server_factory,
 243        };
 244        if maintain_server_loop {
 245            this.available_context_servers_changed(cx);
 246        }
 247        this
 248    }
 249
 250    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 251        self.servers.get(id).map(|state| state.server())
 252    }
 253
 254    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 255        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 256            Some(server.clone())
 257        } else {
 258            None
 259        }
 260    }
 261
 262    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 263        self.servers.get(id).map(ContextServerStatus::from_state)
 264    }
 265
 266    pub fn configuration_for_server(
 267        &self,
 268        id: &ContextServerId,
 269    ) -> Option<Arc<ContextServerConfiguration>> {
 270        self.servers.get(id).map(|state| state.configuration())
 271    }
 272
 273    pub fn all_server_ids(&self) -> Vec<ContextServerId> {
 274        self.servers.keys().cloned().collect()
 275    }
 276
 277    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 278        self.servers
 279            .values()
 280            .filter_map(|state| {
 281                if let ContextServerState::Running { server, .. } = state {
 282                    Some(server.clone())
 283                } else {
 284                    None
 285                }
 286            })
 287            .collect()
 288    }
 289
 290    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 291        cx.spawn(async move |this, cx| {
 292            let this = this.upgrade().context("Context server store dropped")?;
 293            let settings = this
 294                .update(cx, |this, _| {
 295                    this.context_server_settings.get(&server.id().0).cloned()
 296                })
 297                .ok()
 298                .flatten()
 299                .context("Failed to get context server settings")?;
 300
 301            if !settings.enabled() {
 302                return Ok(());
 303            }
 304
 305            let (registry, worktree_store) = this.update(cx, |this, _| {
 306                (this.registry.clone(), this.worktree_store.clone())
 307            })?;
 308            let configuration = ContextServerConfiguration::from_settings(
 309                settings,
 310                server.id(),
 311                registry,
 312                worktree_store,
 313                cx,
 314            )
 315            .await
 316            .context("Failed to create context server configuration")?;
 317
 318            this.update(cx, |this, cx| {
 319                this.run_server(server, Arc::new(configuration), cx)
 320            })
 321        })
 322        .detach_and_log_err(cx);
 323    }
 324
 325    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 326        if matches!(
 327            self.servers.get(id),
 328            Some(ContextServerState::Stopped { .. })
 329        ) {
 330            return Ok(());
 331        }
 332
 333        let state = self
 334            .servers
 335            .remove(id)
 336            .context("Context server not found")?;
 337
 338        let server = state.server();
 339        let configuration = state.configuration();
 340        let mut result = Ok(());
 341        if let ContextServerState::Running { server, .. } = &state {
 342            result = server.stop();
 343        }
 344        drop(state);
 345
 346        self.update_server_state(
 347            id.clone(),
 348            ContextServerState::Stopped {
 349                configuration,
 350                server,
 351            },
 352            cx,
 353        );
 354
 355        result
 356    }
 357
 358    pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 359        if let Some(state) = self.servers.get(&id) {
 360            let configuration = state.configuration();
 361
 362            self.stop_server(&state.server().id(), cx)?;
 363            let new_server = self.create_context_server(id.clone(), configuration.clone())?;
 364            self.run_server(new_server, configuration, cx);
 365        }
 366        Ok(())
 367    }
 368
 369    fn run_server(
 370        &mut self,
 371        server: Arc<ContextServer>,
 372        configuration: Arc<ContextServerConfiguration>,
 373        cx: &mut Context<Self>,
 374    ) {
 375        let id = server.id();
 376        if matches!(
 377            self.servers.get(&id),
 378            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 379        ) {
 380            self.stop_server(&id, cx).log_err();
 381        }
 382
 383        let task = cx.spawn({
 384            let id = server.id();
 385            let server = server.clone();
 386            let configuration = configuration.clone();
 387            async move |this, cx| {
 388                match server.clone().start(&cx).await {
 389                    Ok(_) => {
 390                        log::info!("Started {} context server", id);
 391                        debug_assert!(server.client().is_some());
 392
 393                        this.update(cx, |this, cx| {
 394                            this.update_server_state(
 395                                id.clone(),
 396                                ContextServerState::Running {
 397                                    server,
 398                                    configuration,
 399                                },
 400                                cx,
 401                            )
 402                        })
 403                        .log_err()
 404                    }
 405                    Err(err) => {
 406                        log::error!("{} context server failed to start: {}", id, err);
 407                        this.update(cx, |this, cx| {
 408                            this.update_server_state(
 409                                id.clone(),
 410                                ContextServerState::Error {
 411                                    configuration,
 412                                    server,
 413                                    error: err.to_string().into(),
 414                                },
 415                                cx,
 416                            )
 417                        })
 418                        .log_err()
 419                    }
 420                };
 421            }
 422        });
 423
 424        self.update_server_state(
 425            id.clone(),
 426            ContextServerState::Starting {
 427                configuration,
 428                _task: task,
 429                server,
 430            },
 431            cx,
 432        );
 433    }
 434
 435    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 436        let state = self
 437            .servers
 438            .remove(id)
 439            .context("Context server not found")?;
 440        drop(state);
 441        cx.emit(Event::ServerStatusChanged {
 442            server_id: id.clone(),
 443            status: ContextServerStatus::Stopped,
 444        });
 445        Ok(())
 446    }
 447
 448    fn create_context_server(
 449        &self,
 450        id: ContextServerId,
 451        configuration: Arc<ContextServerConfiguration>,
 452    ) -> Result<Arc<ContextServer>> {
 453        if let Some(factory) = self.context_server_factory.as_ref() {
 454            Ok(factory(id, configuration))
 455        } else {
 456            Ok(Arc::new(ContextServer::stdio(
 457                id,
 458                configuration.command().clone(),
 459            )))
 460        }
 461    }
 462
 463    fn resolve_context_server_settings<'a>(
 464        worktree_store: &'a Entity<WorktreeStore>,
 465        cx: &'a App,
 466    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 467        let location = worktree_store
 468            .read(cx)
 469            .visible_worktrees(cx)
 470            .next()
 471            .map(|worktree| settings::SettingsLocation {
 472                worktree_id: worktree.read(cx).id(),
 473                path: Path::new(""),
 474            });
 475        &ProjectSettings::get(location, cx).context_servers
 476    }
 477
 478    fn update_server_state(
 479        &mut self,
 480        id: ContextServerId,
 481        state: ContextServerState,
 482        cx: &mut Context<Self>,
 483    ) {
 484        let status = ContextServerStatus::from_state(&state);
 485        self.servers.insert(id.clone(), state);
 486        cx.emit(Event::ServerStatusChanged {
 487            server_id: id,
 488            status,
 489        });
 490    }
 491
 492    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 493        if self.update_servers_task.is_some() {
 494            self.needs_server_update = true;
 495        } else {
 496            self.needs_server_update = false;
 497            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 498                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 499                    log::error!("Error maintaining context servers: {}", err);
 500                }
 501
 502                this.update(cx, |this, cx| {
 503                    this.update_servers_task.take();
 504                    if this.needs_server_update {
 505                        this.available_context_servers_changed(cx);
 506                    }
 507                })?;
 508
 509                Ok(())
 510            }));
 511        }
 512    }
 513
 514    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 515        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
 516            (
 517                this.context_server_settings.clone(),
 518                this.registry.clone(),
 519                this.worktree_store.clone(),
 520            )
 521        })?;
 522
 523        for (id, _) in
 524            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 525        {
 526            configured_servers
 527                .entry(id)
 528                .or_insert(ContextServerSettings::default_extension());
 529        }
 530
 531        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
 532            configured_servers
 533                .into_iter()
 534                .partition(|(_, settings)| settings.enabled());
 535
 536        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
 537            let id = ContextServerId(id);
 538            ContextServerConfiguration::from_settings(
 539                settings,
 540                id.clone(),
 541                registry.clone(),
 542                worktree_store.clone(),
 543                cx,
 544            )
 545            .map(|config| (id, config))
 546        }))
 547        .await
 548        .into_iter()
 549        .filter_map(|(id, config)| config.map(|config| (id, config)))
 550        .collect::<HashMap<_, _>>();
 551
 552        let mut servers_to_start = Vec::new();
 553        let mut servers_to_remove = HashSet::default();
 554        let mut servers_to_stop = HashSet::default();
 555
 556        this.update(cx, |this, _cx| {
 557            for server_id in this.servers.keys() {
 558                // All servers that are not in desired_servers should be removed from the store.
 559                // This can happen if the user removed a server from the context server settings.
 560                if !configured_servers.contains_key(&server_id) {
 561                    if disabled_servers.contains_key(&server_id.0) {
 562                        servers_to_stop.insert(server_id.clone());
 563                    } else {
 564                        servers_to_remove.insert(server_id.clone());
 565                    }
 566                }
 567            }
 568
 569            for (id, config) in configured_servers {
 570                let state = this.servers.get(&id);
 571                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
 572                let existing_config = state.as_ref().map(|state| state.configuration());
 573                if existing_config.as_deref() != Some(&config) || is_stopped {
 574                    let config = Arc::new(config);
 575                    if let Some(server) = this
 576                        .create_context_server(id.clone(), config.clone())
 577                        .log_err()
 578                    {
 579                        servers_to_start.push((server, config));
 580                        if this.servers.contains_key(&id) {
 581                            servers_to_stop.insert(id);
 582                        }
 583                    }
 584                }
 585            }
 586        })?;
 587
 588        this.update(cx, |this, cx| {
 589            for id in servers_to_stop {
 590                this.stop_server(&id, cx)?;
 591            }
 592            for id in servers_to_remove {
 593                this.remove_server(&id, cx)?;
 594            }
 595            for (server, config) in servers_to_start {
 596                this.run_server(server, config, cx);
 597            }
 598            anyhow::Ok(())
 599        })?
 600    }
 601}
 602
 603#[cfg(test)]
 604mod tests {
 605    use super::*;
 606    use crate::{
 607        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 608        project_settings::ProjectSettings,
 609    };
 610    use context_server::test::create_fake_transport;
 611    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 612    use serde_json::json;
 613    use std::{cell::RefCell, path::PathBuf, rc::Rc};
 614    use util::path;
 615
 616    #[gpui::test]
 617    async fn test_context_server_status(cx: &mut TestAppContext) {
 618        const SERVER_1_ID: &'static str = "mcp-1";
 619        const SERVER_2_ID: &'static str = "mcp-2";
 620
 621        let (_fs, project) = setup_context_server_test(
 622            cx,
 623            json!({"code.rs": ""}),
 624            vec![
 625                (SERVER_1_ID.into(), dummy_server_settings()),
 626                (SERVER_2_ID.into(), dummy_server_settings()),
 627            ],
 628        )
 629        .await;
 630
 631        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 632        let store = cx.new(|cx| {
 633            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 634        });
 635
 636        let server_1_id = ContextServerId(SERVER_1_ID.into());
 637        let server_2_id = ContextServerId(SERVER_2_ID.into());
 638
 639        let server_1 = Arc::new(ContextServer::new(
 640            server_1_id.clone(),
 641            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 642        ));
 643        let server_2 = Arc::new(ContextServer::new(
 644            server_2_id.clone(),
 645            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 646        ));
 647
 648        store.update(cx, |store, cx| store.start_server(server_1, cx));
 649
 650        cx.run_until_parked();
 651
 652        cx.update(|cx| {
 653            assert_eq!(
 654                store.read(cx).status_for_server(&server_1_id),
 655                Some(ContextServerStatus::Running)
 656            );
 657            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 658        });
 659
 660        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 661
 662        cx.run_until_parked();
 663
 664        cx.update(|cx| {
 665            assert_eq!(
 666                store.read(cx).status_for_server(&server_1_id),
 667                Some(ContextServerStatus::Running)
 668            );
 669            assert_eq!(
 670                store.read(cx).status_for_server(&server_2_id),
 671                Some(ContextServerStatus::Running)
 672            );
 673        });
 674
 675        store
 676            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 677            .unwrap();
 678
 679        cx.update(|cx| {
 680            assert_eq!(
 681                store.read(cx).status_for_server(&server_1_id),
 682                Some(ContextServerStatus::Running)
 683            );
 684            assert_eq!(
 685                store.read(cx).status_for_server(&server_2_id),
 686                Some(ContextServerStatus::Stopped)
 687            );
 688        });
 689    }
 690
 691    #[gpui::test]
 692    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 693        const SERVER_1_ID: &'static str = "mcp-1";
 694        const SERVER_2_ID: &'static str = "mcp-2";
 695
 696        let (_fs, project) = setup_context_server_test(
 697            cx,
 698            json!({"code.rs": ""}),
 699            vec![
 700                (SERVER_1_ID.into(), dummy_server_settings()),
 701                (SERVER_2_ID.into(), dummy_server_settings()),
 702            ],
 703        )
 704        .await;
 705
 706        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 707        let store = cx.new(|cx| {
 708            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 709        });
 710
 711        let server_1_id = ContextServerId(SERVER_1_ID.into());
 712        let server_2_id = ContextServerId(SERVER_2_ID.into());
 713
 714        let server_1 = Arc::new(ContextServer::new(
 715            server_1_id.clone(),
 716            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 717        ));
 718        let server_2 = Arc::new(ContextServer::new(
 719            server_2_id.clone(),
 720            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 721        ));
 722
 723        let _server_events = assert_server_events(
 724            &store,
 725            vec![
 726                (server_1_id.clone(), ContextServerStatus::Starting),
 727                (server_1_id.clone(), ContextServerStatus::Running),
 728                (server_2_id.clone(), ContextServerStatus::Starting),
 729                (server_2_id.clone(), ContextServerStatus::Running),
 730                (server_2_id.clone(), ContextServerStatus::Stopped),
 731            ],
 732            cx,
 733        );
 734
 735        store.update(cx, |store, cx| store.start_server(server_1, cx));
 736
 737        cx.run_until_parked();
 738
 739        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 740
 741        cx.run_until_parked();
 742
 743        store
 744            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 745            .unwrap();
 746    }
 747
 748    #[gpui::test(iterations = 25)]
 749    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 750        const SERVER_1_ID: &'static str = "mcp-1";
 751
 752        let (_fs, project) = setup_context_server_test(
 753            cx,
 754            json!({"code.rs": ""}),
 755            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 756        )
 757        .await;
 758
 759        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 760        let store = cx.new(|cx| {
 761            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 762        });
 763
 764        let server_id = ContextServerId(SERVER_1_ID.into());
 765
 766        let server_with_same_id_1 = Arc::new(ContextServer::new(
 767            server_id.clone(),
 768            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 769        ));
 770        let server_with_same_id_2 = Arc::new(ContextServer::new(
 771            server_id.clone(),
 772            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 773        ));
 774
 775        // If we start another server with the same id, we should report that we stopped the previous one
 776        let _server_events = assert_server_events(
 777            &store,
 778            vec![
 779                (server_id.clone(), ContextServerStatus::Starting),
 780                (server_id.clone(), ContextServerStatus::Stopped),
 781                (server_id.clone(), ContextServerStatus::Starting),
 782                (server_id.clone(), ContextServerStatus::Running),
 783            ],
 784            cx,
 785        );
 786
 787        store.update(cx, |store, cx| {
 788            store.start_server(server_with_same_id_1.clone(), cx)
 789        });
 790        store.update(cx, |store, cx| {
 791            store.start_server(server_with_same_id_2.clone(), cx)
 792        });
 793
 794        cx.run_until_parked();
 795
 796        cx.update(|cx| {
 797            assert_eq!(
 798                store.read(cx).status_for_server(&server_id),
 799                Some(ContextServerStatus::Running)
 800            );
 801        });
 802    }
 803
 804    #[gpui::test]
 805    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 806        const SERVER_1_ID: &'static str = "mcp-1";
 807        const SERVER_2_ID: &'static str = "mcp-2";
 808
 809        let server_1_id = ContextServerId(SERVER_1_ID.into());
 810        let server_2_id = ContextServerId(SERVER_2_ID.into());
 811
 812        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 813
 814        let (_fs, project) = setup_context_server_test(
 815            cx,
 816            json!({"code.rs": ""}),
 817            vec![(
 818                SERVER_1_ID.into(),
 819                ContextServerSettings::Extension {
 820                    enabled: true,
 821                    settings: json!({
 822                        "somevalue": true
 823                    }),
 824                },
 825            )],
 826        )
 827        .await;
 828
 829        let executor = cx.executor();
 830        let registry = cx.new(|cx| {
 831            let mut registry = ContextServerDescriptorRegistry::new();
 832            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
 833            registry
 834        });
 835        let store = cx.new(|cx| {
 836            ContextServerStore::test_maintain_server_loop(
 837                Box::new(move |id, _| {
 838                    Arc::new(ContextServer::new(
 839                        id.clone(),
 840                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 841                    ))
 842                }),
 843                registry.clone(),
 844                project.read(cx).worktree_store(),
 845                cx,
 846            )
 847        });
 848
 849        // Ensure that mcp-1 starts up
 850        {
 851            let _server_events = assert_server_events(
 852                &store,
 853                vec![
 854                    (server_1_id.clone(), ContextServerStatus::Starting),
 855                    (server_1_id.clone(), ContextServerStatus::Running),
 856                ],
 857                cx,
 858            );
 859            cx.run_until_parked();
 860        }
 861
 862        // Ensure that mcp-1 is restarted when the configuration was changed
 863        {
 864            let _server_events = assert_server_events(
 865                &store,
 866                vec![
 867                    (server_1_id.clone(), ContextServerStatus::Stopped),
 868                    (server_1_id.clone(), ContextServerStatus::Starting),
 869                    (server_1_id.clone(), ContextServerStatus::Running),
 870                ],
 871                cx,
 872            );
 873            set_context_server_configuration(
 874                vec![(
 875                    server_1_id.0.clone(),
 876                    ContextServerSettings::Extension {
 877                        enabled: true,
 878                        settings: json!({
 879                            "somevalue": false
 880                        }),
 881                    },
 882                )],
 883                cx,
 884            );
 885
 886            cx.run_until_parked();
 887        }
 888
 889        // Ensure that mcp-1 is not restarted when the configuration was not changed
 890        {
 891            let _server_events = assert_server_events(&store, vec![], cx);
 892            set_context_server_configuration(
 893                vec![(
 894                    server_1_id.0.clone(),
 895                    ContextServerSettings::Extension {
 896                        enabled: true,
 897                        settings: json!({
 898                            "somevalue": false
 899                        }),
 900                    },
 901                )],
 902                cx,
 903            );
 904
 905            cx.run_until_parked();
 906        }
 907
 908        // Ensure that mcp-2 is started once it is added to the settings
 909        {
 910            let _server_events = assert_server_events(
 911                &store,
 912                vec![
 913                    (server_2_id.clone(), ContextServerStatus::Starting),
 914                    (server_2_id.clone(), ContextServerStatus::Running),
 915                ],
 916                cx,
 917            );
 918            set_context_server_configuration(
 919                vec![
 920                    (
 921                        server_1_id.0.clone(),
 922                        ContextServerSettings::Extension {
 923                            enabled: true,
 924                            settings: json!({
 925                                "somevalue": false
 926                            }),
 927                        },
 928                    ),
 929                    (
 930                        server_2_id.0.clone(),
 931                        ContextServerSettings::Custom {
 932                            enabled: true,
 933                            command: ContextServerCommand {
 934                                path: "somebinary".into(),
 935                                args: vec!["arg".to_string()],
 936                                env: None,
 937                            },
 938                        },
 939                    ),
 940                ],
 941                cx,
 942            );
 943
 944            cx.run_until_parked();
 945        }
 946
 947        // Ensure that mcp-2 is restarted once the args have changed
 948        {
 949            let _server_events = assert_server_events(
 950                &store,
 951                vec![
 952                    (server_2_id.clone(), ContextServerStatus::Stopped),
 953                    (server_2_id.clone(), ContextServerStatus::Starting),
 954                    (server_2_id.clone(), ContextServerStatus::Running),
 955                ],
 956                cx,
 957            );
 958            set_context_server_configuration(
 959                vec![
 960                    (
 961                        server_1_id.0.clone(),
 962                        ContextServerSettings::Extension {
 963                            enabled: true,
 964                            settings: json!({
 965                                "somevalue": false
 966                            }),
 967                        },
 968                    ),
 969                    (
 970                        server_2_id.0.clone(),
 971                        ContextServerSettings::Custom {
 972                            enabled: true,
 973                            command: ContextServerCommand {
 974                                path: "somebinary".into(),
 975                                args: vec!["anotherArg".to_string()],
 976                                env: None,
 977                            },
 978                        },
 979                    ),
 980                ],
 981                cx,
 982            );
 983
 984            cx.run_until_parked();
 985        }
 986
 987        // Ensure that mcp-2 is removed once it is removed from the settings
 988        {
 989            let _server_events = assert_server_events(
 990                &store,
 991                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
 992                cx,
 993            );
 994            set_context_server_configuration(
 995                vec![(
 996                    server_1_id.0.clone(),
 997                    ContextServerSettings::Extension {
 998                        enabled: true,
 999                        settings: json!({
1000                            "somevalue": false
1001                        }),
1002                    },
1003                )],
1004                cx,
1005            );
1006
1007            cx.run_until_parked();
1008
1009            cx.update(|cx| {
1010                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1011            });
1012        }
1013
1014        // Ensure that nothing happens if the settings do not change
1015        {
1016            let _server_events = assert_server_events(&store, vec![], cx);
1017            set_context_server_configuration(
1018                vec![(
1019                    server_1_id.0.clone(),
1020                    ContextServerSettings::Extension {
1021                        enabled: true,
1022                        settings: json!({
1023                            "somevalue": false
1024                        }),
1025                    },
1026                )],
1027                cx,
1028            );
1029
1030            cx.run_until_parked();
1031
1032            cx.update(|cx| {
1033                assert_eq!(
1034                    store.read(cx).status_for_server(&server_1_id),
1035                    Some(ContextServerStatus::Running)
1036                );
1037                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1038            });
1039        }
1040    }
1041
1042    #[gpui::test]
1043    async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1044        const SERVER_1_ID: &'static str = "mcp-1";
1045
1046        let server_1_id = ContextServerId(SERVER_1_ID.into());
1047
1048        let (_fs, project) = setup_context_server_test(
1049            cx,
1050            json!({"code.rs": ""}),
1051            vec![(
1052                SERVER_1_ID.into(),
1053                ContextServerSettings::Custom {
1054                    enabled: true,
1055                    command: ContextServerCommand {
1056                        path: "somebinary".into(),
1057                        args: vec!["arg".to_string()],
1058                        env: None,
1059                    },
1060                },
1061            )],
1062        )
1063        .await;
1064
1065        let executor = cx.executor();
1066        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1067        let store = cx.new(|cx| {
1068            ContextServerStore::test_maintain_server_loop(
1069                Box::new(move |id, _| {
1070                    Arc::new(ContextServer::new(
1071                        id.clone(),
1072                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1073                    ))
1074                }),
1075                registry.clone(),
1076                project.read(cx).worktree_store(),
1077                cx,
1078            )
1079        });
1080
1081        // Ensure that mcp-1 starts up
1082        {
1083            let _server_events = assert_server_events(
1084                &store,
1085                vec![
1086                    (server_1_id.clone(), ContextServerStatus::Starting),
1087                    (server_1_id.clone(), ContextServerStatus::Running),
1088                ],
1089                cx,
1090            );
1091            cx.run_until_parked();
1092        }
1093
1094        // Ensure that mcp-1 is stopped once it is disabled.
1095        {
1096            let _server_events = assert_server_events(
1097                &store,
1098                vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1099                cx,
1100            );
1101            set_context_server_configuration(
1102                vec![(
1103                    server_1_id.0.clone(),
1104                    ContextServerSettings::Custom {
1105                        enabled: false,
1106                        command: ContextServerCommand {
1107                            path: "somebinary".into(),
1108                            args: vec!["arg".to_string()],
1109                            env: None,
1110                        },
1111                    },
1112                )],
1113                cx,
1114            );
1115
1116            cx.run_until_parked();
1117        }
1118
1119        // Ensure that mcp-1 is started once it is enabled again.
1120        {
1121            let _server_events = assert_server_events(
1122                &store,
1123                vec![
1124                    (server_1_id.clone(), ContextServerStatus::Starting),
1125                    (server_1_id.clone(), ContextServerStatus::Running),
1126                ],
1127                cx,
1128            );
1129            set_context_server_configuration(
1130                vec![(
1131                    server_1_id.0.clone(),
1132                    ContextServerSettings::Custom {
1133                        enabled: true,
1134                        command: ContextServerCommand {
1135                            path: "somebinary".into(),
1136                            args: vec!["arg".to_string()],
1137                            env: None,
1138                        },
1139                    },
1140                )],
1141                cx,
1142            );
1143
1144            cx.run_until_parked();
1145        }
1146    }
1147
1148    fn set_context_server_configuration(
1149        context_servers: Vec<(Arc<str>, ContextServerSettings)>,
1150        cx: &mut TestAppContext,
1151    ) {
1152        cx.update(|cx| {
1153            SettingsStore::update_global(cx, |store, cx| {
1154                let mut settings = ProjectSettings::default();
1155                for (id, config) in context_servers {
1156                    settings.context_servers.insert(id, config);
1157                }
1158                store
1159                    .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
1160                    .unwrap();
1161            })
1162        });
1163    }
1164
1165    struct ServerEvents {
1166        received_event_count: Rc<RefCell<usize>>,
1167        expected_event_count: usize,
1168        _subscription: Subscription,
1169    }
1170
1171    impl Drop for ServerEvents {
1172        fn drop(&mut self) {
1173            let actual_event_count = *self.received_event_count.borrow();
1174            assert_eq!(
1175                actual_event_count, self.expected_event_count,
1176                "
1177                Expected to receive {} context server store events, but received {} events",
1178                self.expected_event_count, actual_event_count
1179            );
1180        }
1181    }
1182
1183    fn dummy_server_settings() -> ContextServerSettings {
1184        ContextServerSettings::Custom {
1185            enabled: true,
1186            command: ContextServerCommand {
1187                path: "somebinary".into(),
1188                args: vec!["arg".to_string()],
1189                env: None,
1190            },
1191        }
1192    }
1193
1194    fn assert_server_events(
1195        store: &Entity<ContextServerStore>,
1196        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1197        cx: &mut TestAppContext,
1198    ) -> ServerEvents {
1199        cx.update(|cx| {
1200            let mut ix = 0;
1201            let received_event_count = Rc::new(RefCell::new(0));
1202            let expected_event_count = expected_events.len();
1203            let subscription = cx.subscribe(store, {
1204                let received_event_count = received_event_count.clone();
1205                move |_, event, _| match event {
1206                    Event::ServerStatusChanged {
1207                        server_id: actual_server_id,
1208                        status: actual_status,
1209                    } => {
1210                        let (expected_server_id, expected_status) = &expected_events[ix];
1211
1212                        assert_eq!(
1213                            actual_server_id, expected_server_id,
1214                            "Expected different server id at index {}",
1215                            ix
1216                        );
1217                        assert_eq!(
1218                            actual_status, expected_status,
1219                            "Expected different status at index {}",
1220                            ix
1221                        );
1222                        ix += 1;
1223                        *received_event_count.borrow_mut() += 1;
1224                    }
1225                }
1226            });
1227            ServerEvents {
1228                expected_event_count,
1229                received_event_count,
1230                _subscription: subscription,
1231            }
1232        })
1233    }
1234
1235    async fn setup_context_server_test(
1236        cx: &mut TestAppContext,
1237        files: serde_json::Value,
1238        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1239    ) -> (Arc<FakeFs>, Entity<Project>) {
1240        cx.update(|cx| {
1241            let settings_store = SettingsStore::test(cx);
1242            cx.set_global(settings_store);
1243            Project::init_settings(cx);
1244            let mut settings = ProjectSettings::get_global(cx).clone();
1245            for (id, config) in context_server_configurations {
1246                settings.context_servers.insert(id, config);
1247            }
1248            ProjectSettings::override_global(settings, cx);
1249        });
1250
1251        let fs = FakeFs::new(cx.executor());
1252        fs.insert_tree(path!("/test"), files).await;
1253        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1254
1255        (fs, project)
1256    }
1257
1258    struct FakeContextServerDescriptor {
1259        path: PathBuf,
1260    }
1261
1262    impl FakeContextServerDescriptor {
1263        fn new(path: impl Into<PathBuf>) -> Self {
1264            Self { path: path.into() }
1265        }
1266    }
1267
1268    impl ContextServerDescriptor for FakeContextServerDescriptor {
1269        fn command(
1270            &self,
1271            _worktree_store: Entity<WorktreeStore>,
1272            _cx: &AsyncApp,
1273        ) -> Task<Result<ContextServerCommand>> {
1274            Task::ready(Ok(ContextServerCommand {
1275                path: self.path.clone(),
1276                args: vec!["arg1".to_string(), "arg2".to_string()],
1277                env: None,
1278            }))
1279        }
1280
1281        fn configuration(
1282            &self,
1283            _worktree_store: Entity<WorktreeStore>,
1284            _cx: &AsyncApp,
1285        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1286            Task::ready(Ok(None))
1287        }
1288    }
1289}