context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::{path::Path, sync::Arc};
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use futures::{FutureExt as _, future::join_all};
  10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  11use registry::ContextServerDescriptorRegistry;
  12use settings::{Settings as _, SettingsStore};
  13use util::ResultExt as _;
  14
  15use crate::{
  16    project_settings::{ContextServerSettings, ProjectSettings},
  17    worktree_store::WorktreeStore,
  18};
  19
  20pub fn init(cx: &mut App) {
  21    extension::init(cx);
  22}
  23
  24actions!(
  25    context_server,
  26    [
  27        /// Restarts the context server.
  28        Restart
  29    ]
  30);
  31
  32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  33pub enum ContextServerStatus {
  34    Starting,
  35    Running,
  36    Stopped,
  37    Error(Arc<str>),
  38}
  39
  40impl ContextServerStatus {
  41    fn from_state(state: &ContextServerState) -> Self {
  42        match state {
  43            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  44            ContextServerState::Running { .. } => ContextServerStatus::Running,
  45            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  46            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  47        }
  48    }
  49}
  50
  51enum ContextServerState {
  52    Starting {
  53        server: Arc<ContextServer>,
  54        configuration: Arc<ContextServerConfiguration>,
  55        _task: Task<()>,
  56    },
  57    Running {
  58        server: Arc<ContextServer>,
  59        configuration: Arc<ContextServerConfiguration>,
  60    },
  61    Stopped {
  62        server: Arc<ContextServer>,
  63        configuration: Arc<ContextServerConfiguration>,
  64    },
  65    Error {
  66        server: Arc<ContextServer>,
  67        configuration: Arc<ContextServerConfiguration>,
  68        error: Arc<str>,
  69    },
  70}
  71
  72impl ContextServerState {
  73    pub fn server(&self) -> Arc<ContextServer> {
  74        match self {
  75            ContextServerState::Starting { server, .. } => server.clone(),
  76            ContextServerState::Running { server, .. } => server.clone(),
  77            ContextServerState::Stopped { server, .. } => server.clone(),
  78            ContextServerState::Error { server, .. } => server.clone(),
  79        }
  80    }
  81
  82    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  83        match self {
  84            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  85            ContextServerState::Running { configuration, .. } => configuration.clone(),
  86            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  87            ContextServerState::Error { configuration, .. } => configuration.clone(),
  88        }
  89    }
  90}
  91
  92#[derive(Debug, PartialEq, Eq)]
  93pub enum ContextServerConfiguration {
  94    Custom {
  95        command: ContextServerCommand,
  96    },
  97    Extension {
  98        command: ContextServerCommand,
  99        settings: serde_json::Value,
 100    },
 101}
 102
 103impl ContextServerConfiguration {
 104    pub fn command(&self) -> &ContextServerCommand {
 105        match self {
 106            ContextServerConfiguration::Custom { command } => command,
 107            ContextServerConfiguration::Extension { command, .. } => command,
 108        }
 109    }
 110
 111    pub async fn from_settings(
 112        settings: ContextServerSettings,
 113        id: ContextServerId,
 114        registry: Entity<ContextServerDescriptorRegistry>,
 115        worktree_store: Entity<WorktreeStore>,
 116        cx: &AsyncApp,
 117    ) -> Option<Self> {
 118        match settings {
 119            ContextServerSettings::Custom {
 120                enabled: _,
 121                command,
 122            } => Some(ContextServerConfiguration::Custom { command }),
 123            ContextServerSettings::Extension {
 124                enabled: _,
 125                settings,
 126            } => {
 127                let descriptor = cx
 128                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 129                    .ok()
 130                    .flatten()?;
 131
 132                let command = descriptor.command(worktree_store, cx).await.log_err()?;
 133
 134                Some(ContextServerConfiguration::Extension { command, settings })
 135            }
 136        }
 137    }
 138}
 139
 140pub type ContextServerFactory =
 141    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 142
 143pub struct ContextServerStore {
 144    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 145    servers: HashMap<ContextServerId, ContextServerState>,
 146    worktree_store: Entity<WorktreeStore>,
 147    registry: Entity<ContextServerDescriptorRegistry>,
 148    update_servers_task: Option<Task<Result<()>>>,
 149    context_server_factory: Option<ContextServerFactory>,
 150    needs_server_update: bool,
 151    _subscriptions: Vec<Subscription>,
 152}
 153
 154pub enum Event {
 155    ServerStatusChanged {
 156        server_id: ContextServerId,
 157        status: ContextServerStatus,
 158    },
 159}
 160
 161impl EventEmitter<Event> for ContextServerStore {}
 162
 163impl ContextServerStore {
 164    pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
 165        Self::new_internal(
 166            true,
 167            None,
 168            ContextServerDescriptorRegistry::default_global(cx),
 169            worktree_store,
 170            cx,
 171        )
 172    }
 173
 174    #[cfg(any(test, feature = "test-support"))]
 175    pub fn test(
 176        registry: Entity<ContextServerDescriptorRegistry>,
 177        worktree_store: Entity<WorktreeStore>,
 178        cx: &mut Context<Self>,
 179    ) -> Self {
 180        Self::new_internal(false, None, registry, worktree_store, cx)
 181    }
 182
 183    #[cfg(any(test, feature = "test-support"))]
 184    pub fn test_maintain_server_loop(
 185        context_server_factory: ContextServerFactory,
 186        registry: Entity<ContextServerDescriptorRegistry>,
 187        worktree_store: Entity<WorktreeStore>,
 188        cx: &mut Context<Self>,
 189    ) -> Self {
 190        Self::new_internal(
 191            true,
 192            Some(context_server_factory),
 193            registry,
 194            worktree_store,
 195            cx,
 196        )
 197    }
 198
 199    fn new_internal(
 200        maintain_server_loop: bool,
 201        context_server_factory: Option<ContextServerFactory>,
 202        registry: Entity<ContextServerDescriptorRegistry>,
 203        worktree_store: Entity<WorktreeStore>,
 204        cx: &mut Context<Self>,
 205    ) -> Self {
 206        let subscriptions = if maintain_server_loop {
 207            vec![
 208                cx.observe(&registry, |this, _registry, cx| {
 209                    this.available_context_servers_changed(cx);
 210                }),
 211                cx.observe_global::<SettingsStore>(|this, cx| {
 212                    let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
 213                    if &this.context_server_settings == settings {
 214                        return;
 215                    }
 216                    this.context_server_settings = settings.clone();
 217                    this.available_context_servers_changed(cx);
 218                }),
 219            ]
 220        } else {
 221            Vec::new()
 222        };
 223
 224        let mut this = Self {
 225            _subscriptions: subscriptions,
 226            context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
 227                .clone(),
 228            worktree_store,
 229            registry,
 230            needs_server_update: false,
 231            servers: HashMap::default(),
 232            update_servers_task: None,
 233            context_server_factory,
 234        };
 235        if maintain_server_loop {
 236            this.available_context_servers_changed(cx);
 237        }
 238        this
 239    }
 240
 241    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 242        self.servers.get(id).map(|state| state.server())
 243    }
 244
 245    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 246        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 247            Some(server.clone())
 248        } else {
 249            None
 250        }
 251    }
 252
 253    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 254        self.servers.get(id).map(ContextServerStatus::from_state)
 255    }
 256
 257    pub fn configuration_for_server(
 258        &self,
 259        id: &ContextServerId,
 260    ) -> Option<Arc<ContextServerConfiguration>> {
 261        self.servers.get(id).map(|state| state.configuration())
 262    }
 263
 264    pub fn all_server_ids(&self) -> Vec<ContextServerId> {
 265        self.servers.keys().cloned().collect()
 266    }
 267
 268    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 269        self.servers
 270            .values()
 271            .filter_map(|state| {
 272                if let ContextServerState::Running { server, .. } = state {
 273                    Some(server.clone())
 274                } else {
 275                    None
 276                }
 277            })
 278            .collect()
 279    }
 280
 281    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 282        cx.spawn(async move |this, cx| {
 283            let this = this.upgrade().context("Context server store dropped")?;
 284            let settings = this
 285                .update(cx, |this, _| {
 286                    this.context_server_settings.get(&server.id().0).cloned()
 287                })
 288                .ok()
 289                .flatten()
 290                .context("Failed to get context server settings")?;
 291
 292            if !settings.enabled() {
 293                return Ok(());
 294            }
 295
 296            let (registry, worktree_store) = this.update(cx, |this, _| {
 297                (this.registry.clone(), this.worktree_store.clone())
 298            })?;
 299            let configuration = ContextServerConfiguration::from_settings(
 300                settings,
 301                server.id(),
 302                registry,
 303                worktree_store,
 304                cx,
 305            )
 306            .await
 307            .context("Failed to create context server configuration")?;
 308
 309            this.update(cx, |this, cx| {
 310                this.run_server(server, Arc::new(configuration), cx)
 311            })
 312        })
 313        .detach_and_log_err(cx);
 314    }
 315
 316    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 317        if matches!(
 318            self.servers.get(id),
 319            Some(ContextServerState::Stopped { .. })
 320        ) {
 321            return Ok(());
 322        }
 323
 324        let state = self
 325            .servers
 326            .remove(id)
 327            .context("Context server not found")?;
 328
 329        let server = state.server();
 330        let configuration = state.configuration();
 331        let mut result = Ok(());
 332        if let ContextServerState::Running { server, .. } = &state {
 333            result = server.stop();
 334        }
 335        drop(state);
 336
 337        self.update_server_state(
 338            id.clone(),
 339            ContextServerState::Stopped {
 340                configuration,
 341                server,
 342            },
 343            cx,
 344        );
 345
 346        result
 347    }
 348
 349    pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 350        if let Some(state) = self.servers.get(&id) {
 351            let configuration = state.configuration();
 352
 353            self.stop_server(&state.server().id(), cx)?;
 354            let new_server = self.create_context_server(id.clone(), configuration.clone())?;
 355            self.run_server(new_server, configuration, cx);
 356        }
 357        Ok(())
 358    }
 359
 360    fn run_server(
 361        &mut self,
 362        server: Arc<ContextServer>,
 363        configuration: Arc<ContextServerConfiguration>,
 364        cx: &mut Context<Self>,
 365    ) {
 366        let id = server.id();
 367        if matches!(
 368            self.servers.get(&id),
 369            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 370        ) {
 371            self.stop_server(&id, cx).log_err();
 372        }
 373
 374        let task = cx.spawn({
 375            let id = server.id();
 376            let server = server.clone();
 377            let configuration = configuration.clone();
 378            async move |this, cx| {
 379                match server.clone().start(&cx).await {
 380                    Ok(_) => {
 381                        log::info!("Started {} context server", id);
 382                        debug_assert!(server.client().is_some());
 383
 384                        this.update(cx, |this, cx| {
 385                            this.update_server_state(
 386                                id.clone(),
 387                                ContextServerState::Running {
 388                                    server,
 389                                    configuration,
 390                                },
 391                                cx,
 392                            )
 393                        })
 394                        .log_err()
 395                    }
 396                    Err(err) => {
 397                        log::error!("{} context server failed to start: {}", id, err);
 398                        this.update(cx, |this, cx| {
 399                            this.update_server_state(
 400                                id.clone(),
 401                                ContextServerState::Error {
 402                                    configuration,
 403                                    server,
 404                                    error: err.to_string().into(),
 405                                },
 406                                cx,
 407                            )
 408                        })
 409                        .log_err()
 410                    }
 411                };
 412            }
 413        });
 414
 415        self.update_server_state(
 416            id.clone(),
 417            ContextServerState::Starting {
 418                configuration,
 419                _task: task,
 420                server,
 421            },
 422            cx,
 423        );
 424    }
 425
 426    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 427        let state = self
 428            .servers
 429            .remove(id)
 430            .context("Context server not found")?;
 431        drop(state);
 432        cx.emit(Event::ServerStatusChanged {
 433            server_id: id.clone(),
 434            status: ContextServerStatus::Stopped,
 435        });
 436        Ok(())
 437    }
 438
 439    fn create_context_server(
 440        &self,
 441        id: ContextServerId,
 442        configuration: Arc<ContextServerConfiguration>,
 443    ) -> Result<Arc<ContextServer>> {
 444        if let Some(factory) = self.context_server_factory.as_ref() {
 445            Ok(factory(id, configuration))
 446        } else {
 447            Ok(Arc::new(ContextServer::stdio(
 448                id,
 449                configuration.command().clone(),
 450            )))
 451        }
 452    }
 453
 454    fn resolve_context_server_settings<'a>(
 455        worktree_store: &'a Entity<WorktreeStore>,
 456        cx: &'a App,
 457    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 458        let location = worktree_store
 459            .read(cx)
 460            .visible_worktrees(cx)
 461            .next()
 462            .map(|worktree| settings::SettingsLocation {
 463                worktree_id: worktree.read(cx).id(),
 464                path: Path::new(""),
 465            });
 466        &ProjectSettings::get(location, cx).context_servers
 467    }
 468
 469    fn update_server_state(
 470        &mut self,
 471        id: ContextServerId,
 472        state: ContextServerState,
 473        cx: &mut Context<Self>,
 474    ) {
 475        let status = ContextServerStatus::from_state(&state);
 476        self.servers.insert(id.clone(), state);
 477        cx.emit(Event::ServerStatusChanged {
 478            server_id: id,
 479            status,
 480        });
 481    }
 482
 483    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 484        if self.update_servers_task.is_some() {
 485            self.needs_server_update = true;
 486        } else {
 487            self.needs_server_update = false;
 488            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 489                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 490                    log::error!("Error maintaining context servers: {}", err);
 491                }
 492
 493                this.update(cx, |this, cx| {
 494                    this.update_servers_task.take();
 495                    if this.needs_server_update {
 496                        this.available_context_servers_changed(cx);
 497                    }
 498                })?;
 499
 500                Ok(())
 501            }));
 502        }
 503    }
 504
 505    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 506        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
 507            (
 508                this.context_server_settings.clone(),
 509                this.registry.clone(),
 510                this.worktree_store.clone(),
 511            )
 512        })?;
 513
 514        for (id, _) in
 515            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 516        {
 517            configured_servers
 518                .entry(id)
 519                .or_insert(ContextServerSettings::default_extension());
 520        }
 521
 522        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
 523            configured_servers
 524                .into_iter()
 525                .partition(|(_, settings)| settings.enabled());
 526
 527        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
 528            let id = ContextServerId(id);
 529            ContextServerConfiguration::from_settings(
 530                settings,
 531                id.clone(),
 532                registry.clone(),
 533                worktree_store.clone(),
 534                cx,
 535            )
 536            .map(|config| (id, config))
 537        }))
 538        .await
 539        .into_iter()
 540        .filter_map(|(id, config)| config.map(|config| (id, config)))
 541        .collect::<HashMap<_, _>>();
 542
 543        let mut servers_to_start = Vec::new();
 544        let mut servers_to_remove = HashSet::default();
 545        let mut servers_to_stop = HashSet::default();
 546
 547        this.update(cx, |this, _cx| {
 548            for server_id in this.servers.keys() {
 549                // All servers that are not in desired_servers should be removed from the store.
 550                // This can happen if the user removed a server from the context server settings.
 551                if !configured_servers.contains_key(&server_id) {
 552                    if disabled_servers.contains_key(&server_id.0) {
 553                        servers_to_stop.insert(server_id.clone());
 554                    } else {
 555                        servers_to_remove.insert(server_id.clone());
 556                    }
 557                }
 558            }
 559
 560            for (id, config) in configured_servers {
 561                let state = this.servers.get(&id);
 562                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
 563                let existing_config = state.as_ref().map(|state| state.configuration());
 564                if existing_config.as_deref() != Some(&config) || is_stopped {
 565                    let config = Arc::new(config);
 566                    if let Some(server) = this
 567                        .create_context_server(id.clone(), config.clone())
 568                        .log_err()
 569                    {
 570                        servers_to_start.push((server, config));
 571                        if this.servers.contains_key(&id) {
 572                            servers_to_stop.insert(id);
 573                        }
 574                    }
 575                }
 576            }
 577        })?;
 578
 579        this.update(cx, |this, cx| {
 580            for id in servers_to_stop {
 581                this.stop_server(&id, cx)?;
 582            }
 583            for id in servers_to_remove {
 584                this.remove_server(&id, cx)?;
 585            }
 586            for (server, config) in servers_to_start {
 587                this.run_server(server, config, cx);
 588            }
 589            anyhow::Ok(())
 590        })?
 591    }
 592}
 593
 594#[cfg(test)]
 595mod tests {
 596    use super::*;
 597    use crate::{
 598        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 599        project_settings::ProjectSettings,
 600    };
 601    use context_server::test::create_fake_transport;
 602    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 603    use serde_json::json;
 604    use std::{cell::RefCell, rc::Rc};
 605    use util::path;
 606
 607    #[gpui::test]
 608    async fn test_context_server_status(cx: &mut TestAppContext) {
 609        const SERVER_1_ID: &'static str = "mcp-1";
 610        const SERVER_2_ID: &'static str = "mcp-2";
 611
 612        let (_fs, project) = setup_context_server_test(
 613            cx,
 614            json!({"code.rs": ""}),
 615            vec![
 616                (SERVER_1_ID.into(), dummy_server_settings()),
 617                (SERVER_2_ID.into(), dummy_server_settings()),
 618            ],
 619        )
 620        .await;
 621
 622        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 623        let store = cx.new(|cx| {
 624            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 625        });
 626
 627        let server_1_id = ContextServerId(SERVER_1_ID.into());
 628        let server_2_id = ContextServerId(SERVER_2_ID.into());
 629
 630        let server_1 = Arc::new(ContextServer::new(
 631            server_1_id.clone(),
 632            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 633        ));
 634        let server_2 = Arc::new(ContextServer::new(
 635            server_2_id.clone(),
 636            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 637        ));
 638
 639        store.update(cx, |store, cx| store.start_server(server_1, cx));
 640
 641        cx.run_until_parked();
 642
 643        cx.update(|cx| {
 644            assert_eq!(
 645                store.read(cx).status_for_server(&server_1_id),
 646                Some(ContextServerStatus::Running)
 647            );
 648            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 649        });
 650
 651        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 652
 653        cx.run_until_parked();
 654
 655        cx.update(|cx| {
 656            assert_eq!(
 657                store.read(cx).status_for_server(&server_1_id),
 658                Some(ContextServerStatus::Running)
 659            );
 660            assert_eq!(
 661                store.read(cx).status_for_server(&server_2_id),
 662                Some(ContextServerStatus::Running)
 663            );
 664        });
 665
 666        store
 667            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 668            .unwrap();
 669
 670        cx.update(|cx| {
 671            assert_eq!(
 672                store.read(cx).status_for_server(&server_1_id),
 673                Some(ContextServerStatus::Running)
 674            );
 675            assert_eq!(
 676                store.read(cx).status_for_server(&server_2_id),
 677                Some(ContextServerStatus::Stopped)
 678            );
 679        });
 680    }
 681
 682    #[gpui::test]
 683    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 684        const SERVER_1_ID: &'static str = "mcp-1";
 685        const SERVER_2_ID: &'static str = "mcp-2";
 686
 687        let (_fs, project) = setup_context_server_test(
 688            cx,
 689            json!({"code.rs": ""}),
 690            vec![
 691                (SERVER_1_ID.into(), dummy_server_settings()),
 692                (SERVER_2_ID.into(), dummy_server_settings()),
 693            ],
 694        )
 695        .await;
 696
 697        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 698        let store = cx.new(|cx| {
 699            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 700        });
 701
 702        let server_1_id = ContextServerId(SERVER_1_ID.into());
 703        let server_2_id = ContextServerId(SERVER_2_ID.into());
 704
 705        let server_1 = Arc::new(ContextServer::new(
 706            server_1_id.clone(),
 707            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 708        ));
 709        let server_2 = Arc::new(ContextServer::new(
 710            server_2_id.clone(),
 711            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 712        ));
 713
 714        let _server_events = assert_server_events(
 715            &store,
 716            vec![
 717                (server_1_id.clone(), ContextServerStatus::Starting),
 718                (server_1_id.clone(), ContextServerStatus::Running),
 719                (server_2_id.clone(), ContextServerStatus::Starting),
 720                (server_2_id.clone(), ContextServerStatus::Running),
 721                (server_2_id.clone(), ContextServerStatus::Stopped),
 722            ],
 723            cx,
 724        );
 725
 726        store.update(cx, |store, cx| store.start_server(server_1, cx));
 727
 728        cx.run_until_parked();
 729
 730        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 731
 732        cx.run_until_parked();
 733
 734        store
 735            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 736            .unwrap();
 737    }
 738
 739    #[gpui::test(iterations = 25)]
 740    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 741        const SERVER_1_ID: &'static str = "mcp-1";
 742
 743        let (_fs, project) = setup_context_server_test(
 744            cx,
 745            json!({"code.rs": ""}),
 746            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 747        )
 748        .await;
 749
 750        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 751        let store = cx.new(|cx| {
 752            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 753        });
 754
 755        let server_id = ContextServerId(SERVER_1_ID.into());
 756
 757        let server_with_same_id_1 = Arc::new(ContextServer::new(
 758            server_id.clone(),
 759            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 760        ));
 761        let server_with_same_id_2 = Arc::new(ContextServer::new(
 762            server_id.clone(),
 763            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 764        ));
 765
 766        // If we start another server with the same id, we should report that we stopped the previous one
 767        let _server_events = assert_server_events(
 768            &store,
 769            vec![
 770                (server_id.clone(), ContextServerStatus::Starting),
 771                (server_id.clone(), ContextServerStatus::Stopped),
 772                (server_id.clone(), ContextServerStatus::Starting),
 773                (server_id.clone(), ContextServerStatus::Running),
 774            ],
 775            cx,
 776        );
 777
 778        store.update(cx, |store, cx| {
 779            store.start_server(server_with_same_id_1.clone(), cx)
 780        });
 781        store.update(cx, |store, cx| {
 782            store.start_server(server_with_same_id_2.clone(), cx)
 783        });
 784
 785        cx.run_until_parked();
 786
 787        cx.update(|cx| {
 788            assert_eq!(
 789                store.read(cx).status_for_server(&server_id),
 790                Some(ContextServerStatus::Running)
 791            );
 792        });
 793    }
 794
 795    #[gpui::test]
 796    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 797        const SERVER_1_ID: &'static str = "mcp-1";
 798        const SERVER_2_ID: &'static str = "mcp-2";
 799
 800        let server_1_id = ContextServerId(SERVER_1_ID.into());
 801        let server_2_id = ContextServerId(SERVER_2_ID.into());
 802
 803        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 804
 805        let (_fs, project) = setup_context_server_test(
 806            cx,
 807            json!({"code.rs": ""}),
 808            vec![(
 809                SERVER_1_ID.into(),
 810                ContextServerSettings::Extension {
 811                    enabled: true,
 812                    settings: json!({
 813                        "somevalue": true
 814                    }),
 815                },
 816            )],
 817        )
 818        .await;
 819
 820        let executor = cx.executor();
 821        let registry = cx.new(|_| {
 822            let mut registry = ContextServerDescriptorRegistry::new();
 823            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1);
 824            registry
 825        });
 826        let store = cx.new(|cx| {
 827            ContextServerStore::test_maintain_server_loop(
 828                Box::new(move |id, _| {
 829                    Arc::new(ContextServer::new(
 830                        id.clone(),
 831                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 832                    ))
 833                }),
 834                registry.clone(),
 835                project.read(cx).worktree_store(),
 836                cx,
 837            )
 838        });
 839
 840        // Ensure that mcp-1 starts up
 841        {
 842            let _server_events = assert_server_events(
 843                &store,
 844                vec![
 845                    (server_1_id.clone(), ContextServerStatus::Starting),
 846                    (server_1_id.clone(), ContextServerStatus::Running),
 847                ],
 848                cx,
 849            );
 850            cx.run_until_parked();
 851        }
 852
 853        // Ensure that mcp-1 is restarted when the configuration was changed
 854        {
 855            let _server_events = assert_server_events(
 856                &store,
 857                vec![
 858                    (server_1_id.clone(), ContextServerStatus::Stopped),
 859                    (server_1_id.clone(), ContextServerStatus::Starting),
 860                    (server_1_id.clone(), ContextServerStatus::Running),
 861                ],
 862                cx,
 863            );
 864            set_context_server_configuration(
 865                vec![(
 866                    server_1_id.0.clone(),
 867                    ContextServerSettings::Extension {
 868                        enabled: true,
 869                        settings: json!({
 870                            "somevalue": false
 871                        }),
 872                    },
 873                )],
 874                cx,
 875            );
 876
 877            cx.run_until_parked();
 878        }
 879
 880        // Ensure that mcp-1 is not restarted when the configuration was not changed
 881        {
 882            let _server_events = assert_server_events(&store, vec![], cx);
 883            set_context_server_configuration(
 884                vec![(
 885                    server_1_id.0.clone(),
 886                    ContextServerSettings::Extension {
 887                        enabled: true,
 888                        settings: json!({
 889                            "somevalue": false
 890                        }),
 891                    },
 892                )],
 893                cx,
 894            );
 895
 896            cx.run_until_parked();
 897        }
 898
 899        // Ensure that mcp-2 is started once it is added to the settings
 900        {
 901            let _server_events = assert_server_events(
 902                &store,
 903                vec![
 904                    (server_2_id.clone(), ContextServerStatus::Starting),
 905                    (server_2_id.clone(), ContextServerStatus::Running),
 906                ],
 907                cx,
 908            );
 909            set_context_server_configuration(
 910                vec![
 911                    (
 912                        server_1_id.0.clone(),
 913                        ContextServerSettings::Extension {
 914                            enabled: true,
 915                            settings: json!({
 916                                "somevalue": false
 917                            }),
 918                        },
 919                    ),
 920                    (
 921                        server_2_id.0.clone(),
 922                        ContextServerSettings::Custom {
 923                            enabled: true,
 924                            command: ContextServerCommand {
 925                                path: "somebinary".to_string(),
 926                                args: vec!["arg".to_string()],
 927                                env: None,
 928                            },
 929                        },
 930                    ),
 931                ],
 932                cx,
 933            );
 934
 935            cx.run_until_parked();
 936        }
 937
 938        // Ensure that mcp-2 is restarted once the args have changed
 939        {
 940            let _server_events = assert_server_events(
 941                &store,
 942                vec![
 943                    (server_2_id.clone(), ContextServerStatus::Stopped),
 944                    (server_2_id.clone(), ContextServerStatus::Starting),
 945                    (server_2_id.clone(), ContextServerStatus::Running),
 946                ],
 947                cx,
 948            );
 949            set_context_server_configuration(
 950                vec![
 951                    (
 952                        server_1_id.0.clone(),
 953                        ContextServerSettings::Extension {
 954                            enabled: true,
 955                            settings: json!({
 956                                "somevalue": false
 957                            }),
 958                        },
 959                    ),
 960                    (
 961                        server_2_id.0.clone(),
 962                        ContextServerSettings::Custom {
 963                            enabled: true,
 964                            command: ContextServerCommand {
 965                                path: "somebinary".to_string(),
 966                                args: vec!["anotherArg".to_string()],
 967                                env: None,
 968                            },
 969                        },
 970                    ),
 971                ],
 972                cx,
 973            );
 974
 975            cx.run_until_parked();
 976        }
 977
 978        // Ensure that mcp-2 is removed once it is removed from the settings
 979        {
 980            let _server_events = assert_server_events(
 981                &store,
 982                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
 983                cx,
 984            );
 985            set_context_server_configuration(
 986                vec![(
 987                    server_1_id.0.clone(),
 988                    ContextServerSettings::Extension {
 989                        enabled: true,
 990                        settings: json!({
 991                            "somevalue": false
 992                        }),
 993                    },
 994                )],
 995                cx,
 996            );
 997
 998            cx.run_until_parked();
 999
1000            cx.update(|cx| {
1001                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1002            });
1003        }
1004
1005        // Ensure that nothing happens if the settings do not change
1006        {
1007            let _server_events = assert_server_events(&store, vec![], cx);
1008            set_context_server_configuration(
1009                vec![(
1010                    server_1_id.0.clone(),
1011                    ContextServerSettings::Extension {
1012                        enabled: true,
1013                        settings: json!({
1014                            "somevalue": false
1015                        }),
1016                    },
1017                )],
1018                cx,
1019            );
1020
1021            cx.run_until_parked();
1022
1023            cx.update(|cx| {
1024                assert_eq!(
1025                    store.read(cx).status_for_server(&server_1_id),
1026                    Some(ContextServerStatus::Running)
1027                );
1028                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1029            });
1030        }
1031    }
1032
1033    #[gpui::test]
1034    async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1035        const SERVER_1_ID: &'static str = "mcp-1";
1036
1037        let server_1_id = ContextServerId(SERVER_1_ID.into());
1038
1039        let (_fs, project) = setup_context_server_test(
1040            cx,
1041            json!({"code.rs": ""}),
1042            vec![(
1043                SERVER_1_ID.into(),
1044                ContextServerSettings::Custom {
1045                    enabled: true,
1046                    command: ContextServerCommand {
1047                        path: "somebinary".to_string(),
1048                        args: vec!["arg".to_string()],
1049                        env: None,
1050                    },
1051                },
1052            )],
1053        )
1054        .await;
1055
1056        let executor = cx.executor();
1057        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1058        let store = cx.new(|cx| {
1059            ContextServerStore::test_maintain_server_loop(
1060                Box::new(move |id, _| {
1061                    Arc::new(ContextServer::new(
1062                        id.clone(),
1063                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1064                    ))
1065                }),
1066                registry.clone(),
1067                project.read(cx).worktree_store(),
1068                cx,
1069            )
1070        });
1071
1072        // Ensure that mcp-1 starts up
1073        {
1074            let _server_events = assert_server_events(
1075                &store,
1076                vec![
1077                    (server_1_id.clone(), ContextServerStatus::Starting),
1078                    (server_1_id.clone(), ContextServerStatus::Running),
1079                ],
1080                cx,
1081            );
1082            cx.run_until_parked();
1083        }
1084
1085        // Ensure that mcp-1 is stopped once it is disabled.
1086        {
1087            let _server_events = assert_server_events(
1088                &store,
1089                vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1090                cx,
1091            );
1092            set_context_server_configuration(
1093                vec![(
1094                    server_1_id.0.clone(),
1095                    ContextServerSettings::Custom {
1096                        enabled: false,
1097                        command: ContextServerCommand {
1098                            path: "somebinary".to_string(),
1099                            args: vec!["arg".to_string()],
1100                            env: None,
1101                        },
1102                    },
1103                )],
1104                cx,
1105            );
1106
1107            cx.run_until_parked();
1108        }
1109
1110        // Ensure that mcp-1 is started once it is enabled again.
1111        {
1112            let _server_events = assert_server_events(
1113                &store,
1114                vec![
1115                    (server_1_id.clone(), ContextServerStatus::Starting),
1116                    (server_1_id.clone(), ContextServerStatus::Running),
1117                ],
1118                cx,
1119            );
1120            set_context_server_configuration(
1121                vec![(
1122                    server_1_id.0.clone(),
1123                    ContextServerSettings::Custom {
1124                        enabled: true,
1125                        command: ContextServerCommand {
1126                            path: "somebinary".to_string(),
1127                            args: vec!["arg".to_string()],
1128                            env: None,
1129                        },
1130                    },
1131                )],
1132                cx,
1133            );
1134
1135            cx.run_until_parked();
1136        }
1137    }
1138
1139    fn set_context_server_configuration(
1140        context_servers: Vec<(Arc<str>, ContextServerSettings)>,
1141        cx: &mut TestAppContext,
1142    ) {
1143        cx.update(|cx| {
1144            SettingsStore::update_global(cx, |store, cx| {
1145                let mut settings = ProjectSettings::default();
1146                for (id, config) in context_servers {
1147                    settings.context_servers.insert(id, config);
1148                }
1149                store
1150                    .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
1151                    .unwrap();
1152            })
1153        });
1154    }
1155
1156    struct ServerEvents {
1157        received_event_count: Rc<RefCell<usize>>,
1158        expected_event_count: usize,
1159        _subscription: Subscription,
1160    }
1161
1162    impl Drop for ServerEvents {
1163        fn drop(&mut self) {
1164            let actual_event_count = *self.received_event_count.borrow();
1165            assert_eq!(
1166                actual_event_count, self.expected_event_count,
1167                "
1168                Expected to receive {} context server store events, but received {} events",
1169                self.expected_event_count, actual_event_count
1170            );
1171        }
1172    }
1173
1174    fn dummy_server_settings() -> ContextServerSettings {
1175        ContextServerSettings::Custom {
1176            enabled: true,
1177            command: ContextServerCommand {
1178                path: "somebinary".to_string(),
1179                args: vec!["arg".to_string()],
1180                env: None,
1181            },
1182        }
1183    }
1184
1185    fn assert_server_events(
1186        store: &Entity<ContextServerStore>,
1187        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1188        cx: &mut TestAppContext,
1189    ) -> ServerEvents {
1190        cx.update(|cx| {
1191            let mut ix = 0;
1192            let received_event_count = Rc::new(RefCell::new(0));
1193            let expected_event_count = expected_events.len();
1194            let subscription = cx.subscribe(store, {
1195                let received_event_count = received_event_count.clone();
1196                move |_, event, _| match event {
1197                    Event::ServerStatusChanged {
1198                        server_id: actual_server_id,
1199                        status: actual_status,
1200                    } => {
1201                        let (expected_server_id, expected_status) = &expected_events[ix];
1202
1203                        assert_eq!(
1204                            actual_server_id, expected_server_id,
1205                            "Expected different server id at index {}",
1206                            ix
1207                        );
1208                        assert_eq!(
1209                            actual_status, expected_status,
1210                            "Expected different status at index {}",
1211                            ix
1212                        );
1213                        ix += 1;
1214                        *received_event_count.borrow_mut() += 1;
1215                    }
1216                }
1217            });
1218            ServerEvents {
1219                expected_event_count,
1220                received_event_count,
1221                _subscription: subscription,
1222            }
1223        })
1224    }
1225
1226    async fn setup_context_server_test(
1227        cx: &mut TestAppContext,
1228        files: serde_json::Value,
1229        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1230    ) -> (Arc<FakeFs>, Entity<Project>) {
1231        cx.update(|cx| {
1232            let settings_store = SettingsStore::test(cx);
1233            cx.set_global(settings_store);
1234            Project::init_settings(cx);
1235            let mut settings = ProjectSettings::get_global(cx).clone();
1236            for (id, config) in context_server_configurations {
1237                settings.context_servers.insert(id, config);
1238            }
1239            ProjectSettings::override_global(settings, cx);
1240        });
1241
1242        let fs = FakeFs::new(cx.executor());
1243        fs.insert_tree(path!("/test"), files).await;
1244        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1245
1246        (fs, project)
1247    }
1248
1249    struct FakeContextServerDescriptor {
1250        path: String,
1251    }
1252
1253    impl FakeContextServerDescriptor {
1254        fn new(path: impl Into<String>) -> Self {
1255            Self { path: path.into() }
1256        }
1257    }
1258
1259    impl ContextServerDescriptor for FakeContextServerDescriptor {
1260        fn command(
1261            &self,
1262            _worktree_store: Entity<WorktreeStore>,
1263            _cx: &AsyncApp,
1264        ) -> Task<Result<ContextServerCommand>> {
1265            Task::ready(Ok(ContextServerCommand {
1266                path: self.path.clone(),
1267                args: vec!["arg1".to_string(), "arg2".to_string()],
1268                env: None,
1269            }))
1270        }
1271
1272        fn configuration(
1273            &self,
1274            _worktree_store: Entity<WorktreeStore>,
1275            _cx: &AsyncApp,
1276        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1277            Task::ready(Ok(None))
1278        }
1279    }
1280}