context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::{path::Path, sync::Arc};
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use futures::{FutureExt as _, future::join_all};
  10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  11use registry::ContextServerDescriptorRegistry;
  12use settings::{Settings as _, SettingsStore};
  13use util::ResultExt as _;
  14
  15use crate::{
  16    project_settings::{ContextServerSettings, ProjectSettings},
  17    worktree_store::WorktreeStore,
  18};
  19
  20pub fn init(cx: &mut App) {
  21    extension::init(cx);
  22}
  23
  24actions!(context_server, [Restart]);
  25
  26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  27pub enum ContextServerStatus {
  28    Starting,
  29    Running,
  30    Stopped,
  31    Error(Arc<str>),
  32}
  33
  34impl ContextServerStatus {
  35    fn from_state(state: &ContextServerState) -> Self {
  36        match state {
  37            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  38            ContextServerState::Running { .. } => ContextServerStatus::Running,
  39            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  40            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  41        }
  42    }
  43}
  44
  45enum ContextServerState {
  46    Starting {
  47        server: Arc<ContextServer>,
  48        configuration: Arc<ContextServerConfiguration>,
  49        _task: Task<()>,
  50    },
  51    Running {
  52        server: Arc<ContextServer>,
  53        configuration: Arc<ContextServerConfiguration>,
  54    },
  55    Stopped {
  56        server: Arc<ContextServer>,
  57        configuration: Arc<ContextServerConfiguration>,
  58    },
  59    Error {
  60        server: Arc<ContextServer>,
  61        configuration: Arc<ContextServerConfiguration>,
  62        error: Arc<str>,
  63    },
  64}
  65
  66impl ContextServerState {
  67    pub fn server(&self) -> Arc<ContextServer> {
  68        match self {
  69            ContextServerState::Starting { server, .. } => server.clone(),
  70            ContextServerState::Running { server, .. } => server.clone(),
  71            ContextServerState::Stopped { server, .. } => server.clone(),
  72            ContextServerState::Error { server, .. } => server.clone(),
  73        }
  74    }
  75
  76    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  77        match self {
  78            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  79            ContextServerState::Running { configuration, .. } => configuration.clone(),
  80            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  81            ContextServerState::Error { configuration, .. } => configuration.clone(),
  82        }
  83    }
  84}
  85
  86#[derive(Debug, PartialEq, Eq)]
  87pub enum ContextServerConfiguration {
  88    Custom {
  89        command: ContextServerCommand,
  90    },
  91    Extension {
  92        command: ContextServerCommand,
  93        settings: serde_json::Value,
  94    },
  95}
  96
  97impl ContextServerConfiguration {
  98    pub fn command(&self) -> &ContextServerCommand {
  99        match self {
 100            ContextServerConfiguration::Custom { command } => command,
 101            ContextServerConfiguration::Extension { command, .. } => command,
 102        }
 103    }
 104
 105    pub async fn from_settings(
 106        settings: ContextServerSettings,
 107        id: ContextServerId,
 108        registry: Entity<ContextServerDescriptorRegistry>,
 109        worktree_store: Entity<WorktreeStore>,
 110        cx: &AsyncApp,
 111    ) -> Option<Self> {
 112        match settings {
 113            ContextServerSettings::Custom {
 114                enabled: _,
 115                command,
 116            } => Some(ContextServerConfiguration::Custom { command }),
 117            ContextServerSettings::Extension {
 118                enabled: _,
 119                settings,
 120            } => {
 121                let descriptor = cx
 122                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 123                    .ok()
 124                    .flatten()?;
 125
 126                let command = descriptor.command(worktree_store, cx).await.log_err()?;
 127
 128                Some(ContextServerConfiguration::Extension { command, settings })
 129            }
 130        }
 131    }
 132}
 133
 134pub type ContextServerFactory =
 135    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 136
 137pub struct ContextServerStore {
 138    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 139    servers: HashMap<ContextServerId, ContextServerState>,
 140    worktree_store: Entity<WorktreeStore>,
 141    registry: Entity<ContextServerDescriptorRegistry>,
 142    update_servers_task: Option<Task<Result<()>>>,
 143    context_server_factory: Option<ContextServerFactory>,
 144    needs_server_update: bool,
 145    _subscriptions: Vec<Subscription>,
 146}
 147
 148pub enum Event {
 149    ServerStatusChanged {
 150        server_id: ContextServerId,
 151        status: ContextServerStatus,
 152    },
 153}
 154
 155impl EventEmitter<Event> for ContextServerStore {}
 156
 157impl ContextServerStore {
 158    pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
 159        Self::new_internal(
 160            true,
 161            None,
 162            ContextServerDescriptorRegistry::default_global(cx),
 163            worktree_store,
 164            cx,
 165        )
 166    }
 167
 168    #[cfg(any(test, feature = "test-support"))]
 169    pub fn test(
 170        registry: Entity<ContextServerDescriptorRegistry>,
 171        worktree_store: Entity<WorktreeStore>,
 172        cx: &mut Context<Self>,
 173    ) -> Self {
 174        Self::new_internal(false, None, registry, worktree_store, cx)
 175    }
 176
 177    #[cfg(any(test, feature = "test-support"))]
 178    pub fn test_maintain_server_loop(
 179        context_server_factory: ContextServerFactory,
 180        registry: Entity<ContextServerDescriptorRegistry>,
 181        worktree_store: Entity<WorktreeStore>,
 182        cx: &mut Context<Self>,
 183    ) -> Self {
 184        Self::new_internal(
 185            true,
 186            Some(context_server_factory),
 187            registry,
 188            worktree_store,
 189            cx,
 190        )
 191    }
 192
 193    fn new_internal(
 194        maintain_server_loop: bool,
 195        context_server_factory: Option<ContextServerFactory>,
 196        registry: Entity<ContextServerDescriptorRegistry>,
 197        worktree_store: Entity<WorktreeStore>,
 198        cx: &mut Context<Self>,
 199    ) -> Self {
 200        let subscriptions = if maintain_server_loop {
 201            vec![
 202                cx.observe(&registry, |this, _registry, cx| {
 203                    this.available_context_servers_changed(cx);
 204                }),
 205                cx.observe_global::<SettingsStore>(|this, cx| {
 206                    let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
 207                    if &this.context_server_settings == settings {
 208                        return;
 209                    }
 210                    this.context_server_settings = settings.clone();
 211                    this.available_context_servers_changed(cx);
 212                }),
 213            ]
 214        } else {
 215            Vec::new()
 216        };
 217
 218        let mut this = Self {
 219            _subscriptions: subscriptions,
 220            context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
 221                .clone(),
 222            worktree_store,
 223            registry,
 224            needs_server_update: false,
 225            servers: HashMap::default(),
 226            update_servers_task: None,
 227            context_server_factory,
 228        };
 229        if maintain_server_loop {
 230            this.available_context_servers_changed(cx);
 231        }
 232        this
 233    }
 234
 235    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 236        self.servers.get(id).map(|state| state.server())
 237    }
 238
 239    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 240        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 241            Some(server.clone())
 242        } else {
 243            None
 244        }
 245    }
 246
 247    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 248        self.servers.get(id).map(ContextServerStatus::from_state)
 249    }
 250
 251    pub fn configuration_for_server(
 252        &self,
 253        id: &ContextServerId,
 254    ) -> Option<Arc<ContextServerConfiguration>> {
 255        self.servers.get(id).map(|state| state.configuration())
 256    }
 257
 258    pub fn all_server_ids(&self) -> Vec<ContextServerId> {
 259        self.servers.keys().cloned().collect()
 260    }
 261
 262    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 263        self.servers
 264            .values()
 265            .filter_map(|state| {
 266                if let ContextServerState::Running { server, .. } = state {
 267                    Some(server.clone())
 268                } else {
 269                    None
 270                }
 271            })
 272            .collect()
 273    }
 274
 275    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 276        cx.spawn(async move |this, cx| {
 277            let this = this.upgrade().context("Context server store dropped")?;
 278            let settings = this
 279                .update(cx, |this, _| {
 280                    this.context_server_settings.get(&server.id().0).cloned()
 281                })
 282                .ok()
 283                .flatten()
 284                .context("Failed to get context server settings")?;
 285
 286            if !settings.enabled() {
 287                return Ok(());
 288            }
 289
 290            let (registry, worktree_store) = this.update(cx, |this, _| {
 291                (this.registry.clone(), this.worktree_store.clone())
 292            })?;
 293            let configuration = ContextServerConfiguration::from_settings(
 294                settings,
 295                server.id(),
 296                registry,
 297                worktree_store,
 298                cx,
 299            )
 300            .await
 301            .context("Failed to create context server configuration")?;
 302
 303            this.update(cx, |this, cx| {
 304                this.run_server(server, Arc::new(configuration), cx)
 305            })
 306        })
 307        .detach_and_log_err(cx);
 308    }
 309
 310    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 311        if matches!(
 312            self.servers.get(id),
 313            Some(ContextServerState::Stopped { .. })
 314        ) {
 315            return Ok(());
 316        }
 317
 318        let state = self
 319            .servers
 320            .remove(id)
 321            .context("Context server not found")?;
 322
 323        let server = state.server();
 324        let configuration = state.configuration();
 325        let mut result = Ok(());
 326        if let ContextServerState::Running { server, .. } = &state {
 327            result = server.stop();
 328        }
 329        drop(state);
 330
 331        self.update_server_state(
 332            id.clone(),
 333            ContextServerState::Stopped {
 334                configuration,
 335                server,
 336            },
 337            cx,
 338        );
 339
 340        result
 341    }
 342
 343    pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 344        if let Some(state) = self.servers.get(&id) {
 345            let configuration = state.configuration();
 346
 347            self.stop_server(&state.server().id(), cx)?;
 348            let new_server = self.create_context_server(id.clone(), configuration.clone())?;
 349            self.run_server(new_server, configuration, cx);
 350        }
 351        Ok(())
 352    }
 353
 354    fn run_server(
 355        &mut self,
 356        server: Arc<ContextServer>,
 357        configuration: Arc<ContextServerConfiguration>,
 358        cx: &mut Context<Self>,
 359    ) {
 360        let id = server.id();
 361        if matches!(
 362            self.servers.get(&id),
 363            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 364        ) {
 365            self.stop_server(&id, cx).log_err();
 366        }
 367
 368        let task = cx.spawn({
 369            let id = server.id();
 370            let server = server.clone();
 371            let configuration = configuration.clone();
 372            async move |this, cx| {
 373                match server.clone().start(&cx).await {
 374                    Ok(_) => {
 375                        log::info!("Started {} context server", id);
 376                        debug_assert!(server.client().is_some());
 377
 378                        this.update(cx, |this, cx| {
 379                            this.update_server_state(
 380                                id.clone(),
 381                                ContextServerState::Running {
 382                                    server,
 383                                    configuration,
 384                                },
 385                                cx,
 386                            )
 387                        })
 388                        .log_err()
 389                    }
 390                    Err(err) => {
 391                        log::error!("{} context server failed to start: {}", id, err);
 392                        this.update(cx, |this, cx| {
 393                            this.update_server_state(
 394                                id.clone(),
 395                                ContextServerState::Error {
 396                                    configuration,
 397                                    server,
 398                                    error: err.to_string().into(),
 399                                },
 400                                cx,
 401                            )
 402                        })
 403                        .log_err()
 404                    }
 405                };
 406            }
 407        });
 408
 409        self.update_server_state(
 410            id.clone(),
 411            ContextServerState::Starting {
 412                configuration,
 413                _task: task,
 414                server,
 415            },
 416            cx,
 417        );
 418    }
 419
 420    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 421        let state = self
 422            .servers
 423            .remove(id)
 424            .context("Context server not found")?;
 425        drop(state);
 426        cx.emit(Event::ServerStatusChanged {
 427            server_id: id.clone(),
 428            status: ContextServerStatus::Stopped,
 429        });
 430        Ok(())
 431    }
 432
 433    fn create_context_server(
 434        &self,
 435        id: ContextServerId,
 436        configuration: Arc<ContextServerConfiguration>,
 437    ) -> Result<Arc<ContextServer>> {
 438        if let Some(factory) = self.context_server_factory.as_ref() {
 439            Ok(factory(id, configuration))
 440        } else {
 441            Ok(Arc::new(ContextServer::stdio(
 442                id,
 443                configuration.command().clone(),
 444            )))
 445        }
 446    }
 447
 448    fn resolve_context_server_settings<'a>(
 449        worktree_store: &'a Entity<WorktreeStore>,
 450        cx: &'a App,
 451    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 452        let location = worktree_store
 453            .read(cx)
 454            .visible_worktrees(cx)
 455            .next()
 456            .map(|worktree| settings::SettingsLocation {
 457                worktree_id: worktree.read(cx).id(),
 458                path: Path::new(""),
 459            });
 460        &ProjectSettings::get(location, cx).context_servers
 461    }
 462
 463    fn update_server_state(
 464        &mut self,
 465        id: ContextServerId,
 466        state: ContextServerState,
 467        cx: &mut Context<Self>,
 468    ) {
 469        let status = ContextServerStatus::from_state(&state);
 470        self.servers.insert(id.clone(), state);
 471        cx.emit(Event::ServerStatusChanged {
 472            server_id: id,
 473            status,
 474        });
 475    }
 476
 477    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 478        if self.update_servers_task.is_some() {
 479            self.needs_server_update = true;
 480        } else {
 481            self.needs_server_update = false;
 482            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 483                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 484                    log::error!("Error maintaining context servers: {}", err);
 485                }
 486
 487                this.update(cx, |this, cx| {
 488                    this.update_servers_task.take();
 489                    if this.needs_server_update {
 490                        this.available_context_servers_changed(cx);
 491                    }
 492                })?;
 493
 494                Ok(())
 495            }));
 496        }
 497    }
 498
 499    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 500        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
 501            (
 502                this.context_server_settings.clone(),
 503                this.registry.clone(),
 504                this.worktree_store.clone(),
 505            )
 506        })?;
 507
 508        for (id, _) in
 509            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 510        {
 511            configured_servers
 512                .entry(id)
 513                .or_insert(ContextServerSettings::default_extension());
 514        }
 515
 516        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
 517            configured_servers
 518                .into_iter()
 519                .partition(|(_, settings)| settings.enabled());
 520
 521        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
 522            let id = ContextServerId(id);
 523            ContextServerConfiguration::from_settings(
 524                settings,
 525                id.clone(),
 526                registry.clone(),
 527                worktree_store.clone(),
 528                cx,
 529            )
 530            .map(|config| (id, config))
 531        }))
 532        .await
 533        .into_iter()
 534        .filter_map(|(id, config)| config.map(|config| (id, config)))
 535        .collect::<HashMap<_, _>>();
 536
 537        let mut servers_to_start = Vec::new();
 538        let mut servers_to_remove = HashSet::default();
 539        let mut servers_to_stop = HashSet::default();
 540
 541        this.update(cx, |this, _cx| {
 542            for server_id in this.servers.keys() {
 543                // All servers that are not in desired_servers should be removed from the store.
 544                // This can happen if the user removed a server from the context server settings.
 545                if !configured_servers.contains_key(&server_id) {
 546                    if disabled_servers.contains_key(&server_id.0) {
 547                        servers_to_stop.insert(server_id.clone());
 548                    } else {
 549                        servers_to_remove.insert(server_id.clone());
 550                    }
 551                }
 552            }
 553
 554            for (id, config) in configured_servers {
 555                let state = this.servers.get(&id);
 556                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
 557                let existing_config = state.as_ref().map(|state| state.configuration());
 558                if existing_config.as_deref() != Some(&config) || is_stopped {
 559                    let config = Arc::new(config);
 560                    if let Some(server) = this
 561                        .create_context_server(id.clone(), config.clone())
 562                        .log_err()
 563                    {
 564                        servers_to_start.push((server, config));
 565                        if this.servers.contains_key(&id) {
 566                            servers_to_stop.insert(id);
 567                        }
 568                    }
 569                }
 570            }
 571        })?;
 572
 573        this.update(cx, |this, cx| {
 574            for id in servers_to_stop {
 575                this.stop_server(&id, cx)?;
 576            }
 577            for id in servers_to_remove {
 578                this.remove_server(&id, cx)?;
 579            }
 580            for (server, config) in servers_to_start {
 581                this.run_server(server, config, cx);
 582            }
 583            anyhow::Ok(())
 584        })?
 585    }
 586}
 587
 588#[cfg(test)]
 589mod tests {
 590    use super::*;
 591    use crate::{
 592        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 593        project_settings::ProjectSettings,
 594    };
 595    use context_server::test::create_fake_transport;
 596    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 597    use serde_json::json;
 598    use std::{cell::RefCell, rc::Rc};
 599    use util::path;
 600
 601    #[gpui::test]
 602    async fn test_context_server_status(cx: &mut TestAppContext) {
 603        const SERVER_1_ID: &'static str = "mcp-1";
 604        const SERVER_2_ID: &'static str = "mcp-2";
 605
 606        let (_fs, project) = setup_context_server_test(
 607            cx,
 608            json!({"code.rs": ""}),
 609            vec![
 610                (SERVER_1_ID.into(), dummy_server_settings()),
 611                (SERVER_2_ID.into(), dummy_server_settings()),
 612            ],
 613        )
 614        .await;
 615
 616        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 617        let store = cx.new(|cx| {
 618            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 619        });
 620
 621        let server_1_id = ContextServerId(SERVER_1_ID.into());
 622        let server_2_id = ContextServerId(SERVER_2_ID.into());
 623
 624        let server_1 = Arc::new(ContextServer::new(
 625            server_1_id.clone(),
 626            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 627        ));
 628        let server_2 = Arc::new(ContextServer::new(
 629            server_2_id.clone(),
 630            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 631        ));
 632
 633        store.update(cx, |store, cx| store.start_server(server_1, cx));
 634
 635        cx.run_until_parked();
 636
 637        cx.update(|cx| {
 638            assert_eq!(
 639                store.read(cx).status_for_server(&server_1_id),
 640                Some(ContextServerStatus::Running)
 641            );
 642            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 643        });
 644
 645        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 646
 647        cx.run_until_parked();
 648
 649        cx.update(|cx| {
 650            assert_eq!(
 651                store.read(cx).status_for_server(&server_1_id),
 652                Some(ContextServerStatus::Running)
 653            );
 654            assert_eq!(
 655                store.read(cx).status_for_server(&server_2_id),
 656                Some(ContextServerStatus::Running)
 657            );
 658        });
 659
 660        store
 661            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 662            .unwrap();
 663
 664        cx.update(|cx| {
 665            assert_eq!(
 666                store.read(cx).status_for_server(&server_1_id),
 667                Some(ContextServerStatus::Running)
 668            );
 669            assert_eq!(
 670                store.read(cx).status_for_server(&server_2_id),
 671                Some(ContextServerStatus::Stopped)
 672            );
 673        });
 674    }
 675
 676    #[gpui::test]
 677    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 678        const SERVER_1_ID: &'static str = "mcp-1";
 679        const SERVER_2_ID: &'static str = "mcp-2";
 680
 681        let (_fs, project) = setup_context_server_test(
 682            cx,
 683            json!({"code.rs": ""}),
 684            vec![
 685                (SERVER_1_ID.into(), dummy_server_settings()),
 686                (SERVER_2_ID.into(), dummy_server_settings()),
 687            ],
 688        )
 689        .await;
 690
 691        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 692        let store = cx.new(|cx| {
 693            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 694        });
 695
 696        let server_1_id = ContextServerId(SERVER_1_ID.into());
 697        let server_2_id = ContextServerId(SERVER_2_ID.into());
 698
 699        let server_1 = Arc::new(ContextServer::new(
 700            server_1_id.clone(),
 701            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 702        ));
 703        let server_2 = Arc::new(ContextServer::new(
 704            server_2_id.clone(),
 705            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 706        ));
 707
 708        let _server_events = assert_server_events(
 709            &store,
 710            vec![
 711                (server_1_id.clone(), ContextServerStatus::Starting),
 712                (server_1_id.clone(), ContextServerStatus::Running),
 713                (server_2_id.clone(), ContextServerStatus::Starting),
 714                (server_2_id.clone(), ContextServerStatus::Running),
 715                (server_2_id.clone(), ContextServerStatus::Stopped),
 716            ],
 717            cx,
 718        );
 719
 720        store.update(cx, |store, cx| store.start_server(server_1, cx));
 721
 722        cx.run_until_parked();
 723
 724        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 725
 726        cx.run_until_parked();
 727
 728        store
 729            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 730            .unwrap();
 731    }
 732
 733    #[gpui::test(iterations = 25)]
 734    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 735        const SERVER_1_ID: &'static str = "mcp-1";
 736
 737        let (_fs, project) = setup_context_server_test(
 738            cx,
 739            json!({"code.rs": ""}),
 740            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 741        )
 742        .await;
 743
 744        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 745        let store = cx.new(|cx| {
 746            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 747        });
 748
 749        let server_id = ContextServerId(SERVER_1_ID.into());
 750
 751        let server_with_same_id_1 = Arc::new(ContextServer::new(
 752            server_id.clone(),
 753            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 754        ));
 755        let server_with_same_id_2 = Arc::new(ContextServer::new(
 756            server_id.clone(),
 757            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 758        ));
 759
 760        // If we start another server with the same id, we should report that we stopped the previous one
 761        let _server_events = assert_server_events(
 762            &store,
 763            vec![
 764                (server_id.clone(), ContextServerStatus::Starting),
 765                (server_id.clone(), ContextServerStatus::Stopped),
 766                (server_id.clone(), ContextServerStatus::Starting),
 767                (server_id.clone(), ContextServerStatus::Running),
 768            ],
 769            cx,
 770        );
 771
 772        store.update(cx, |store, cx| {
 773            store.start_server(server_with_same_id_1.clone(), cx)
 774        });
 775        store.update(cx, |store, cx| {
 776            store.start_server(server_with_same_id_2.clone(), cx)
 777        });
 778
 779        cx.run_until_parked();
 780
 781        cx.update(|cx| {
 782            assert_eq!(
 783                store.read(cx).status_for_server(&server_id),
 784                Some(ContextServerStatus::Running)
 785            );
 786        });
 787    }
 788
 789    #[gpui::test]
 790    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 791        const SERVER_1_ID: &'static str = "mcp-1";
 792        const SERVER_2_ID: &'static str = "mcp-2";
 793
 794        let server_1_id = ContextServerId(SERVER_1_ID.into());
 795        let server_2_id = ContextServerId(SERVER_2_ID.into());
 796
 797        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 798
 799        let (_fs, project) = setup_context_server_test(
 800            cx,
 801            json!({"code.rs": ""}),
 802            vec![(
 803                SERVER_1_ID.into(),
 804                ContextServerSettings::Extension {
 805                    enabled: true,
 806                    settings: json!({
 807                        "somevalue": true
 808                    }),
 809                },
 810            )],
 811        )
 812        .await;
 813
 814        let executor = cx.executor();
 815        let registry = cx.new(|_| {
 816            let mut registry = ContextServerDescriptorRegistry::new();
 817            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1);
 818            registry
 819        });
 820        let store = cx.new(|cx| {
 821            ContextServerStore::test_maintain_server_loop(
 822                Box::new(move |id, _| {
 823                    Arc::new(ContextServer::new(
 824                        id.clone(),
 825                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 826                    ))
 827                }),
 828                registry.clone(),
 829                project.read(cx).worktree_store(),
 830                cx,
 831            )
 832        });
 833
 834        // Ensure that mcp-1 starts up
 835        {
 836            let _server_events = assert_server_events(
 837                &store,
 838                vec![
 839                    (server_1_id.clone(), ContextServerStatus::Starting),
 840                    (server_1_id.clone(), ContextServerStatus::Running),
 841                ],
 842                cx,
 843            );
 844            cx.run_until_parked();
 845        }
 846
 847        // Ensure that mcp-1 is restarted when the configuration was changed
 848        {
 849            let _server_events = assert_server_events(
 850                &store,
 851                vec![
 852                    (server_1_id.clone(), ContextServerStatus::Stopped),
 853                    (server_1_id.clone(), ContextServerStatus::Starting),
 854                    (server_1_id.clone(), ContextServerStatus::Running),
 855                ],
 856                cx,
 857            );
 858            set_context_server_configuration(
 859                vec![(
 860                    server_1_id.0.clone(),
 861                    ContextServerSettings::Extension {
 862                        enabled: true,
 863                        settings: json!({
 864                            "somevalue": false
 865                        }),
 866                    },
 867                )],
 868                cx,
 869            );
 870
 871            cx.run_until_parked();
 872        }
 873
 874        // Ensure that mcp-1 is not restarted when the configuration was not changed
 875        {
 876            let _server_events = assert_server_events(&store, vec![], cx);
 877            set_context_server_configuration(
 878                vec![(
 879                    server_1_id.0.clone(),
 880                    ContextServerSettings::Extension {
 881                        enabled: true,
 882                        settings: json!({
 883                            "somevalue": false
 884                        }),
 885                    },
 886                )],
 887                cx,
 888            );
 889
 890            cx.run_until_parked();
 891        }
 892
 893        // Ensure that mcp-2 is started once it is added to the settings
 894        {
 895            let _server_events = assert_server_events(
 896                &store,
 897                vec![
 898                    (server_2_id.clone(), ContextServerStatus::Starting),
 899                    (server_2_id.clone(), ContextServerStatus::Running),
 900                ],
 901                cx,
 902            );
 903            set_context_server_configuration(
 904                vec![
 905                    (
 906                        server_1_id.0.clone(),
 907                        ContextServerSettings::Extension {
 908                            enabled: true,
 909                            settings: json!({
 910                                "somevalue": false
 911                            }),
 912                        },
 913                    ),
 914                    (
 915                        server_2_id.0.clone(),
 916                        ContextServerSettings::Custom {
 917                            enabled: true,
 918                            command: ContextServerCommand {
 919                                path: "somebinary".to_string(),
 920                                args: vec!["arg".to_string()],
 921                                env: None,
 922                            },
 923                        },
 924                    ),
 925                ],
 926                cx,
 927            );
 928
 929            cx.run_until_parked();
 930        }
 931
 932        // Ensure that mcp-2 is restarted once the args have changed
 933        {
 934            let _server_events = assert_server_events(
 935                &store,
 936                vec![
 937                    (server_2_id.clone(), ContextServerStatus::Stopped),
 938                    (server_2_id.clone(), ContextServerStatus::Starting),
 939                    (server_2_id.clone(), ContextServerStatus::Running),
 940                ],
 941                cx,
 942            );
 943            set_context_server_configuration(
 944                vec![
 945                    (
 946                        server_1_id.0.clone(),
 947                        ContextServerSettings::Extension {
 948                            enabled: true,
 949                            settings: json!({
 950                                "somevalue": false
 951                            }),
 952                        },
 953                    ),
 954                    (
 955                        server_2_id.0.clone(),
 956                        ContextServerSettings::Custom {
 957                            enabled: true,
 958                            command: ContextServerCommand {
 959                                path: "somebinary".to_string(),
 960                                args: vec!["anotherArg".to_string()],
 961                                env: None,
 962                            },
 963                        },
 964                    ),
 965                ],
 966                cx,
 967            );
 968
 969            cx.run_until_parked();
 970        }
 971
 972        // Ensure that mcp-2 is removed once it is removed from the settings
 973        {
 974            let _server_events = assert_server_events(
 975                &store,
 976                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
 977                cx,
 978            );
 979            set_context_server_configuration(
 980                vec![(
 981                    server_1_id.0.clone(),
 982                    ContextServerSettings::Extension {
 983                        enabled: true,
 984                        settings: json!({
 985                            "somevalue": false
 986                        }),
 987                    },
 988                )],
 989                cx,
 990            );
 991
 992            cx.run_until_parked();
 993
 994            cx.update(|cx| {
 995                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 996            });
 997        }
 998
 999        // Ensure that nothing happens if the settings do not change
