context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::sync::Arc;
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use futures::{FutureExt as _, future::join_all};
  10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  11use registry::ContextServerDescriptorRegistry;
  12use settings::{Settings as _, SettingsStore};
  13use util::{ResultExt as _, rel_path::RelPath};
  14
  15use crate::{
  16    Project,
  17    project_settings::{ContextServerSettings, ProjectSettings},
  18    worktree_store::WorktreeStore,
  19};
  20
  21pub fn init(cx: &mut App) {
  22    extension::init(cx);
  23}
  24
  25actions!(
  26    context_server,
  27    [
  28        /// Restarts the context server.
  29        Restart
  30    ]
  31);
  32
  33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  34pub enum ContextServerStatus {
  35    Starting,
  36    Running,
  37    Stopped,
  38    Error(Arc<str>),
  39}
  40
  41impl ContextServerStatus {
  42    fn from_state(state: &ContextServerState) -> Self {
  43        match state {
  44            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  45            ContextServerState::Running { .. } => ContextServerStatus::Running,
  46            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  47            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  48        }
  49    }
  50}
  51
  52enum ContextServerState {
  53    Starting {
  54        server: Arc<ContextServer>,
  55        configuration: Arc<ContextServerConfiguration>,
  56        _task: Task<()>,
  57    },
  58    Running {
  59        server: Arc<ContextServer>,
  60        configuration: Arc<ContextServerConfiguration>,
  61    },
  62    Stopped {
  63        server: Arc<ContextServer>,
  64        configuration: Arc<ContextServerConfiguration>,
  65    },
  66    Error {
  67        server: Arc<ContextServer>,
  68        configuration: Arc<ContextServerConfiguration>,
  69        error: Arc<str>,
  70    },
  71}
  72
  73impl ContextServerState {
  74    pub fn server(&self) -> Arc<ContextServer> {
  75        match self {
  76            ContextServerState::Starting { server, .. } => server.clone(),
  77            ContextServerState::Running { server, .. } => server.clone(),
  78            ContextServerState::Stopped { server, .. } => server.clone(),
  79            ContextServerState::Error { server, .. } => server.clone(),
  80        }
  81    }
  82
  83    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  84        match self {
  85            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  86            ContextServerState::Running { configuration, .. } => configuration.clone(),
  87            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  88            ContextServerState::Error { configuration, .. } => configuration.clone(),
  89        }
  90    }
  91}
  92
  93#[derive(Debug, PartialEq, Eq)]
  94pub enum ContextServerConfiguration {
  95    Custom {
  96        command: ContextServerCommand,
  97    },
  98    Extension {
  99        command: ContextServerCommand,
 100        settings: serde_json::Value,
 101    },
 102}
 103
 104impl ContextServerConfiguration {
 105    pub fn command(&self) -> &ContextServerCommand {
 106        match self {
 107            ContextServerConfiguration::Custom { command } => command,
 108            ContextServerConfiguration::Extension { command, .. } => command,
 109        }
 110    }
 111
 112    pub async fn from_settings(
 113        settings: ContextServerSettings,
 114        id: ContextServerId,
 115        registry: Entity<ContextServerDescriptorRegistry>,
 116        worktree_store: Entity<WorktreeStore>,
 117        cx: &AsyncApp,
 118    ) -> Option<Self> {
 119        match settings {
 120            ContextServerSettings::Custom {
 121                enabled: _,
 122                command,
 123            } => Some(ContextServerConfiguration::Custom { command }),
 124            ContextServerSettings::Extension {
 125                enabled: _,
 126                settings,
 127            } => {
 128                let descriptor = cx
 129                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 130                    .ok()
 131                    .flatten()?;
 132
 133                match descriptor.command(worktree_store, cx).await {
 134                    Ok(command) => {
 135                        Some(ContextServerConfiguration::Extension { command, settings })
 136                    }
 137                    Err(e) => {
 138                        log::error!(
 139                            "Failed to create context server configuration from settings: {e:#}"
 140                        );
 141                        None
 142                    }
 143                }
 144            }
 145        }
 146    }
 147}
 148
 149pub type ContextServerFactory =
 150    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 151
 152pub struct ContextServerStore {
 153    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 154    servers: HashMap<ContextServerId, ContextServerState>,
 155    worktree_store: Entity<WorktreeStore>,
 156    project: WeakEntity<Project>,
 157    registry: Entity<ContextServerDescriptorRegistry>,
 158    update_servers_task: Option<Task<Result<()>>>,
 159    context_server_factory: Option<ContextServerFactory>,
 160    needs_server_update: bool,
 161    _subscriptions: Vec<Subscription>,
 162}
 163
 164pub enum Event {
 165    ServerStatusChanged {
 166        server_id: ContextServerId,
 167        status: ContextServerStatus,
 168    },
 169}
 170
 171impl EventEmitter<Event> for ContextServerStore {}
 172
 173impl ContextServerStore {
 174    pub fn new(
 175        worktree_store: Entity<WorktreeStore>,
 176        weak_project: WeakEntity<Project>,
 177        cx: &mut Context<Self>,
 178    ) -> Self {
 179        Self::new_internal(
 180            true,
 181            None,
 182            ContextServerDescriptorRegistry::default_global(cx),
 183            worktree_store,
 184            weak_project,
 185            cx,
 186        )
 187    }
 188
 189    /// Returns all configured context server ids, regardless of enabled state.
 190    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 191        self.context_server_settings
 192            .keys()
 193            .cloned()
 194            .map(ContextServerId)
 195            .collect()
 196    }
 197
 198    #[cfg(any(test, feature = "test-support"))]
 199    pub fn test(
 200        registry: Entity<ContextServerDescriptorRegistry>,
 201        worktree_store: Entity<WorktreeStore>,
 202        weak_project: WeakEntity<Project>,
 203        cx: &mut Context<Self>,
 204    ) -> Self {
 205        Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
 206    }
 207
 208    #[cfg(any(test, feature = "test-support"))]
 209    pub fn test_maintain_server_loop(
 210        context_server_factory: ContextServerFactory,
 211        registry: Entity<ContextServerDescriptorRegistry>,
 212        worktree_store: Entity<WorktreeStore>,
 213        weak_project: WeakEntity<Project>,
 214        cx: &mut Context<Self>,
 215    ) -> Self {
 216        Self::new_internal(
 217            true,
 218            Some(context_server_factory),
 219            registry,
 220            worktree_store,
 221            weak_project,
 222            cx,
 223        )
 224    }
 225
 226    fn new_internal(
 227        maintain_server_loop: bool,
 228        context_server_factory: Option<ContextServerFactory>,
 229        registry: Entity<ContextServerDescriptorRegistry>,
 230        worktree_store: Entity<WorktreeStore>,
 231        weak_project: WeakEntity<Project>,
 232        cx: &mut Context<Self>,
 233    ) -> Self {
 234        let subscriptions = if maintain_server_loop {
 235            vec![
 236                cx.observe(&registry, |this, _registry, cx| {
 237                    this.available_context_servers_changed(cx);
 238                }),
 239                cx.observe_global::<SettingsStore>(|this, cx| {
 240                    let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
 241                    if &this.context_server_settings == settings {
 242                        return;
 243                    }
 244                    this.context_server_settings = settings.clone();
 245                    this.available_context_servers_changed(cx);
 246                }),
 247            ]
 248        } else {
 249            Vec::new()
 250        };
 251
 252        let mut this = Self {
 253            _subscriptions: subscriptions,
 254            context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
 255                .clone(),
 256            worktree_store,
 257            project: weak_project,
 258            registry,
 259            needs_server_update: false,
 260            servers: HashMap::default(),
 261            update_servers_task: None,
 262            context_server_factory,
 263        };
 264        if maintain_server_loop {
 265            this.available_context_servers_changed(cx);
 266        }
 267        this
 268    }
 269
 270    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 271        self.servers.get(id).map(|state| state.server())
 272    }
 273
 274    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 275        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 276            Some(server.clone())
 277        } else {
 278            None
 279        }
 280    }
 281
 282    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 283        self.servers.get(id).map(ContextServerStatus::from_state)
 284    }
 285
 286    pub fn configuration_for_server(
 287        &self,
 288        id: &ContextServerId,
 289    ) -> Option<Arc<ContextServerConfiguration>> {
 290        self.servers.get(id).map(|state| state.configuration())
 291    }
 292
 293    pub fn server_ids(&self, cx: &App) -> HashSet<ContextServerId> {
 294        self.servers
 295            .keys()
 296            .cloned()
 297            .chain(
 298                self.registry
 299                    .read(cx)
 300                    .context_server_descriptors()
 301                    .into_iter()
 302                    .map(|(id, _)| ContextServerId(id)),
 303            )
 304            .collect()
 305    }
 306
 307    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 308        self.servers
 309            .values()
 310            .filter_map(|state| {
 311                if let ContextServerState::Running { server, .. } = state {
 312                    Some(server.clone())
 313                } else {
 314                    None
 315                }
 316            })
 317            .collect()
 318    }
 319
 320    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 321        cx.spawn(async move |this, cx| {
 322            let this = this.upgrade().context("Context server store dropped")?;
 323            let settings = this
 324                .update(cx, |this, _| {
 325                    this.context_server_settings.get(&server.id().0).cloned()
 326                })
 327                .ok()
 328                .flatten()
 329                .context("Failed to get context server settings")?;
 330
 331            if !settings.enabled() {
 332                return Ok(());
 333            }
 334
 335            let (registry, worktree_store) = this.update(cx, |this, _| {
 336                (this.registry.clone(), this.worktree_store.clone())
 337            })?;
 338            let configuration = ContextServerConfiguration::from_settings(
 339                settings,
 340                server.id(),
 341                registry,
 342                worktree_store,
 343                cx,
 344            )
 345            .await
 346            .context("Failed to create context server configuration")?;
 347
 348            this.update(cx, |this, cx| {
 349                this.run_server(server, Arc::new(configuration), cx)
 350            })
 351        })
 352        .detach_and_log_err(cx);
 353    }
 354
 355    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 356        if matches!(
 357            self.servers.get(id),
 358            Some(ContextServerState::Stopped { .. })
 359        ) {
 360            return Ok(());
 361        }
 362
 363        let state = self
 364            .servers
 365            .remove(id)
 366            .context("Context server not found")?;
 367
 368        let server = state.server();
 369        let configuration = state.configuration();
 370        let mut result = Ok(());
 371        if let ContextServerState::Running { server, .. } = &state {
 372            result = server.stop();
 373        }
 374        drop(state);
 375
 376        self.update_server_state(
 377            id.clone(),
 378            ContextServerState::Stopped {
 379                configuration,
 380                server,
 381            },
 382            cx,
 383        );
 384
 385        result
 386    }
 387
 388    pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 389        if let Some(state) = self.servers.get(id) {
 390            let configuration = state.configuration();
 391
 392            self.stop_server(&state.server().id(), cx)?;
 393            let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
 394            self.run_server(new_server, configuration, cx);
 395        }
 396        Ok(())
 397    }
 398
 399    fn run_server(
 400        &mut self,
 401        server: Arc<ContextServer>,
 402        configuration: Arc<ContextServerConfiguration>,
 403        cx: &mut Context<Self>,
 404    ) {
 405        let id = server.id();
 406        if matches!(
 407            self.servers.get(&id),
 408            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 409        ) {
 410            self.stop_server(&id, cx).log_err();
 411        }
 412
 413        let task = cx.spawn({
 414            let id = server.id();
 415            let server = server.clone();
 416            let configuration = configuration.clone();
 417            async move |this, cx| {
 418                match server.clone().start(cx).await {
 419                    Ok(_) => {
 420                        debug_assert!(server.client().is_some());
 421
 422                        this.update(cx, |this, cx| {
 423                            this.update_server_state(
 424                                id.clone(),
 425                                ContextServerState::Running {
 426                                    server,
 427                                    configuration,
 428                                },
 429                                cx,
 430                            )
 431                        })
 432                        .log_err()
 433                    }
 434                    Err(err) => {
 435                        log::error!("{} context server failed to start: {}", id, err);
 436                        this.update(cx, |this, cx| {
 437                            this.update_server_state(
 438                                id.clone(),
 439                                ContextServerState::Error {
 440                                    configuration,
 441                                    server,
 442                                    error: err.to_string().into(),
 443                                },
 444                                cx,
 445                            )
 446                        })
 447                        .log_err()
 448                    }
 449                };
 450            }
 451        });
 452
 453        self.update_server_state(
 454            id.clone(),
 455            ContextServerState::Starting {
 456                configuration,
 457                _task: task,
 458                server,
 459            },
 460            cx,
 461        );
 462    }
 463
 464    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 465        let state = self
 466            .servers
 467            .remove(id)
 468            .context("Context server not found")?;
 469        drop(state);
 470        cx.emit(Event::ServerStatusChanged {
 471            server_id: id.clone(),
 472            status: ContextServerStatus::Stopped,
 473        });
 474        Ok(())
 475    }
 476
 477    fn create_context_server(
 478        &self,
 479        id: ContextServerId,
 480        configuration: Arc<ContextServerConfiguration>,
 481        cx: &mut Context<Self>,
 482    ) -> Arc<ContextServer> {
 483        let project = self.project.upgrade();
 484        let mut root_path = None;
 485        if let Some(project) = project {
 486            let project = project.read(cx);
 487            if project.is_local() {
 488                if let Some(path) = project.active_project_directory(cx) {
 489                    root_path = Some(path);
 490                } else {
 491                    for worktree in self.worktree_store.read(cx).visible_worktrees(cx) {
 492                        if let Some(path) = worktree.read(cx).root_dir() {
 493                            root_path = Some(path);
 494                            break;
 495                        }
 496                    }
 497                }
 498            }
 499        };
 500
 501        if let Some(factory) = self.context_server_factory.as_ref() {
 502            factory(id, configuration)
 503        } else {
 504            Arc::new(ContextServer::stdio(
 505                id,
 506                configuration.command().clone(),
 507                root_path,
 508            ))
 509        }
 510    }
 511
 512    fn resolve_context_server_settings<'a>(
 513        worktree_store: &'a Entity<WorktreeStore>,
 514        cx: &'a App,
 515    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 516        let location = worktree_store
 517            .read(cx)
 518            .visible_worktrees(cx)
 519            .next()
 520            .map(|worktree| settings::SettingsLocation {
 521                worktree_id: worktree.read(cx).id(),
 522                path: RelPath::empty(),
 523            });
 524        &ProjectSettings::get(location, cx).context_servers
 525    }
 526
 527    fn update_server_state(
 528        &mut self,
 529        id: ContextServerId,
 530        state: ContextServerState,
 531        cx: &mut Context<Self>,
 532    ) {
 533        let status = ContextServerStatus::from_state(&state);
 534        self.servers.insert(id.clone(), state);
 535        cx.emit(Event::ServerStatusChanged {
 536            server_id: id,
 537            status,
 538        });
 539    }
 540
 541    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 542        if self.update_servers_task.is_some() {
 543            self.needs_server_update = true;
 544        } else {
 545            self.needs_server_update = false;
 546            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 547                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 548                    log::error!("Error maintaining context servers: {}", err);
 549                }
 550
 551                this.update(cx, |this, cx| {
 552                    this.update_servers_task.take();
 553                    if this.needs_server_update {
 554                        this.available_context_servers_changed(cx);
 555                    }
 556                })?;
 557
 558                Ok(())
 559            }));
 560        }
 561    }
 562
 563    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 564        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
 565            (
 566                this.context_server_settings.clone(),
 567                this.registry.clone(),
 568                this.worktree_store.clone(),
 569            )
 570        })?;
 571
 572        for (id, _) in
 573            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 574        {
 575            configured_servers
 576                .entry(id)
 577                .or_insert(ContextServerSettings::default_extension());
 578        }
 579
 580        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
 581            configured_servers
 582                .into_iter()
 583                .partition(|(_, settings)| settings.enabled());
 584
 585        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
 586            let id = ContextServerId(id);
 587            ContextServerConfiguration::from_settings(
 588                settings,
 589                id.clone(),
 590                registry.clone(),
 591                worktree_store.clone(),
 592                cx,
 593            )
 594            .map(|config| (id, config))
 595        }))
 596        .await
 597        .into_iter()
 598        .filter_map(|(id, config)| config.map(|config| (id, config)))
 599        .collect::<HashMap<_, _>>();
 600
 601        let mut servers_to_start = Vec::new();
 602        let mut servers_to_remove = HashSet::default();
 603        let mut servers_to_stop = HashSet::default();
 604
 605        this.update(cx, |this, cx| {
 606            for server_id in this.servers.keys() {
 607                // All servers that are not in desired_servers should be removed from the store.
 608                // This can happen if the user removed a server from the context server settings.
 609                if !configured_servers.contains_key(server_id) {
 610                    if disabled_servers.contains_key(&server_id.0) {
 611                        servers_to_stop.insert(server_id.clone());
 612                    } else {
 613                        servers_to_remove.insert(server_id.clone());
 614                    }
 615                }
 616            }
 617
 618            for (id, config) in configured_servers {
 619                let state = this.servers.get(&id);
 620                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
 621                let existing_config = state.as_ref().map(|state| state.configuration());
 622                if existing_config.as_deref() != Some(&config) || is_stopped {
 623                    let config = Arc::new(config);
 624                    let server = this.create_context_server(id.clone(), config.clone(), cx);
 625                    servers_to_start.push((server, config));
 626                    if this.servers.contains_key(&id) {
 627                        servers_to_stop.insert(id);
 628                    }
 629                }
 630            }
 631        })?;
 632
 633        this.update(cx, |this, cx| {
 634            for id in servers_to_stop {
 635                this.stop_server(&id, cx)?;
 636            }
 637            for id in servers_to_remove {
 638                this.remove_server(&id, cx)?;
 639            }
 640            for (server, config) in servers_to_start {
 641                this.run_server(server, config, cx);
 642            }
 643            anyhow::Ok(())
 644        })?
 645    }
 646}
 647
 648#[cfg(test)]
 649mod tests {
 650    use super::*;
 651    use crate::{
 652        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 653        project_settings::ProjectSettings,
 654    };
 655    use context_server::test::create_fake_transport;
 656    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 657    use serde_json::json;
 658    use std::{cell::RefCell, path::PathBuf, rc::Rc};
 659    use util::path;
 660
 661    #[gpui::test]
 662    async fn test_context_server_status(cx: &mut TestAppContext) {
 663        const SERVER_1_ID: &str = "mcp-1";
 664        const SERVER_2_ID: &str = "mcp-2";
 665
 666        let (_fs, project) = setup_context_server_test(
 667            cx,
 668            json!({"code.rs": ""}),
 669            vec![
 670                (SERVER_1_ID.into(), dummy_server_settings()),
 671                (SERVER_2_ID.into(), dummy_server_settings()),
 672            ],
 673        )
 674        .await;
 675
 676        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 677        let store = cx.new(|cx| {
 678            ContextServerStore::test(
 679                registry.clone(),
 680                project.read(cx).worktree_store(),
 681                project.downgrade(),
 682                cx,
 683            )
 684        });
 685
 686        let server_1_id = ContextServerId(SERVER_1_ID.into());
 687        let server_2_id = ContextServerId(SERVER_2_ID.into());
 688
 689        let server_1 = Arc::new(ContextServer::new(
 690            server_1_id.clone(),
 691            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 692        ));
 693        let server_2 = Arc::new(ContextServer::new(
 694            server_2_id.clone(),
 695            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 696        ));
 697
 698        store.update(cx, |store, cx| store.start_server(server_1, cx));
 699
 700        cx.run_until_parked();
 701
 702        cx.update(|cx| {
 703            assert_eq!(
 704                store.read(cx).status_for_server(&server_1_id),
 705                Some(ContextServerStatus::Running)
 706            );
 707            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 708        });
 709
 710        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 711
 712        cx.run_until_parked();
 713
 714        cx.update(|cx| {
 715            assert_eq!(
 716                store.read(cx).status_for_server(&server_1_id),
 717                Some(ContextServerStatus::Running)
 718            );
 719            assert_eq!(
 720                store.read(cx).status_for_server(&server_2_id),
 721                Some(ContextServerStatus::Running)
 722            );
 723        });
 724
 725        store
 726            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 727            .unwrap();
 728
 729        cx.update(|cx| {
 730            assert_eq!(
 731                store.read(cx).status_for_server(&server_1_id),
 732                Some(ContextServerStatus::Running)
 733            );
 734            assert_eq!(
 735                store.read(cx).status_for_server(&server_2_id),
 736                Some(ContextServerStatus::Stopped)
 737            );
 738        });
 739    }
 740
 741    #[gpui::test]
 742    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 743        const SERVER_1_ID: &str = "mcp-1";
 744        const SERVER_2_ID: &str = "mcp-2";
 745
 746        let (_fs, project) = setup_context_server_test(
 747            cx,
 748            json!({"code.rs": ""}),
 749            vec![
 750                (SERVER_1_ID.into(), dummy_server_settings()),
 751                (SERVER_2_ID.into(), dummy_server_settings()),
 752            ],
 753        )
 754        .await;
 755
 756        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 757        let store = cx.new(|cx| {
 758            ContextServerStore::test(
 759                registry.clone(),
 760                project.read(cx).worktree_store(),
 761                project.downgrade(),
 762                cx,
 763            )
 764        });
 765
 766        let server_1_id = ContextServerId(SERVER_1_ID.into());
 767        let server_2_id = ContextServerId(SERVER_2_ID.into());
 768
 769        let server_1 = Arc::new(ContextServer::new(
 770            server_1_id.clone(),
 771            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 772        ));
 773        let server_2 = Arc::new(ContextServer::new(
 774            server_2_id.clone(),
 775            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 776        ));
 777
 778        let _server_events = assert_server_events(
 779            &store,
 780            vec![
 781                (server_1_id.clone(), ContextServerStatus::Starting),
 782                (server_1_id, ContextServerStatus::Running),
 783                (server_2_id.clone(), ContextServerStatus::Starting),
 784                (server_2_id.clone(), ContextServerStatus::Running),
 785                (server_2_id.clone(), ContextServerStatus::Stopped),
 786            ],
 787            cx,
 788        );
 789
 790        store.update(cx, |store, cx| store.start_server(server_1, cx));
 791
 792        cx.run_until_parked();
 793
 794        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 795
 796        cx.run_until_parked();
 797
 798        store
 799            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 800            .unwrap();
 801    }
 802
 803    #[gpui::test(iterations = 25)]
 804    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 805        const SERVER_1_ID: &str = "mcp-1";
 806
 807        let (_fs, project) = setup_context_server_test(
 808            cx,
 809            json!({"code.rs": ""}),
 810            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 811        )
 812        .await;
 813
 814        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 815        let store = cx.new(|cx| {
 816            ContextServerStore::test(
 817                registry.clone(),
 818                project.read(cx).worktree_store(),
 819                project.downgrade(),
 820                cx,
 821            )
 822        });
 823
 824        let server_id = ContextServerId(SERVER_1_ID.into());
 825
 826        let server_with_same_id_1 = Arc::new(ContextServer::new(
 827            server_id.clone(),
 828            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 829        ));
 830        let server_with_same_id_2 = Arc::new(ContextServer::new(
 831            server_id.clone(),
 832            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 833        ));
 834
 835        // If we start another server with the same id, we should report that we stopped the previous one
 836        let _server_events = assert_server_events(
 837            &store,
 838            vec![
 839                (server_id.clone(), ContextServerStatus::Starting),
 840                (server_id.clone(), ContextServerStatus::Stopped),
 841                (server_id.clone(), ContextServerStatus::Starting),
 842                (server_id.clone(), ContextServerStatus::Running),
 843            ],
 844            cx,
 845        );
 846
 847        store.update(cx, |store, cx| {
 848            store.start_server(server_with_same_id_1.clone(), cx)
 849        });
 850        store.update(cx, |store, cx| {
 851            store.start_server(server_with_same_id_2.clone(), cx)
 852        });
 853
 854        cx.run_until_parked();
 855
 856        cx.update(|cx| {
 857            assert_eq!(
 858                store.read(cx).status_for_server(&server_id),
 859                Some(ContextServerStatus::Running)
 860            );
 861        });
 862    }
 863
 864    #[gpui::test]
 865    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 866        const SERVER_1_ID: &str = "mcp-1";
 867        const SERVER_2_ID: &str = "mcp-2";
 868
 869        let server_1_id = ContextServerId(SERVER_1_ID.into());
 870        let server_2_id = ContextServerId(SERVER_2_ID.into());
 871
 872        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 873
 874        let (_fs, project) = setup_context_server_test(
 875            cx,
 876            json!({"code.rs": ""}),
 877            vec![(
 878                SERVER_1_ID.into(),
 879                ContextServerSettings::Extension {
 880                    enabled: true,
 881                    settings: json!({
 882                        "somevalue": true
 883                    }),
 884                },
 885            )],
 886        )
 887        .await;
 888
 889        let executor = cx.executor();
 890        let registry = cx.new(|cx| {
 891            let mut registry = ContextServerDescriptorRegistry::new();
 892            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
 893            registry
 894        });
 895        let store = cx.new(|cx| {
 896            ContextServerStore::test_maintain_server_loop(
 897                Box::new(move |id, _| {
 898                    Arc::new(ContextServer::new(
 899                        id.clone(),
 900                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 901                    ))
 902                }),
 903                registry.clone(),
 904                project.read(cx).worktree_store(),
 905                project.downgrade(),
 906                cx,
 907            )
 908        });
 909
 910        // Ensure that mcp-1 starts up
 911        {
 912            let _server_events = assert_server_events(
 913                &store,
 914                vec![
 915                    (server_1_id.clone(), ContextServerStatus::Starting),
 916                    (server_1_id.clone(), ContextServerStatus::Running),
 917                ],
 918                cx,
 919            );
 920            cx.run_until_parked();
 921        }
 922
 923        // Ensure that mcp-1 is restarted when the configuration was changed
 924        {
 925            let _server_events = assert_server_events(
 926                &store,
 927                vec![
 928                    (server_1_id.clone(), ContextServerStatus::Stopped),
 929                    (server_1_id.clone(), ContextServerStatus::Starting),
 930                    (server_1_id.clone(), ContextServerStatus::Running),
 931                ],
 932                cx,
 933            );
 934            set_context_server_configuration(
 935                vec![(
 936                    server_1_id.0.clone(),
 937                    settings::ContextServerSettingsContent::Extension {
 938                        enabled: true,
 939                        settings: json!({
 940                            "somevalue": false
 941                        }),
 942                    },
 943                )],
 944                cx,
 945            );
 946
 947            cx.run_until_parked();
 948        }
 949
 950        // Ensure that mcp-1 is not restarted when the configuration was not changed
 951        {
 952            let _server_events = assert_server_events(&store, vec![], cx);
 953            set_context_server_configuration(
 954                vec![(
 955                    server_1_id.0.clone(),
 956                    settings::ContextServerSettingsContent::Extension {
 957                        enabled: true,
 958                        settings: json!({
 959                            "somevalue": false
 960                        }),
 961                    },
 962                )],
 963                cx,
 964            );
 965
 966            cx.run_until_parked();
 967        }
 968
 969        // Ensure that mcp-2 is started once it is added to the settings
 970        {
 971            let _server_events = assert_server_events(
 972                &store,
 973                vec![
 974                    (server_2_id.clone(), ContextServerStatus::Starting),
 975                    (server_2_id.clone(), ContextServerStatus::Running),
 976                ],
 977                cx,
 978            );
 979            set_context_server_configuration(
 980                vec![
 981                    (
 982                        server_1_id.0.clone(),
 983                        settings::ContextServerSettingsContent::Extension {
 984                            enabled: true,
 985                            settings: json!({
 986                                "somevalue": false
 987                            }),
 988                        },
 989                    ),
 990                    (
 991                        server_2_id.0.clone(),
 992                        settings::ContextServerSettingsContent::Custom {
 993                            enabled: true,
 994                            command: ContextServerCommand {
 995                                path: "somebinary".into(),
 996                                args: vec!["arg".to_string()],
 997                                env: None,
 998                                timeout: None,
 999                            },
1000                        },
1001                    ),
1002                ],
1003                cx,
1004            );
1005
1006            cx.run_until_parked();
1007        }
1008
1009        // Ensure that mcp-2 is restarted once the args have changed
1010        {
1011            let _server_events = assert_server_events(
1012                &store,
1013                vec![
1014                    (server_2_id.clone(), ContextServerStatus::Stopped),
1015                    (server_2_id.clone(), ContextServerStatus::Starting),
1016                    (server_2_id.clone(), ContextServerStatus::Running),
1017                ],
1018                cx,
1019            );
1020            set_context_server_configuration(
1021                vec![
1022                    (
1023                        server_1_id.0.clone(),
1024                        settings::ContextServerSettingsContent::Extension {
1025                            enabled: true,
1026                            settings: json!({
1027                                "somevalue": false
1028                            }),
1029                        },
1030                    ),
1031                    (
1032                        server_2_id.0.clone(),
1033                        settings::ContextServerSettingsContent::Custom {
1034                            enabled: true,
1035                            command: ContextServerCommand {
1036                                path: "somebinary".into(),
1037                                args: vec!["anotherArg".to_string()],
1038                                env: None,
1039                                timeout: None,
1040                            },
1041                        },
1042                    ),
1043                ],
1044                cx,
1045            );
1046
1047            cx.run_until_parked();
1048        }
1049
1050        // Ensure that mcp-2 is removed once it is removed from the settings
1051        {
1052            let _server_events = assert_server_events(
1053                &store,
1054                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1055                cx,
1056            );
1057            set_context_server_configuration(
1058                vec![(
1059                    server_1_id.0.clone(),
1060                    settings::ContextServerSettingsContent::Extension {
1061                        enabled: true,
1062                        settings: json!({
1063                            "somevalue": false
1064                        }),
1065                    },
1066                )],
1067                cx,
1068            );
1069
1070            cx.run_until_parked();
1071
1072            cx.update(|cx| {
1073                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1074            });
1075        }
1076
1077        // Ensure that nothing happens if the settings do not change
1078        {
1079            let _server_events = assert_server_events(&store, vec![], cx);
1080            set_context_server_configuration(
1081                vec![(
1082                    server_1_id.0.clone(),
1083                    settings::ContextServerSettingsContent::Extension {
1084                        enabled: true,
1085                        settings: json!({
1086                            "somevalue": false
1087                        }),
1088                    },
1089                )],
1090                cx,
1091            );
1092
1093            cx.run_until_parked();
1094
1095            cx.update(|cx| {
1096                assert_eq!(
1097                    store.read(cx).status_for_server(&server_1_id),
1098                    Some(ContextServerStatus::Running)
1099                );
1100                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1101            });
1102        }
1103    }
1104
1105    #[gpui::test]
1106    async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1107        const SERVER_1_ID: &str = "mcp-1";
1108
1109        let server_1_id = ContextServerId(SERVER_1_ID.into());
1110
1111        let (_fs, project) = setup_context_server_test(
1112            cx,
1113            json!({"code.rs": ""}),
1114            vec![(
1115                SERVER_1_ID.into(),
1116                ContextServerSettings::Custom {
1117                    enabled: true,
1118                    command: ContextServerCommand {
1119                        path: "somebinary".into(),
1120                        args: vec!["arg".to_string()],
1121                        env: None,
1122                        timeout: None,
1123                    },
1124                },
1125            )],
1126        )
1127        .await;
1128
1129        let executor = cx.executor();
1130        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1131        let store = cx.new(|cx| {
1132            ContextServerStore::test_maintain_server_loop(
1133                Box::new(move |id, _| {
1134                    Arc::new(ContextServer::new(
1135                        id.clone(),
1136                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1137                    ))
1138                }),
1139                registry.clone(),
1140                project.read(cx).worktree_store(),
1141                project.downgrade(),
1142                cx,
1143            )
1144        });
1145
1146        // Ensure that mcp-1 starts up
1147        {
1148            let _server_events = assert_server_events(
1149                &store,
1150                vec![
1151                    (server_1_id.clone(), ContextServerStatus::Starting),
1152                    (server_1_id.clone(), ContextServerStatus::Running),
1153                ],
1154                cx,
1155            );
1156            cx.run_until_parked();
1157        }
1158
1159        // Ensure that mcp-1 is stopped once it is disabled.
1160        {
1161            let _server_events = assert_server_events(
1162                &store,
1163                vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1164                cx,
1165            );
1166            set_context_server_configuration(
1167                vec![(
1168                    server_1_id.0.clone(),
1169                    settings::ContextServerSettingsContent::Custom {
1170                        enabled: false,
1171                        command: ContextServerCommand {
1172                            path: "somebinary".into(),
1173                            args: vec!["arg".to_string()],
1174                            env: None,
1175                            timeout: None,
1176                        },
1177                    },
1178                )],
1179                cx,
1180            );
1181
1182            cx.run_until_parked();
1183        }
1184
1185        // Ensure that mcp-1 is started once it is enabled again.
1186        {
1187            let _server_events = assert_server_events(
1188                &store,
1189                vec![
1190                    (server_1_id.clone(), ContextServerStatus::Starting),
1191                    (server_1_id.clone(), ContextServerStatus::Running),
1192                ],
1193                cx,
1194            );
1195            set_context_server_configuration(
1196                vec![(
1197                    server_1_id.0.clone(),
1198                    settings::ContextServerSettingsContent::Custom {
1199                        enabled: true,
1200                        command: ContextServerCommand {
1201                            path: "somebinary".into(),
1202                            args: vec!["arg".to_string()],
1203                            timeout: None,
1204                            env: None,
1205                        },
1206                    },
1207                )],
1208                cx,
1209            );
1210
1211            cx.run_until_parked();
1212        }
1213    }
1214
1215    fn set_context_server_configuration(
1216        context_servers: Vec<(Arc<str>, settings::ContextServerSettingsContent)>,
1217        cx: &mut TestAppContext,
1218    ) {
1219        cx.update(|cx| {
1220            SettingsStore::update_global(cx, |store, cx| {
1221                store.update_user_settings(cx, |content| {
1222                    content.project.context_servers.clear();
1223                    for (id, config) in context_servers {
1224                        content.project.context_servers.insert(id, config);
1225                    }
1226                });
1227            })
1228        });
1229    }
1230
1231    struct ServerEvents {
1232        received_event_count: Rc<RefCell<usize>>,
1233        expected_event_count: usize,
1234        _subscription: Subscription,
1235    }
1236
1237    impl Drop for ServerEvents {
1238        fn drop(&mut self) {
1239            let actual_event_count = *self.received_event_count.borrow();
1240            assert_eq!(
1241                actual_event_count, self.expected_event_count,
1242                "
1243                Expected to receive {} context server store events, but received {} events",
1244                self.expected_event_count, actual_event_count
1245            );
1246        }
1247    }
1248
1249    fn dummy_server_settings() -> ContextServerSettings {
1250        ContextServerSettings::Custom {
1251            enabled: true,
1252            command: ContextServerCommand {
1253                path: "somebinary".into(),
1254                args: vec!["arg".to_string()],
1255                env: None,
1256                timeout: None,
1257            },
1258        }
1259    }
1260
1261    fn assert_server_events(
1262        store: &Entity<ContextServerStore>,
1263        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1264        cx: &mut TestAppContext,
1265    ) -> ServerEvents {
1266        cx.update(|cx| {
1267            let mut ix = 0;
1268            let received_event_count = Rc::new(RefCell::new(0));
1269            let expected_event_count = expected_events.len();
1270            let subscription = cx.subscribe(store, {
1271                let received_event_count = received_event_count.clone();
1272                move |_, event, _| match event {
1273                    Event::ServerStatusChanged {
1274                        server_id: actual_server_id,
1275                        status: actual_status,
1276                    } => {
1277                        let (expected_server_id, expected_status) = &expected_events[ix];
1278
1279                        assert_eq!(
1280                            actual_server_id, expected_server_id,
1281                            "Expected different server id at index {}",
1282                            ix
1283                        );
1284                        assert_eq!(
1285                            actual_status, expected_status,
1286                            "Expected different status at index {}",
1287                            ix
1288                        );
1289                        ix += 1;
1290                        *received_event_count.borrow_mut() += 1;
1291                    }
1292                }
1293            });
1294            ServerEvents {
1295                expected_event_count,
1296                received_event_count,
1297                _subscription: subscription,
1298            }
1299        })
1300    }
1301
1302    async fn setup_context_server_test(
1303        cx: &mut TestAppContext,
1304        files: serde_json::Value,
1305        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1306    ) -> (Arc<FakeFs>, Entity<Project>) {
1307        cx.update(|cx| {
1308            let settings_store = SettingsStore::test(cx);
1309            cx.set_global(settings_store);
1310            let mut settings = ProjectSettings::get_global(cx).clone();
1311            for (id, config) in context_server_configurations {
1312                settings.context_servers.insert(id, config);
1313            }
1314            ProjectSettings::override_global(settings, cx);
1315        });
1316
1317        let fs = FakeFs::new(cx.executor());
1318        fs.insert_tree(path!("/test"), files).await;
1319        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1320
1321        (fs, project)
1322    }
1323
1324    struct FakeContextServerDescriptor {
1325        path: PathBuf,
1326    }
1327
1328    impl FakeContextServerDescriptor {
1329        fn new(path: impl Into<PathBuf>) -> Self {
1330            Self { path: path.into() }
1331        }
1332    }
1333
1334    impl ContextServerDescriptor for FakeContextServerDescriptor {
1335        fn command(
1336            &self,
1337            _worktree_store: Entity<WorktreeStore>,
1338            _cx: &AsyncApp,
1339        ) -> Task<Result<ContextServerCommand>> {
1340            Task::ready(Ok(ContextServerCommand {
1341                path: self.path.clone(),
1342                args: vec!["arg1".to_string(), "arg2".to_string()],
1343                env: None,
1344                timeout: None,
1345            }))
1346        }
1347
1348        fn configuration(
1349            &self,
1350            _worktree_store: Entity<WorktreeStore>,
1351            _cx: &AsyncApp,
1352        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1353            Task::ready(Ok(None))
1354        }
1355    }
1356}