context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::{path::Path, sync::Arc};
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use futures::{FutureExt as _, future::join_all};
  10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  11use registry::ContextServerDescriptorRegistry;
  12use settings::{Settings as _, SettingsStore};
  13use util::ResultExt as _;
  14
  15use crate::{
  16    Project,
  17    project_settings::{ContextServerSettings, ProjectSettings},
  18    worktree_store::WorktreeStore,
  19};
  20
  21pub fn init(cx: &mut App) {
  22    extension::init(cx);
  23}
  24
  25actions!(
  26    context_server,
  27    [
  28        /// Restarts the context server.
  29        Restart
  30    ]
  31);
  32
  33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  34pub enum ContextServerStatus {
  35    Starting,
  36    Running,
  37    Stopped,
  38    Error(Arc<str>),
  39}
  40
  41impl ContextServerStatus {
  42    fn from_state(state: &ContextServerState) -> Self {
  43        match state {
  44            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  45            ContextServerState::Running { .. } => ContextServerStatus::Running,
  46            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  47            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  48        }
  49    }
  50}
  51
  52enum ContextServerState {
  53    Starting {
  54        server: Arc<ContextServer>,
  55        configuration: Arc<ContextServerConfiguration>,
  56        _task: Task<()>,
  57    },
  58    Running {
  59        server: Arc<ContextServer>,
  60        configuration: Arc<ContextServerConfiguration>,
  61    },
  62    Stopped {
  63        server: Arc<ContextServer>,
  64        configuration: Arc<ContextServerConfiguration>,
  65    },
  66    Error {
  67        server: Arc<ContextServer>,
  68        configuration: Arc<ContextServerConfiguration>,
  69        error: Arc<str>,
  70    },
  71}
  72
  73impl ContextServerState {
  74    pub fn server(&self) -> Arc<ContextServer> {
  75        match self {
  76            ContextServerState::Starting { server, .. } => server.clone(),
  77            ContextServerState::Running { server, .. } => server.clone(),
  78            ContextServerState::Stopped { server, .. } => server.clone(),
  79            ContextServerState::Error { server, .. } => server.clone(),
  80        }
  81    }
  82
  83    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  84        match self {
  85            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  86            ContextServerState::Running { configuration, .. } => configuration.clone(),
  87            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  88            ContextServerState::Error { configuration, .. } => configuration.clone(),
  89        }
  90    }
  91}
  92
  93#[derive(Debug, PartialEq, Eq)]
  94pub enum ContextServerConfiguration {
  95    Custom {
  96        command: ContextServerCommand,
  97    },
  98    Extension {
  99        command: ContextServerCommand,
 100        settings: serde_json::Value,
 101    },
 102}
 103
 104impl ContextServerConfiguration {
 105    pub fn command(&self) -> &ContextServerCommand {
 106        match self {
 107            ContextServerConfiguration::Custom { command } => command,
 108            ContextServerConfiguration::Extension { command, .. } => command,
 109        }
 110    }
 111
 112    pub async fn from_settings(
 113        settings: ContextServerSettings,
 114        id: ContextServerId,
 115        registry: Entity<ContextServerDescriptorRegistry>,
 116        worktree_store: Entity<WorktreeStore>,
 117        cx: &AsyncApp,
 118    ) -> Option<Self> {
 119        match settings {
 120            ContextServerSettings::Custom {
 121                enabled: _,
 122                command,
 123            } => Some(ContextServerConfiguration::Custom { command }),
 124            ContextServerSettings::Extension {
 125                enabled: _,
 126                settings,
 127            } => {
 128                let descriptor = cx
 129                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 130                    .ok()
 131                    .flatten()?;
 132
 133                let command = descriptor.command(worktree_store, cx).await.log_err()?;
 134
 135                Some(ContextServerConfiguration::Extension { command, settings })
 136            }
 137        }
 138    }
 139}
 140
 141pub type ContextServerFactory =
 142    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 143
 144pub struct ContextServerStore {
 145    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 146    servers: HashMap<ContextServerId, ContextServerState>,
 147    worktree_store: Entity<WorktreeStore>,
 148    project: WeakEntity<Project>,
 149    registry: Entity<ContextServerDescriptorRegistry>,
 150    update_servers_task: Option<Task<Result<()>>>,
 151    context_server_factory: Option<ContextServerFactory>,
 152    needs_server_update: bool,
 153    _subscriptions: Vec<Subscription>,
 154}
 155
 156pub enum Event {
 157    ServerStatusChanged {
 158        server_id: ContextServerId,
 159        status: ContextServerStatus,
 160    },
 161}
 162
 163impl EventEmitter<Event> for ContextServerStore {}
 164
 165impl ContextServerStore {
 166    pub fn new(
 167        worktree_store: Entity<WorktreeStore>,
 168        weak_project: WeakEntity<Project>,
 169        cx: &mut Context<Self>,
 170    ) -> Self {
 171        Self::new_internal(
 172            true,
 173            None,
 174            ContextServerDescriptorRegistry::default_global(cx),
 175            worktree_store,
 176            weak_project,
 177            cx,
 178        )
 179    }
 180
 181    /// Returns all configured context server ids, regardless of enabled state.
 182    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 183        self.context_server_settings
 184            .keys()
 185            .cloned()
 186            .map(ContextServerId)
 187            .collect()
 188    }
 189
 190    #[cfg(any(test, feature = "test-support"))]
 191    pub fn test(
 192        registry: Entity<ContextServerDescriptorRegistry>,
 193        worktree_store: Entity<WorktreeStore>,
 194        weak_project: WeakEntity<Project>,
 195        cx: &mut Context<Self>,
 196    ) -> Self {
 197        Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
 198    }
 199
 200    #[cfg(any(test, feature = "test-support"))]
 201    pub fn test_maintain_server_loop(
 202        context_server_factory: ContextServerFactory,
 203        registry: Entity<ContextServerDescriptorRegistry>,
 204        worktree_store: Entity<WorktreeStore>,
 205        weak_project: WeakEntity<Project>,
 206        cx: &mut Context<Self>,
 207    ) -> Self {
 208        Self::new_internal(
 209            true,
 210            Some(context_server_factory),
 211            registry,
 212            worktree_store,
 213            weak_project,
 214            cx,
 215        )
 216    }
 217
 218    fn new_internal(
 219        maintain_server_loop: bool,
 220        context_server_factory: Option<ContextServerFactory>,
 221        registry: Entity<ContextServerDescriptorRegistry>,
 222        worktree_store: Entity<WorktreeStore>,
 223        weak_project: WeakEntity<Project>,
 224        cx: &mut Context<Self>,
 225    ) -> Self {
 226        let subscriptions = if maintain_server_loop {
 227            vec![
 228                cx.observe(&registry, |this, _registry, cx| {
 229                    this.available_context_servers_changed(cx);
 230                }),
 231                cx.observe_global::<SettingsStore>(|this, cx| {
 232                    let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
 233                    if &this.context_server_settings == settings {
 234                        return;
 235                    }
 236                    this.context_server_settings = settings.clone();
 237                    this.available_context_servers_changed(cx);
 238                }),
 239            ]
 240        } else {
 241            Vec::new()
 242        };
 243
 244        let mut this = Self {
 245            _subscriptions: subscriptions,
 246            context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
 247                .clone(),
 248            worktree_store,
 249            project: weak_project,
 250            registry,
 251            needs_server_update: false,
 252            servers: HashMap::default(),
 253            update_servers_task: None,
 254            context_server_factory,
 255        };
 256        if maintain_server_loop {
 257            this.available_context_servers_changed(cx);
 258        }
 259        this
 260    }
 261
 262    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 263        self.servers.get(id).map(|state| state.server())
 264    }
 265
 266    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 267        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 268            Some(server.clone())
 269        } else {
 270            None
 271        }
 272    }
 273
 274    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 275        self.servers.get(id).map(ContextServerStatus::from_state)
 276    }
 277
 278    pub fn configuration_for_server(
 279        &self,
 280        id: &ContextServerId,
 281    ) -> Option<Arc<ContextServerConfiguration>> {
 282        self.servers.get(id).map(|state| state.configuration())
 283    }
 284
 285    pub fn all_server_ids(&self) -> Vec<ContextServerId> {
 286        self.servers.keys().cloned().collect()
 287    }
 288
 289    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 290        self.servers
 291            .values()
 292            .filter_map(|state| {
 293                if let ContextServerState::Running { server, .. } = state {
 294                    Some(server.clone())
 295                } else {
 296                    None
 297                }
 298            })
 299            .collect()
 300    }
 301
 302    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 303        cx.spawn(async move |this, cx| {
 304            let this = this.upgrade().context("Context server store dropped")?;
 305            let settings = this
 306                .update(cx, |this, _| {
 307                    this.context_server_settings.get(&server.id().0).cloned()
 308                })
 309                .ok()
 310                .flatten()
 311                .context("Failed to get context server settings")?;
 312
 313            if !settings.enabled() {
 314                return Ok(());
 315            }
 316
 317            let (registry, worktree_store) = this.update(cx, |this, _| {
 318                (this.registry.clone(), this.worktree_store.clone())
 319            })?;
 320            let configuration = ContextServerConfiguration::from_settings(
 321                settings,
 322                server.id(),
 323                registry,
 324                worktree_store,
 325                cx,
 326            )
 327            .await
 328            .context("Failed to create context server configuration")?;
 329
 330            this.update(cx, |this, cx| {
 331                this.run_server(server, Arc::new(configuration), cx)
 332            })
 333        })
 334        .detach_and_log_err(cx);
 335    }
 336
 337    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 338        if matches!(
 339            self.servers.get(id),
 340            Some(ContextServerState::Stopped { .. })
 341        ) {
 342            return Ok(());
 343        }
 344
 345        let state = self
 346            .servers
 347            .remove(id)
 348            .context("Context server not found")?;
 349
 350        let server = state.server();
 351        let configuration = state.configuration();
 352        let mut result = Ok(());
 353        if let ContextServerState::Running { server, .. } = &state {
 354            result = server.stop();
 355        }
 356        drop(state);
 357
 358        self.update_server_state(
 359            id.clone(),
 360            ContextServerState::Stopped {
 361                configuration,
 362                server,
 363            },
 364            cx,
 365        );
 366
 367        result
 368    }
 369
 370    pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 371        if let Some(state) = self.servers.get(id) {
 372            let configuration = state.configuration();
 373
 374            self.stop_server(&state.server().id(), cx)?;
 375            let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
 376            self.run_server(new_server, configuration, cx);
 377        }
 378        Ok(())
 379    }
 380
 381    fn run_server(
 382        &mut self,
 383        server: Arc<ContextServer>,
 384        configuration: Arc<ContextServerConfiguration>,
 385        cx: &mut Context<Self>,
 386    ) {
 387        let id = server.id();
 388        if matches!(
 389            self.servers.get(&id),
 390            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 391        ) {
 392            self.stop_server(&id, cx).log_err();
 393        }
 394
 395        let task = cx.spawn({
 396            let id = server.id();
 397            let server = server.clone();
 398            let configuration = configuration.clone();
 399            async move |this, cx| {
 400                match server.clone().start(cx).await {
 401                    Ok(_) => {
 402                        log::info!("Started {} context server", id);
 403                        debug_assert!(server.client().is_some());
 404
 405                        this.update(cx, |this, cx| {
 406                            this.update_server_state(
 407                                id.clone(),
 408                                ContextServerState::Running {
 409                                    server,
 410                                    configuration,
 411                                },
 412                                cx,
 413                            )
 414                        })
 415                        .log_err()
 416                    }
 417                    Err(err) => {
 418                        log::error!("{} context server failed to start: {}", id, err);
 419                        this.update(cx, |this, cx| {
 420                            this.update_server_state(
 421                                id.clone(),
 422                                ContextServerState::Error {
 423                                    configuration,
 424                                    server,
 425                                    error: err.to_string().into(),
 426                                },
 427                                cx,
 428                            )
 429                        })
 430                        .log_err()
 431                    }
 432                };
 433            }
 434        });
 435
 436        self.update_server_state(
 437            id.clone(),
 438            ContextServerState::Starting {
 439                configuration,
 440                _task: task,
 441                server,
 442            },
 443            cx,
 444        );
 445    }
 446
 447    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 448        let state = self
 449            .servers
 450            .remove(id)
 451            .context("Context server not found")?;
 452        drop(state);
 453        cx.emit(Event::ServerStatusChanged {
 454            server_id: id.clone(),
 455            status: ContextServerStatus::Stopped,
 456        });
 457        Ok(())
 458    }
 459
 460    fn create_context_server(
 461        &self,
 462        id: ContextServerId,
 463        configuration: Arc<ContextServerConfiguration>,
 464        cx: &mut Context<Self>,
 465    ) -> Arc<ContextServer> {
 466        let root_path = self
 467            .project
 468            .read_with(cx, |project, cx| project.active_project_directory(cx))
 469            .ok()
 470            .flatten()
 471            .or_else(|| {
 472                self.worktree_store.read_with(cx, |store, cx| {
 473                    store.visible_worktrees(cx).fold(None, |acc, item| {
 474                        if acc.is_none() {
 475                            item.read(cx).root_dir()
 476                        } else {
 477                            acc
 478                        }
 479                    })
 480                })
 481            });
 482
 483        if let Some(factory) = self.context_server_factory.as_ref() {
 484            factory(id, configuration)
 485        } else {
 486            Arc::new(ContextServer::stdio(
 487                id,
 488                configuration.command().clone(),
 489                root_path,
 490            ))
 491        }
 492    }
 493
 494    fn resolve_context_server_settings<'a>(
 495        worktree_store: &'a Entity<WorktreeStore>,
 496        cx: &'a App,
 497    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 498        let location = worktree_store
 499            .read(cx)
 500            .visible_worktrees(cx)
 501            .next()
 502            .map(|worktree| settings::SettingsLocation {
 503                worktree_id: worktree.read(cx).id(),
 504                path: Path::new(""),
 505            });
 506        &ProjectSettings::get(location, cx).context_servers
 507    }
 508
 509    fn update_server_state(
 510        &mut self,
 511        id: ContextServerId,
 512        state: ContextServerState,
 513        cx: &mut Context<Self>,
 514    ) {
 515        let status = ContextServerStatus::from_state(&state);
 516        self.servers.insert(id.clone(), state);
 517        cx.emit(Event::ServerStatusChanged {
 518            server_id: id,
 519            status,
 520        });
 521    }
 522
 523    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 524        if self.update_servers_task.is_some() {
 525            self.needs_server_update = true;
 526        } else {
 527            self.needs_server_update = false;
 528            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 529                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 530                    log::error!("Error maintaining context servers: {}", err);
 531                }
 532
 533                this.update(cx, |this, cx| {
 534                    this.update_servers_task.take();
 535                    if this.needs_server_update {
 536                        this.available_context_servers_changed(cx);
 537                    }
 538                })?;
 539
 540                Ok(())
 541            }));
 542        }
 543    }
 544
 545    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 546        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
 547            (
 548                this.context_server_settings.clone(),
 549                this.registry.clone(),
 550                this.worktree_store.clone(),
 551            )
 552        })?;
 553
 554        for (id, _) in
 555            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 556        {
 557            configured_servers
 558                .entry(id)
 559                .or_insert(ContextServerSettings::default_extension());
 560        }
 561
 562        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
 563            configured_servers
 564                .into_iter()
 565                .partition(|(_, settings)| settings.enabled());
 566
 567        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
 568            let id = ContextServerId(id);
 569            ContextServerConfiguration::from_settings(
 570                settings,
 571                id.clone(),
 572                registry.clone(),
 573                worktree_store.clone(),
 574                cx,
 575            )
 576            .map(|config| (id, config))
 577        }))
 578        .await
 579        .into_iter()
 580        .filter_map(|(id, config)| config.map(|config| (id, config)))
 581        .collect::<HashMap<_, _>>();
 582
 583        let mut servers_to_start = Vec::new();
 584        let mut servers_to_remove = HashSet::default();
 585        let mut servers_to_stop = HashSet::default();
 586
 587        this.update(cx, |this, cx| {
 588            for server_id in this.servers.keys() {
 589                // All servers that are not in desired_servers should be removed from the store.
 590                // This can happen if the user removed a server from the context server settings.
 591                if !configured_servers.contains_key(server_id) {
 592                    if disabled_servers.contains_key(&server_id.0) {
 593                        servers_to_stop.insert(server_id.clone());
 594                    } else {
 595                        servers_to_remove.insert(server_id.clone());
 596                    }
 597                }
 598            }
 599
 600            for (id, config) in configured_servers {
 601                let state = this.servers.get(&id);
 602                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
 603                let existing_config = state.as_ref().map(|state| state.configuration());
 604                if existing_config.as_deref() != Some(&config) || is_stopped {
 605                    let config = Arc::new(config);
 606                    let server = this.create_context_server(id.clone(), config.clone(), cx);
 607                    servers_to_start.push((server, config));
 608                    if this.servers.contains_key(&id) {
 609                        servers_to_stop.insert(id);
 610                    }
 611                }
 612            }
 613        })?;
 614
 615        this.update(cx, |this, cx| {
 616            for id in servers_to_stop {
 617                this.stop_server(&id, cx)?;
 618            }
 619            for id in servers_to_remove {
 620                this.remove_server(&id, cx)?;
 621            }
 622            for (server, config) in servers_to_start {
 623                this.run_server(server, config, cx);
 624            }
 625            anyhow::Ok(())
 626        })?
 627    }
 628}
 629
 630#[cfg(test)]
 631mod tests {
 632    use super::*;
 633    use crate::{
 634        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 635        project_settings::ProjectSettings,
 636    };
 637    use context_server::test::create_fake_transport;
 638    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 639    use serde_json::json;
 640    use std::{cell::RefCell, path::PathBuf, rc::Rc};
 641    use util::path;
 642
 643    #[gpui::test]
 644    async fn test_context_server_status(cx: &mut TestAppContext) {
 645        const SERVER_1_ID: &str = "mcp-1";
 646        const SERVER_2_ID: &str = "mcp-2";
 647
 648        let (_fs, project) = setup_context_server_test(
 649            cx,
 650            json!({"code.rs": ""}),
 651            vec![
 652                (SERVER_1_ID.into(), dummy_server_settings()),
 653                (SERVER_2_ID.into(), dummy_server_settings()),
 654            ],
 655        )
 656        .await;
 657
 658        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 659        let store = cx.new(|cx| {
 660            ContextServerStore::test(
 661                registry.clone(),
 662                project.read(cx).worktree_store(),
 663                project.downgrade(),
 664                cx,
 665            )
 666        });
 667
 668        let server_1_id = ContextServerId(SERVER_1_ID.into());
 669        let server_2_id = ContextServerId(SERVER_2_ID.into());
 670
 671        let server_1 = Arc::new(ContextServer::new(
 672            server_1_id.clone(),
 673            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 674        ));
 675        let server_2 = Arc::new(ContextServer::new(
 676            server_2_id.clone(),
 677            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 678        ));
 679
 680        store.update(cx, |store, cx| store.start_server(server_1, cx));
 681
 682        cx.run_until_parked();
 683
 684        cx.update(|cx| {
 685            assert_eq!(
 686                store.read(cx).status_for_server(&server_1_id),
 687                Some(ContextServerStatus::Running)
 688            );
 689            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 690        });
 691
 692        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 693
 694        cx.run_until_parked();
 695
 696        cx.update(|cx| {
 697            assert_eq!(
 698                store.read(cx).status_for_server(&server_1_id),
 699                Some(ContextServerStatus::Running)
 700            );
 701            assert_eq!(
 702                store.read(cx).status_for_server(&server_2_id),
 703                Some(ContextServerStatus::Running)
 704            );
 705        });
 706
 707        store
 708            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 709            .unwrap();
 710
 711        cx.update(|cx| {
 712            assert_eq!(
 713                store.read(cx).status_for_server(&server_1_id),
 714                Some(ContextServerStatus::Running)
 715            );
 716            assert_eq!(
 717                store.read(cx).status_for_server(&server_2_id),
 718                Some(ContextServerStatus::Stopped)
 719            );
 720        });
 721    }
 722
 723    #[gpui::test]
 724    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 725        const SERVER_1_ID: &str = "mcp-1";
 726        const SERVER_2_ID: &str = "mcp-2";
 727
 728        let (_fs, project) = setup_context_server_test(
 729            cx,
 730            json!({"code.rs": ""}),
 731            vec![
 732                (SERVER_1_ID.into(), dummy_server_settings()),
 733                (SERVER_2_ID.into(), dummy_server_settings()),
 734            ],
 735        )
 736        .await;
 737
 738        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 739        let store = cx.new(|cx| {
 740            ContextServerStore::test(
 741                registry.clone(),
 742                project.read(cx).worktree_store(),
 743                project.downgrade(),
 744                cx,
 745            )
 746        });
 747
 748        let server_1_id = ContextServerId(SERVER_1_ID.into());
 749        let server_2_id = ContextServerId(SERVER_2_ID.into());
 750
 751        let server_1 = Arc::new(ContextServer::new(
 752            server_1_id.clone(),
 753            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 754        ));
 755        let server_2 = Arc::new(ContextServer::new(
 756            server_2_id.clone(),
 757            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 758        ));
 759
 760        let _server_events = assert_server_events(
 761            &store,
 762            vec![
 763                (server_1_id.clone(), ContextServerStatus::Starting),
 764                (server_1_id.clone(), ContextServerStatus::Running),
 765                (server_2_id.clone(), ContextServerStatus::Starting),
 766                (server_2_id.clone(), ContextServerStatus::Running),
 767                (server_2_id.clone(), ContextServerStatus::Stopped),
 768            ],
 769            cx,
 770        );
 771
 772        store.update(cx, |store, cx| store.start_server(server_1, cx));
 773
 774        cx.run_until_parked();
 775
 776        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 777
 778        cx.run_until_parked();
 779
 780        store
 781            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 782            .unwrap();
 783    }
 784
 785    #[gpui::test(iterations = 25)]
 786    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 787        const SERVER_1_ID: &str = "mcp-1";
 788
 789        let (_fs, project) = setup_context_server_test(
 790            cx,
 791            json!({"code.rs": ""}),
 792            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 793        )
 794        .await;
 795
 796        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 797        let store = cx.new(|cx| {
 798            ContextServerStore::test(
 799                registry.clone(),
 800                project.read(cx).worktree_store(),
 801                project.downgrade(),
 802                cx,
 803            )
 804        });
 805
 806        let server_id = ContextServerId(SERVER_1_ID.into());
 807
 808        let server_with_same_id_1 = Arc::new(ContextServer::new(
 809            server_id.clone(),
 810            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 811        ));
 812        let server_with_same_id_2 = Arc::new(ContextServer::new(
 813            server_id.clone(),
 814            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 815        ));
 816
 817        // If we start another server with the same id, we should report that we stopped the previous one
 818        let _server_events = assert_server_events(
 819            &store,
 820            vec![
 821                (server_id.clone(), ContextServerStatus::Starting),
 822                (server_id.clone(), ContextServerStatus::Stopped),
 823                (server_id.clone(), ContextServerStatus::Starting),
 824                (server_id.clone(), ContextServerStatus::Running),
 825            ],
 826            cx,
 827        );
 828
 829        store.update(cx, |store, cx| {
 830            store.start_server(server_with_same_id_1.clone(), cx)
 831        });
 832        store.update(cx, |store, cx| {
 833            store.start_server(server_with_same_id_2.clone(), cx)
 834        });
 835
 836        cx.run_until_parked();
 837
 838        cx.update(|cx| {
 839            assert_eq!(
 840                store.read(cx).status_for_server(&server_id),
 841                Some(ContextServerStatus::Running)
 842            );
 843        });
 844    }
 845
 846    #[gpui::test]
 847    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 848        const SERVER_1_ID: &str = "mcp-1";
 849        const SERVER_2_ID: &str = "mcp-2";
 850
 851        let server_1_id = ContextServerId(SERVER_1_ID.into());
 852        let server_2_id = ContextServerId(SERVER_2_ID.into());
 853
 854        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 855
 856        let (_fs, project) = setup_context_server_test(
 857            cx,
 858            json!({"code.rs": ""}),
 859            vec![(
 860                SERVER_1_ID.into(),
 861                ContextServerSettings::Extension {
 862                    enabled: true,
 863                    settings: json!({
 864                        "somevalue": true
 865                    }),
 866                },
 867            )],
 868        )
 869        .await;
 870
 871        let executor = cx.executor();
 872        let registry = cx.new(|cx| {
 873            let mut registry = ContextServerDescriptorRegistry::new();
 874            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
 875            registry
 876        });
 877        let store = cx.new(|cx| {
 878            ContextServerStore::test_maintain_server_loop(
 879                Box::new(move |id, _| {
 880                    Arc::new(ContextServer::new(
 881                        id.clone(),
 882                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 883                    ))
 884                }),
 885                registry.clone(),
 886                project.read(cx).worktree_store(),
 887                project.downgrade(),
 888                cx,
 889            )
 890        });
 891
 892        // Ensure that mcp-1 starts up
 893        {
 894            let _server_events = assert_server_events(
 895                &store,
 896                vec![
 897                    (server_1_id.clone(), ContextServerStatus::Starting),
 898                    (server_1_id.clone(), ContextServerStatus::Running),
 899                ],
 900                cx,
 901            );
 902            cx.run_until_parked();
 903        }
 904
 905        // Ensure that mcp-1 is restarted when the configuration was changed
 906        {
 907            let _server_events = assert_server_events(
 908                &store,
 909                vec![
 910                    (server_1_id.clone(), ContextServerStatus::Stopped),
 911                    (server_1_id.clone(), ContextServerStatus::Starting),
 912                    (server_1_id.clone(), ContextServerStatus::Running),
 913                ],
 914                cx,
 915            );
 916            set_context_server_configuration(
 917                vec![(
 918                    server_1_id.0.clone(),
 919                    ContextServerSettings::Extension {
 920                        enabled: true,
 921                        settings: json!({
 922                            "somevalue": false
 923                        }),
 924                    },
 925                )],
 926                cx,
 927            );
 928
 929            cx.run_until_parked();
 930        }
 931
 932        // Ensure that mcp-1 is not restarted when the configuration was not changed
 933        {
 934            let _server_events = assert_server_events(&store, vec![], cx);
 935            set_context_server_configuration(
 936                vec![(
 937                    server_1_id.0.clone(),
 938                    ContextServerSettings::Extension {
 939                        enabled: true,
 940                        settings: json!({
 941                            "somevalue": false
 942                        }),
 943                    },
 944                )],
 945                cx,
 946            );
 947
 948            cx.run_until_parked();
 949        }
 950
 951        // Ensure that mcp-2 is started once it is added to the settings
 952        {
 953            let _server_events = assert_server_events(
 954                &store,
 955                vec![
 956                    (server_2_id.clone(), ContextServerStatus::Starting),
 957                    (server_2_id.clone(), ContextServerStatus::Running),
 958                ],
 959                cx,
 960            );
 961            set_context_server_configuration(
 962                vec![
 963                    (
 964                        server_1_id.0.clone(),
 965                        ContextServerSettings::Extension {
 966                            enabled: true,
 967                            settings: json!({
 968                                "somevalue": false
 969                            }),
 970                        },
 971                    ),
 972                    (
 973                        server_2_id.0.clone(),
 974                        ContextServerSettings::Custom {
 975                            enabled: true,
 976                            command: ContextServerCommand {
 977                                path: "somebinary".into(),
 978                                args: vec!["arg".to_string()],
 979                                env: None,
 980                            },
 981                        },
 982                    ),
 983                ],
 984                cx,
 985            );
 986
 987            cx.run_until_parked();
 988        }
 989
 990        // Ensure that mcp-2 is restarted once the args have changed
 991        {
 992            let _server_events = assert_server_events(
 993                &store,
 994                vec![
 995                    (server_2_id.clone(), ContextServerStatus::Stopped),
 996                    (server_2_id.clone(), ContextServerStatus::Starting),
 997                    (server_2_id.clone(), ContextServerStatus::Running),
 998                ],
 999                cx,
1000            );
1001            set_context_server_configuration(
1002                vec![
1003                    (
1004                        server_1_id.0.clone(),
1005                        ContextServerSettings::Extension {
1006                            enabled: true,
1007                            settings: json!({
1008                                "somevalue": false
1009                            }),
1010                        },
1011                    ),
1012                    (
1013                        server_2_id.0.clone(),
1014                        ContextServerSettings::Custom {
1015                            enabled: true,
1016                            command: ContextServerCommand {
1017                                path: "somebinary".into(),
1018                                args: vec!["anotherArg".to_string()],
1019                                env: None,
1020                            },
1021                        },
1022                    ),
1023                ],
1024                cx,
1025            );
1026
1027            cx.run_until_parked();
1028        }
1029
1030        // Ensure that mcp-2 is removed once it is removed from the settings
1031        {
1032            let _server_events = assert_server_events(
1033                &store,
1034                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1035                cx,
1036            );
1037            set_context_server_configuration(
1038                vec![(
1039                    server_1_id.0.clone(),
1040                    ContextServerSettings::Extension {
1041                        enabled: true,
1042                        settings: json!({
1043                            "somevalue": false
1044                        }),
1045                    },
1046                )],
1047                cx,
1048            );
1049
1050            cx.run_until_parked();
1051
1052            cx.update(|cx| {
1053                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1054            });
1055        }
1056
1057        // Ensure that nothing happens if the settings do not change
1058        {
1059            let _server_events = assert_server_events(&store, vec![], cx);
1060            set_context_server_configuration(
1061                vec![(
1062                    server_1_id.0.clone(),
1063                    ContextServerSettings::Extension {
1064                        enabled: true,
1065                        settings: json!({
1066                            "somevalue": false
1067                        }),
1068                    },
1069                )],
1070                cx,
1071            );
1072
1073            cx.run_until_parked();
1074
1075            cx.update(|cx| {
1076                assert_eq!(
1077                    store.read(cx).status_for_server(&server_1_id),
1078                    Some(ContextServerStatus::Running)
1079                );
1080                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1081            });
1082        }
1083    }
1084
1085    #[gpui::test]
1086    async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1087        const SERVER_1_ID: &str = "mcp-1";
1088
1089        let server_1_id = ContextServerId(SERVER_1_ID.into());
1090
1091        let (_fs, project) = setup_context_server_test(
1092            cx,
1093            json!({"code.rs": ""}),
1094            vec![(
1095                SERVER_1_ID.into(),
1096                ContextServerSettings::Custom {
1097                    enabled: true,
1098                    command: ContextServerCommand {
1099                        path: "somebinary".into(),
1100                        args: vec!["arg".to_string()],
1101                        env: None,
1102                    },
1103                },
1104            )],
1105        )
1106        .await;
1107
1108        let executor = cx.executor();
1109        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1110        let store = cx.new(|cx| {
1111            ContextServerStore::test_maintain_server_loop(
1112                Box::new(move |id, _| {
1113                    Arc::new(ContextServer::new(
1114                        id.clone(),
1115                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1116                    ))
1117                }),
1118                registry.clone(),
1119                project.read(cx).worktree_store(),
1120                project.downgrade(),
1121                cx,
1122            )
1123        });
1124
1125        // Ensure that mcp-1 starts up
1126        {
1127            let _server_events = assert_server_events(
1128                &store,
1129                vec![
1130                    (server_1_id.clone(), ContextServerStatus::Starting),
1131                    (server_1_id.clone(), ContextServerStatus::Running),
1132                ],
1133                cx,
1134            );
1135            cx.run_until_parked();
1136        }
1137
1138        // Ensure that mcp-1 is stopped once it is disabled.
1139        {
1140            let _server_events = assert_server_events(
1141                &store,
1142                vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1143                cx,
1144            );
1145            set_context_server_configuration(
1146                vec![(
1147                    server_1_id.0.clone(),
1148                    ContextServerSettings::Custom {
1149                        enabled: false,
1150                        command: ContextServerCommand {
1151                            path: "somebinary".into(),
1152                            args: vec!["arg".to_string()],
1153                            env: None,
1154                        },
1155                    },
1156                )],
1157                cx,
1158            );
1159
1160            cx.run_until_parked();
1161        }
1162
1163        // Ensure that mcp-1 is started once it is enabled again.
1164        {
1165            let _server_events = assert_server_events(
1166                &store,
1167                vec![
1168                    (server_1_id.clone(), ContextServerStatus::Starting),
1169                    (server_1_id.clone(), ContextServerStatus::Running),
1170                ],
1171                cx,
1172            );
1173            set_context_server_configuration(
1174                vec![(
1175                    server_1_id.0.clone(),
1176                    ContextServerSettings::Custom {
1177                        enabled: true,
1178                        command: ContextServerCommand {
1179                            path: "somebinary".into(),
1180                            args: vec!["arg".to_string()],
1181                            env: None,
1182                        },
1183                    },
1184                )],
1185                cx,
1186            );
1187
1188            cx.run_until_parked();
1189        }
1190    }
1191
1192    fn set_context_server_configuration(
1193        context_servers: Vec<(Arc<str>, ContextServerSettings)>,
1194        cx: &mut TestAppContext,
1195    ) {
1196        cx.update(|cx| {
1197            SettingsStore::update_global(cx, |store, cx| {
1198                let mut settings = ProjectSettings::default();
1199                for (id, config) in context_servers {
1200                    settings.context_servers.insert(id, config);
1201                }
1202                store
1203                    .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
1204                    .unwrap();
1205            })
1206        });
1207    }
1208
1209    struct ServerEvents {
1210        received_event_count: Rc<RefCell<usize>>,
1211        expected_event_count: usize,
1212        _subscription: Subscription,
1213    }
1214
1215    impl Drop for ServerEvents {
1216        fn drop(&mut self) {
1217            let actual_event_count = *self.received_event_count.borrow();
1218            assert_eq!(
1219                actual_event_count, self.expected_event_count,
1220                "
1221                Expected to receive {} context server store events, but received {} events",
1222                self.expected_event_count, actual_event_count
1223            );
1224        }
1225    }
1226
1227    fn dummy_server_settings() -> ContextServerSettings {
1228        ContextServerSettings::Custom {
1229            enabled: true,
1230            command: ContextServerCommand {
1231                path: "somebinary".into(),
1232                args: vec!["arg".to_string()],
1233                env: None,
1234            },
1235        }
1236    }
1237
1238    fn assert_server_events(
1239        store: &Entity<ContextServerStore>,
1240        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1241        cx: &mut TestAppContext,
1242    ) -> ServerEvents {
1243        cx.update(|cx| {
1244            let mut ix = 0;
1245            let received_event_count = Rc::new(RefCell::new(0));
1246            let expected_event_count = expected_events.len();
1247            let subscription = cx.subscribe(store, {
1248                let received_event_count = received_event_count.clone();
1249                move |_, event, _| match event {
1250                    Event::ServerStatusChanged {
1251                        server_id: actual_server_id,
1252                        status: actual_status,
1253                    } => {
1254                        let (expected_server_id, expected_status) = &expected_events[ix];
1255
1256                        assert_eq!(
1257                            actual_server_id, expected_server_id,
1258                            "Expected different server id at index {}",
1259                            ix
1260                        );
1261                        assert_eq!(
1262                            actual_status, expected_status,
1263                            "Expected different status at index {}",
1264                            ix
1265                        );
1266                        ix += 1;
1267                        *received_event_count.borrow_mut() += 1;
1268                    }
1269                }
1270            });
1271            ServerEvents {
1272                expected_event_count,
1273                received_event_count,
1274                _subscription: subscription,
1275            }
1276        })
1277    }
1278
1279    async fn setup_context_server_test(
1280        cx: &mut TestAppContext,
1281        files: serde_json::Value,
1282        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1283    ) -> (Arc<FakeFs>, Entity<Project>) {
1284        cx.update(|cx| {
1285            let settings_store = SettingsStore::test(cx);
1286            cx.set_global(settings_store);
1287            Project::init_settings(cx);
1288            let mut settings = ProjectSettings::get_global(cx).clone();
1289            for (id, config) in context_server_configurations {
1290                settings.context_servers.insert(id, config);
1291            }
1292            ProjectSettings::override_global(settings, cx);
1293        });
1294
1295        let fs = FakeFs::new(cx.executor());
1296        fs.insert_tree(path!("/test"), files).await;
1297        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1298
1299        (fs, project)
1300    }
1301
1302    struct FakeContextServerDescriptor {
1303        path: PathBuf,
1304    }
1305
1306    impl FakeContextServerDescriptor {
1307        fn new(path: impl Into<PathBuf>) -> Self {
1308            Self { path: path.into() }
1309        }
1310    }
1311
1312    impl ContextServerDescriptor for FakeContextServerDescriptor {
1313        fn command(
1314            &self,
1315            _worktree_store: Entity<WorktreeStore>,
1316            _cx: &AsyncApp,
1317        ) -> Task<Result<ContextServerCommand>> {
1318            Task::ready(Ok(ContextServerCommand {
1319                path: self.path.clone(),
1320                args: vec!["arg1".to_string(), "arg2".to_string()],
1321                env: None,
1322            }))
1323        }
1324
1325        fn configuration(
1326            &self,
1327            _worktree_store: Entity<WorktreeStore>,
1328            _cx: &AsyncApp,
1329        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1330            Task::ready(Ok(None))
1331        }
1332    }
1333}