1000        {
1001            let _server_events = assert_server_events(&store, vec![], cx);
1002            set_context_server_configuration(
1003                vec![(
1004                    server_1_id.0.clone(),
1005                    ContextServerSettings::Extension {
1006                        enabled: true,
1007                        settings: json!({
1008                            "somevalue": false
1009                        }),
1010                    },
1011                )],
1012                cx,
1013            );
1014
1015            cx.run_until_parked();
1016
1017            cx.update(|cx| {
1018                assert_eq!(
1019                    store.read(cx).status_for_server(&server_1_id),
1020                    Some(ContextServerStatus::Running)
1021                );
1022                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1023            });
1024        }
1025    }
1026
1027    #[gpui::test]
1028    async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1029        const SERVER_1_ID: &'static str = "mcp-1";
1030
1031        let server_1_id = ContextServerId(SERVER_1_ID.into());
1032
1033        let (_fs, project) = setup_context_server_test(
1034            cx,
1035            json!({"code.rs": ""}),
1036            vec![(
1037                SERVER_1_ID.into(),
1038                ContextServerSettings::Custom {
1039                    enabled: true,
1040                    command: ContextServerCommand {
1041                        path: "somebinary".to_string(),
1042                        args: vec!["arg".to_string()],
1043                        env: None,
1044                    },
1045                },
1046            )],
1047        )
1048        .await;
1049
1050        let executor = cx.executor();
1051        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1052        let store = cx.new(|cx| {
1053            ContextServerStore::test_maintain_server_loop(
1054                Box::new(move |id, _| {
1055                    Arc::new(ContextServer::new(
1056                        id.clone(),
1057                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1058                    ))
1059                }),
1060                registry.clone(),
1061                project.read(cx).worktree_store(),
1062                cx,
1063            )
1064        });
1065
1066        // Ensure that mcp-1 starts up
1067        {
1068            let _server_events = assert_server_events(
1069                &store,
1070                vec![
1071                    (server_1_id.clone(), ContextServerStatus::Starting),
1072                    (server_1_id.clone(), ContextServerStatus::Running),
1073                ],
1074                cx,
1075            );
1076            cx.run_until_parked();
1077        }
1078
1079        // Ensure that mcp-1 is stopped once it is disabled.
1080        {
1081            let _server_events = assert_server_events(
1082                &store,
1083                vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1084                cx,
1085            );
1086            set_context_server_configuration(
1087                vec![(
1088                    server_1_id.0.clone(),
1089                    ContextServerSettings::Custom {
1090                        enabled: false,
1091                        command: ContextServerCommand {
1092                            path: "somebinary".to_string(),
1093                            args: vec!["arg".to_string()],
1094                            env: None,
1095                        },
1096                    },
1097                )],
1098                cx,
1099            );
1100
1101            cx.run_until_parked();
1102        }
1103
1104        // Ensure that mcp-1 is started once it is enabled again.
1105        {
1106            let _server_events = assert_server_events(
1107                &store,
1108                vec![
1109                    (server_1_id.clone(), ContextServerStatus::Starting),
1110                    (server_1_id.clone(), ContextServerStatus::Running),
1111                ],
1112                cx,
1113            );
1114            set_context_server_configuration(
1115                vec![(
1116                    server_1_id.0.clone(),
1117                    ContextServerSettings::Custom {
1118                        enabled: true,
1119                        command: ContextServerCommand {
1120                            path: "somebinary".to_string(),
1121                            args: vec!["arg".to_string()],
1122                            env: None,
1123                        },
1124                    },
1125                )],
1126                cx,
1127            );
1128
1129            cx.run_until_parked();
1130        }
1131    }
1132
1133    fn set_context_server_configuration(
1134        context_servers: Vec<(Arc<str>, ContextServerSettings)>,
1135        cx: &mut TestAppContext,
1136    ) {
1137        cx.update(|cx| {
1138            SettingsStore::update_global(cx, |store, cx| {
1139                let mut settings = ProjectSettings::default();
1140                for (id, config) in context_servers {
1141                    settings.context_servers.insert(id, config);
1142                }
1143                store
1144                    .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
1145                    .unwrap();
1146            })
1147        });
1148    }
1149
1150    struct ServerEvents {
1151        received_event_count: Rc<RefCell<usize>>,
1152        expected_event_count: usize,
1153        _subscription: Subscription,
1154    }
1155
1156    impl Drop for ServerEvents {
1157        fn drop(&mut self) {
1158            let actual_event_count = *self.received_event_count.borrow();
1159            assert_eq!(
1160                actual_event_count, self.expected_event_count,
1161                "
1162                Expected to receive {} context server store events, but received {} events",
1163                self.expected_event_count, actual_event_count
1164            );
1165        }
1166    }
1167
1168    fn dummy_server_settings() -> ContextServerSettings {
1169        ContextServerSettings::Custom {
1170            enabled: true,
1171            command: ContextServerCommand {
1172                path: "somebinary".to_string(),
1173                args: vec!["arg".to_string()],
1174                env: None,
1175            },
1176        }
1177    }
1178
1179    fn assert_server_events(
1180        store: &Entity<ContextServerStore>,
1181        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1182        cx: &mut TestAppContext,
1183    ) -> ServerEvents {
1184        cx.update(|cx| {
1185            let mut ix = 0;
1186            let received_event_count = Rc::new(RefCell::new(0));
1187            let expected_event_count = expected_events.len();
1188            let subscription = cx.subscribe(store, {
1189                let received_event_count = received_event_count.clone();
1190                move |_, event, _| match event {
1191                    Event::ServerStatusChanged {
1192                        server_id: actual_server_id,
1193                        status: actual_status,
1194                    } => {
1195                        let (expected_server_id, expected_status) = &expected_events[ix];
1196
1197                        assert_eq!(
1198                            actual_server_id, expected_server_id,
1199                            "Expected different server id at index {}",
1200                            ix
1201                        );
1202                        assert_eq!(
1203                            actual_status, expected_status,
1204                            "Expected different status at index {}",
1205                            ix
1206                        );
1207                        ix += 1;
1208                        *received_event_count.borrow_mut() += 1;
1209                    }
1210                }
1211            });
1212            ServerEvents {
1213                expected_event_count,
1214                received_event_count,
1215                _subscription: subscription,
1216            }
1217        })
1218    }
1219
1220    async fn setup_context_server_test(
1221        cx: &mut TestAppContext,
1222        files: serde_json::Value,
1223        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1224    ) -> (Arc<FakeFs>, Entity<Project>) {
1225        cx.update(|cx| {
1226            let settings_store = SettingsStore::test(cx);
1227            cx.set_global(settings_store);
1228            Project::init_settings(cx);
1229            let mut settings = ProjectSettings::get_global(cx).clone();
1230            for (id, config) in context_server_configurations {
1231                settings.context_servers.insert(id, config);
1232            }
1233            ProjectSettings::override_global(settings, cx);
1234        });
1235
1236        let fs = FakeFs::new(cx.executor());
1237        fs.insert_tree(path!("/test"), files).await;
1238        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1239
1240        (fs, project)
1241    }
1242
1243    struct FakeContextServerDescriptor {
1244        path: String,
1245    }
1246
1247    impl FakeContextServerDescriptor {
1248        fn new(path: impl Into<String>) -> Self {
1249            Self { path: path.into() }
1250        }
1251    }
1252
1253    impl ContextServerDescriptor for FakeContextServerDescriptor {
1254        fn command(
1255            &self,
1256            _worktree_store: Entity<WorktreeStore>,
1257            _cx: &AsyncApp,
1258        ) -> Task<Result<ContextServerCommand>> {
1259            Task::ready(Ok(ContextServerCommand {
1260                path: self.path.clone(),
1261                args: vec!["arg1".to_string(), "arg2".to_string()],
1262                env: None,
1263            }))
1264        }
1265
1266        fn configuration(
1267            &self,
1268            _worktree_store: Entity<WorktreeStore>,
1269            _cx: &AsyncApp,
1270        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1271            Task::ready(Ok(None))
1272        }
1273    }
1274}