context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::{path::Path, sync::Arc};
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use futures::{FutureExt as _, future::join_all};
  10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  11use registry::ContextServerDescriptorRegistry;
  12use settings::{Settings as _, SettingsStore};
  13use util::ResultExt as _;
  14
  15use crate::{
  16    Project,
  17    project_settings::{ContextServerSettings, ProjectSettings},
  18    worktree_store::WorktreeStore,
  19};
  20
  21pub fn init(cx: &mut App) {
  22    extension::init(cx);
  23}
  24
  25actions!(
  26    context_server,
  27    [
  28        /// Restarts the context server.
  29        Restart
  30    ]
  31);
  32
  33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  34pub enum ContextServerStatus {
  35    Starting,
  36    Running,
  37    Stopped,
  38    Error(Arc<str>),
  39}
  40
  41impl ContextServerStatus {
  42    fn from_state(state: &ContextServerState) -> Self {
  43        match state {
  44            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  45            ContextServerState::Running { .. } => ContextServerStatus::Running,
  46            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  47            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  48        }
  49    }
  50}
  51
  52enum ContextServerState {
  53    Starting {
  54        server: Arc<ContextServer>,
  55        configuration: Arc<ContextServerConfiguration>,
  56        _task: Task<()>,
  57    },
  58    Running {
  59        server: Arc<ContextServer>,
  60        configuration: Arc<ContextServerConfiguration>,
  61    },
  62    Stopped {
  63        server: Arc<ContextServer>,
  64        configuration: Arc<ContextServerConfiguration>,
  65    },
  66    Error {
  67        server: Arc<ContextServer>,
  68        configuration: Arc<ContextServerConfiguration>,
  69        error: Arc<str>,
  70    },
  71}
  72
  73impl ContextServerState {
  74    pub fn server(&self) -> Arc<ContextServer> {
  75        match self {
  76            ContextServerState::Starting { server, .. } => server.clone(),
  77            ContextServerState::Running { server, .. } => server.clone(),
  78            ContextServerState::Stopped { server, .. } => server.clone(),
  79            ContextServerState::Error { server, .. } => server.clone(),
  80        }
  81    }
  82
  83    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  84        match self {
  85            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  86            ContextServerState::Running { configuration, .. } => configuration.clone(),
  87            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  88            ContextServerState::Error { configuration, .. } => configuration.clone(),
  89        }
  90    }
  91}
  92
  93#[derive(Debug, PartialEq, Eq)]
  94pub enum ContextServerConfiguration {
  95    Custom {
  96        command: ContextServerCommand,
  97    },
  98    Extension {
  99        command: ContextServerCommand,
 100        settings: serde_json::Value,
 101    },
 102}
 103
 104impl ContextServerConfiguration {
 105    pub fn command(&self) -> &ContextServerCommand {
 106        match self {
 107            ContextServerConfiguration::Custom { command } => command,
 108            ContextServerConfiguration::Extension { command, .. } => command,
 109        }
 110    }
 111
 112    pub async fn from_settings(
 113        settings: ContextServerSettings,
 114        id: ContextServerId,
 115        registry: Entity<ContextServerDescriptorRegistry>,
 116        worktree_store: Entity<WorktreeStore>,
 117        cx: &AsyncApp,
 118    ) -> Option<Self> {
 119        match settings {
 120            ContextServerSettings::Custom {
 121                enabled: _,
 122                command,
 123            } => Some(ContextServerConfiguration::Custom { command }),
 124            ContextServerSettings::Extension {
 125                enabled: _,
 126                settings,
 127            } => {
 128                let descriptor = cx
 129                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 130                    .ok()
 131                    .flatten()?;
 132
 133                let command = descriptor.command(worktree_store, cx).await.log_err()?;
 134
 135                Some(ContextServerConfiguration::Extension { command, settings })
 136            }
 137        }
 138    }
 139}
 140
 141pub type ContextServerFactory =
 142    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 143
 144pub struct ContextServerStore {
 145    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 146    servers: HashMap<ContextServerId, ContextServerState>,
 147    worktree_store: Entity<WorktreeStore>,
 148    project: WeakEntity<Project>,
 149    registry: Entity<ContextServerDescriptorRegistry>,
 150    update_servers_task: Option<Task<Result<()>>>,
 151    context_server_factory: Option<ContextServerFactory>,
 152    needs_server_update: bool,
 153    _subscriptions: Vec<Subscription>,
 154}
 155
 156pub enum Event {
 157    ServerStatusChanged {
 158        server_id: ContextServerId,
 159        status: ContextServerStatus,
 160    },
 161}
 162
 163impl EventEmitter<Event> for ContextServerStore {}
 164
 165impl ContextServerStore {
 166    pub fn new(
 167        worktree_store: Entity<WorktreeStore>,
 168        weak_project: WeakEntity<Project>,
 169        cx: &mut Context<Self>,
 170    ) -> Self {
 171        Self::new_internal(
 172            true,
 173            None,
 174            ContextServerDescriptorRegistry::default_global(cx),
 175            worktree_store,
 176            weak_project,
 177            cx,
 178        )
 179    }
 180
 181    /// Returns all configured context server ids, regardless of enabled state.
 182    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 183        self.context_server_settings
 184            .keys()
 185            .cloned()
 186            .map(ContextServerId)
 187            .collect()
 188    }
 189
 190    #[cfg(any(test, feature = "test-support"))]
 191    pub fn test(
 192        registry: Entity<ContextServerDescriptorRegistry>,
 193        worktree_store: Entity<WorktreeStore>,
 194        weak_project: WeakEntity<Project>,
 195        cx: &mut Context<Self>,
 196    ) -> Self {
 197        Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
 198    }
 199
 200    #[cfg(any(test, feature = "test-support"))]
 201    pub fn test_maintain_server_loop(
 202        context_server_factory: ContextServerFactory,
 203        registry: Entity<ContextServerDescriptorRegistry>,
 204        worktree_store: Entity<WorktreeStore>,
 205        weak_project: WeakEntity<Project>,
 206        cx: &mut Context<Self>,
 207    ) -> Self {
 208        Self::new_internal(
 209            true,
 210            Some(context_server_factory),
 211            registry,
 212            worktree_store,
 213            weak_project,
 214            cx,
 215        )
 216    }
 217
 218    fn new_internal(
 219        maintain_server_loop: bool,
 220        context_server_factory: Option<ContextServerFactory>,
 221        registry: Entity<ContextServerDescriptorRegistry>,
 222        worktree_store: Entity<WorktreeStore>,
 223        weak_project: WeakEntity<Project>,
 224        cx: &mut Context<Self>,
 225    ) -> Self {
 226        let subscriptions = if maintain_server_loop {
 227            vec![
 228                cx.observe(&registry, |this, _registry, cx| {
 229                    this.available_context_servers_changed(cx);
 230                }),
 231                cx.observe_global::<SettingsStore>(|this, cx| {
 232                    let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
 233                    if &this.context_server_settings == settings {
 234                        return;
 235                    }
 236                    this.context_server_settings = settings.clone();
 237                    this.available_context_servers_changed(cx);
 238                }),
 239            ]
 240        } else {
 241            Vec::new()
 242        };
 243
 244        let mut this = Self {
 245            _subscriptions: subscriptions,
 246            context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
 247                .clone(),
 248            worktree_store,
 249            project: weak_project,
 250            registry,
 251            needs_server_update: false,
 252            servers: HashMap::default(),
 253            update_servers_task: None,
 254            context_server_factory,
 255        };
 256        if maintain_server_loop {
 257            this.available_context_servers_changed(cx);
 258        }
 259        this
 260    }
 261
 262    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 263        self.servers.get(id).map(|state| state.server())
 264    }
 265
 266    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 267        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 268            Some(server.clone())
 269        } else {
 270            None
 271        }
 272    }
 273
 274    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 275        self.servers.get(id).map(ContextServerStatus::from_state)
 276    }
 277
 278    pub fn configuration_for_server(
 279        &self,
 280        id: &ContextServerId,
 281    ) -> Option<Arc<ContextServerConfiguration>> {
 282        self.servers.get(id).map(|state| state.configuration())
 283    }
 284
 285    pub fn all_server_ids(&self) -> Vec<ContextServerId> {
 286        self.servers.keys().cloned().collect()
 287    }
 288
 289    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 290        self.servers
 291            .values()
 292            .filter_map(|state| {
 293                if let ContextServerState::Running { server, .. } = state {
 294                    Some(server.clone())
 295                } else {
 296                    None
 297                }
 298            })
 299            .collect()
 300    }
 301
 302    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 303        cx.spawn(async move |this, cx| {
 304            let this = this.upgrade().context("Context server store dropped")?;
 305            let settings = this
 306                .update(cx, |this, _| {
 307                    this.context_server_settings.get(&server.id().0).cloned()
 308                })
 309                .ok()
 310                .flatten()
 311                .context("Failed to get context server settings")?;
 312
 313            if !settings.enabled() {
 314                return Ok(());
 315            }
 316
 317            let (registry, worktree_store) = this.update(cx, |this, _| {
 318                (this.registry.clone(), this.worktree_store.clone())
 319            })?;
 320            let configuration = ContextServerConfiguration::from_settings(
 321                settings,
 322                server.id(),
 323                registry,
 324                worktree_store,
 325                cx,
 326            )
 327            .await
 328            .context("Failed to create context server configuration")?;
 329
 330            this.update(cx, |this, cx| {
 331                this.run_server(server, Arc::new(configuration), cx)
 332            })
 333        })
 334        .detach_and_log_err(cx);
 335    }
 336
 337    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 338        if matches!(
 339            self.servers.get(id),
 340            Some(ContextServerState::Stopped { .. })
 341        ) {
 342            return Ok(());
 343        }
 344
 345        let state = self
 346            .servers
 347            .remove(id)
 348            .context("Context server not found")?;
 349
 350        let server = state.server();
 351        let configuration = state.configuration();
 352        let mut result = Ok(());
 353        if let ContextServerState::Running { server, .. } = &state {
 354            result = server.stop();
 355        }
 356        drop(state);
 357
 358        self.update_server_state(
 359            id.clone(),
 360            ContextServerState::Stopped {
 361                configuration,
 362                server,
 363            },
 364            cx,
 365        );
 366
 367        result
 368    }
 369
 370    pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 371        if let Some(state) = self.servers.get(id) {
 372            let configuration = state.configuration();
 373
 374            self.stop_server(&state.server().id(), cx)?;
 375            let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
 376            self.run_server(new_server, configuration, cx);
 377        }
 378        Ok(())
 379    }
 380
 381    fn run_server(
 382        &mut self,
 383        server: Arc<ContextServer>,
 384        configuration: Arc<ContextServerConfiguration>,
 385        cx: &mut Context<Self>,
 386    ) {
 387        let id = server.id();
 388        if matches!(
 389            self.servers.get(&id),
 390            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 391        ) {
 392            self.stop_server(&id, cx).log_err();
 393        }
 394
 395        let task = cx.spawn({
 396            let id = server.id();
 397            let server = server.clone();
 398            let configuration = configuration.clone();
 399            async move |this, cx| {
 400                match server.clone().start(cx).await {
 401                    Ok(_) => {
 402                        debug_assert!(server.client().is_some());
 403
 404                        this.update(cx, |this, cx| {
 405                            this.update_server_state(
 406                                id.clone(),
 407                                ContextServerState::Running {
 408                                    server,
 409                                    configuration,
 410                                },
 411                                cx,
 412                            )
 413                        })
 414                        .log_err()
 415                    }
 416                    Err(err) => {
 417                        log::error!("{} context server failed to start: {}", id, err);
 418                        this.update(cx, |this, cx| {
 419                            this.update_server_state(
 420                                id.clone(),
 421                                ContextServerState::Error {
 422                                    configuration,
 423                                    server,
 424                                    error: err.to_string().into(),
 425                                },
 426                                cx,
 427                            )
 428                        })
 429                        .log_err()
 430                    }
 431                };
 432            }
 433        });
 434
 435        self.update_server_state(
 436            id.clone(),
 437            ContextServerState::Starting {
 438                configuration,
 439                _task: task,
 440                server,
 441            },
 442            cx,
 443        );
 444    }
 445
 446    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 447        let state = self
 448            .servers
 449            .remove(id)
 450            .context("Context server not found")?;
 451        drop(state);
 452        cx.emit(Event::ServerStatusChanged {
 453            server_id: id.clone(),
 454            status: ContextServerStatus::Stopped,
 455        });
 456        Ok(())
 457    }
 458
 459    fn create_context_server(
 460        &self,
 461        id: ContextServerId,
 462        configuration: Arc<ContextServerConfiguration>,
 463        cx: &mut Context<Self>,
 464    ) -> Arc<ContextServer> {
 465        let root_path = self
 466            .project
 467            .read_with(cx, |project, cx| project.active_project_directory(cx))
 468            .ok()
 469            .flatten()
 470            .or_else(|| {
 471                self.worktree_store.read_with(cx, |store, cx| {
 472                    store.visible_worktrees(cx).fold(None, |acc, item| {
 473                        if acc.is_none() {
 474                            item.read(cx).root_dir()
 475                        } else {
 476                            acc
 477                        }
 478                    })
 479                })
 480            });
 481
 482        if let Some(factory) = self.context_server_factory.as_ref() {
 483            factory(id, configuration)
 484        } else {
 485            Arc::new(ContextServer::stdio(
 486                id,
 487                configuration.command().clone(),
 488                root_path,
 489            ))
 490        }
 491    }
 492
 493    fn resolve_context_server_settings<'a>(
 494        worktree_store: &'a Entity<WorktreeStore>,
 495        cx: &'a App,
 496    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 497        let location = worktree_store
 498            .read(cx)
 499            .visible_worktrees(cx)
 500            .next()
 501            .map(|worktree| settings::SettingsLocation {
 502                worktree_id: worktree.read(cx).id(),
 503                path: Path::new(""),
 504            });
 505        &ProjectSettings::get(location, cx).context_servers
 506    }
 507
 508    fn update_server_state(
 509        &mut self,
 510        id: ContextServerId,
 511        state: ContextServerState,
 512        cx: &mut Context<Self>,
 513    ) {
 514        let status = ContextServerStatus::from_state(&state);
 515        self.servers.insert(id.clone(), state);
 516        cx.emit(Event::ServerStatusChanged {
 517            server_id: id,
 518            status,
 519        });
 520    }
 521
 522    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 523        if self.update_servers_task.is_some() {
 524            self.needs_server_update = true;
 525        } else {
 526            self.needs_server_update = false;
 527            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 528                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 529                    log::error!("Error maintaining context servers: {}", err);
 530                }
 531
 532                this.update(cx, |this, cx| {
 533                    this.update_servers_task.take();
 534                    if this.needs_server_update {
 535                        this.available_context_servers_changed(cx);
 536                    }
 537                })?;
 538
 539                Ok(())
 540            }));
 541        }
 542    }
 543
 544    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 545        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
 546            (
 547                this.context_server_settings.clone(),
 548                this.registry.clone(),
 549                this.worktree_store.clone(),
 550            )
 551        })?;
 552
 553        for (id, _) in
 554            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 555        {
 556            configured_servers
 557                .entry(id)
 558                .or_insert(ContextServerSettings::default_extension());
 559        }
 560
 561        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
 562            configured_servers
 563                .into_iter()
 564                .partition(|(_, settings)| settings.enabled());
 565
 566        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
 567            let id = ContextServerId(id);
 568            ContextServerConfiguration::from_settings(
 569                settings,
 570                id.clone(),
 571                registry.clone(),
 572                worktree_store.clone(),
 573                cx,
 574            )
 575            .map(|config| (id, config))
 576        }))
 577        .await
 578        .into_iter()
 579        .filter_map(|(id, config)| config.map(|config| (id, config)))
 580        .collect::<HashMap<_, _>>();
 581
 582        let mut servers_to_start = Vec::new();
 583        let mut servers_to_remove = HashSet::default();
 584        let mut servers_to_stop = HashSet::default();
 585
 586        this.update(cx, |this, cx| {
 587            for server_id in this.servers.keys() {
 588                // All servers that are not in desired_servers should be removed from the store.
 589                // This can happen if the user removed a server from the context server settings.
 590                if !configured_servers.contains_key(server_id) {
 591                    if disabled_servers.contains_key(&server_id.0) {
 592                        servers_to_stop.insert(server_id.clone());
 593                    } else {
 594                        servers_to_remove.insert(server_id.clone());
 595                    }
 596                }
 597            }
 598
 599            for (id, config) in configured_servers {
 600                let state = this.servers.get(&id);
 601                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
 602                let existing_config = state.as_ref().map(|state| state.configuration());
 603                if existing_config.as_deref() != Some(&config) || is_stopped {
 604                    let config = Arc::new(config);
 605                    let server = this.create_context_server(id.clone(), config.clone(), cx);
 606                    servers_to_start.push((server, config));
 607                    if this.servers.contains_key(&id) {
 608                        servers_to_stop.insert(id);
 609                    }
 610                }
 611            }
 612        })?;
 613
 614        this.update(cx, |this, cx| {
 615            for id in servers_to_stop {
 616                this.stop_server(&id, cx)?;
 617            }
 618            for id in servers_to_remove {
 619                this.remove_server(&id, cx)?;
 620            }
 621            for (server, config) in servers_to_start {
 622                this.run_server(server, config, cx);
 623            }
 624            anyhow::Ok(())
 625        })?
 626    }
 627}
 628
 629#[cfg(test)]
 630mod tests {
 631    use super::*;
 632    use crate::{
 633        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 634        project_settings::ProjectSettings,
 635    };
 636    use context_server::test::create_fake_transport;
 637    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 638    use serde_json::json;
 639    use std::{cell::RefCell, path::PathBuf, rc::Rc};
 640    use util::path;
 641
 642    #[gpui::test]
 643    async fn test_context_server_status(cx: &mut TestAppContext) {
 644        const SERVER_1_ID: &str = "mcp-1";
 645        const SERVER_2_ID: &str = "mcp-2";
 646
 647        let (_fs, project) = setup_context_server_test(
 648            cx,
 649            json!({"code.rs": ""}),
 650            vec![
 651                (SERVER_1_ID.into(), dummy_server_settings()),
 652                (SERVER_2_ID.into(), dummy_server_settings()),
 653            ],
 654        )
 655        .await;
 656
 657        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 658        let store = cx.new(|cx| {
 659            ContextServerStore::test(
 660                registry.clone(),
 661                project.read(cx).worktree_store(),
 662                project.downgrade(),
 663                cx,
 664            )
 665        });
 666
 667        let server_1_id = ContextServerId(SERVER_1_ID.into());
 668        let server_2_id = ContextServerId(SERVER_2_ID.into());
 669
 670        let server_1 = Arc::new(ContextServer::new(
 671            server_1_id.clone(),
 672            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 673        ));
 674        let server_2 = Arc::new(ContextServer::new(
 675            server_2_id.clone(),
 676            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 677        ));
 678
 679        store.update(cx, |store, cx| store.start_server(server_1, cx));
 680
 681        cx.run_until_parked();
 682
 683        cx.update(|cx| {
 684            assert_eq!(
 685                store.read(cx).status_for_server(&server_1_id),
 686                Some(ContextServerStatus::Running)
 687            );
 688            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 689        });
 690
 691        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 692
 693        cx.run_until_parked();
 694
 695        cx.update(|cx| {
 696            assert_eq!(
 697                store.read(cx).status_for_server(&server_1_id),
 698                Some(ContextServerStatus::Running)
 699            );
 700            assert_eq!(
 701                store.read(cx).status_for_server(&server_2_id),
 702                Some(ContextServerStatus::Running)
 703            );
 704        });
 705
 706        store
 707            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 708            .unwrap();
 709
 710        cx.update(|cx| {
 711            assert_eq!(
 712                store.read(cx).status_for_server(&server_1_id),
 713                Some(ContextServerStatus::Running)
 714            );
 715            assert_eq!(
 716                store.read(cx).status_for_server(&server_2_id),
 717                Some(ContextServerStatus::Stopped)
 718            );
 719        });
 720    }
 721
 722    #[gpui::test]
 723    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 724        const SERVER_1_ID: &str = "mcp-1";
 725        const SERVER_2_ID: &str = "mcp-2";
 726
 727        let (_fs, project) = setup_context_server_test(
 728            cx,
 729            json!({"code.rs": ""}),
 730            vec![
 731                (SERVER_1_ID.into(), dummy_server_settings()),
 732                (SERVER_2_ID.into(), dummy_server_settings()),
 733            ],
 734        )
 735        .await;
 736
 737        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 738        let store = cx.new(|cx| {
 739            ContextServerStore::test(
 740                registry.clone(),
 741                project.read(cx).worktree_store(),
 742                project.downgrade(),
 743                cx,
 744            )
 745        });
 746
 747        let server_1_id = ContextServerId(SERVER_1_ID.into());
 748        let server_2_id = ContextServerId(SERVER_2_ID.into());
 749
 750        let server_1 = Arc::new(ContextServer::new(
 751            server_1_id.clone(),
 752            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 753        ));
 754        let server_2 = Arc::new(ContextServer::new(
 755            server_2_id.clone(),
 756            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 757        ));
 758
 759        let _server_events = assert_server_events(
 760            &store,
 761            vec![
 762                (server_1_id.clone(), ContextServerStatus::Starting),
 763                (server_1_id, ContextServerStatus::Running),
 764                (server_2_id.clone(), ContextServerStatus::Starting),
 765                (server_2_id.clone(), ContextServerStatus::Running),
 766                (server_2_id.clone(), ContextServerStatus::Stopped),
 767            ],
 768            cx,
 769        );
 770
 771        store.update(cx, |store, cx| store.start_server(server_1, cx));
 772
 773        cx.run_until_parked();
 774
 775        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 776
 777        cx.run_until_parked();
 778
 779        store
 780            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 781            .unwrap();
 782    }
 783
 784    #[gpui::test(iterations = 25)]
 785    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 786        const SERVER_1_ID: &str = "mcp-1";
 787
 788        let (_fs, project) = setup_context_server_test(
 789            cx,
 790            json!({"code.rs": ""}),
 791            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 792        )
 793        .await;
 794
 795        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 796        let store = cx.new(|cx| {
 797            ContextServerStore::test(
 798                registry.clone(),
 799                project.read(cx).worktree_store(),
 800                project.downgrade(),
 801                cx,
 802            )
 803        });
 804
 805        let server_id = ContextServerId(SERVER_1_ID.into());
 806
 807        let server_with_same_id_1 = Arc::new(ContextServer::new(
 808            server_id.clone(),
 809            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 810        ));
 811        let server_with_same_id_2 = Arc::new(ContextServer::new(
 812            server_id.clone(),
 813            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 814        ));
 815
 816        // If we start another server with the same id, we should report that we stopped the previous one
 817        let _server_events = assert_server_events(
 818            &store,
 819            vec![
 820                (server_id.clone(), ContextServerStatus::Starting),
 821                (server_id.clone(), ContextServerStatus::Stopped),
 822                (server_id.clone(), ContextServerStatus::Starting),
 823                (server_id.clone(), ContextServerStatus::Running),
 824            ],
 825            cx,
 826        );
 827
 828        store.update(cx, |store, cx| {
 829            store.start_server(server_with_same_id_1.clone(), cx)
 830        });
 831        store.update(cx, |store, cx| {
 832            store.start_server(server_with_same_id_2.clone(), cx)
 833        });
 834
 835        cx.run_until_parked();
 836
 837        cx.update(|cx| {
 838            assert_eq!(
 839                store.read(cx).status_for_server(&server_id),
 840                Some(ContextServerStatus::Running)
 841            );
 842        });
 843    }
 844
 845    #[gpui::test]
 846    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 847        const SERVER_1_ID: &str = "mcp-1";
 848        const SERVER_2_ID: &str = "mcp-2";
 849
 850        let server_1_id = ContextServerId(SERVER_1_ID.into());
 851        let server_2_id = ContextServerId(SERVER_2_ID.into());
 852
 853        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 854
 855        let (_fs, project) = setup_context_server_test(
 856            cx,
 857            json!({"code.rs": ""}),
 858            vec![(
 859                SERVER_1_ID.into(),
 860                ContextServerSettings::Extension {
 861                    enabled: true,
 862                    settings: json!({
 863                        "somevalue": true
 864                    }),
 865                },
 866            )],
 867        )
 868        .await;
 869
 870        let executor = cx.executor();
 871        let registry = cx.new(|cx| {
 872            let mut registry = ContextServerDescriptorRegistry::new();
 873            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
 874            registry
 875        });
 876        let store = cx.new(|cx| {
 877            ContextServerStore::test_maintain_server_loop(
 878                Box::new(move |id, _| {
 879                    Arc::new(ContextServer::new(
 880                        id.clone(),
 881                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 882                    ))
 883                }),
 884                registry.clone(),
 885                project.read(cx).worktree_store(),
 886                project.downgrade(),
 887                cx,
 888            )
 889        });
 890
 891        // Ensure that mcp-1 starts up
 892        {
 893            let _server_events = assert_server_events(
 894                &store,
 895                vec![
 896                    (server_1_id.clone(), ContextServerStatus::Starting),
 897                    (server_1_id.clone(), ContextServerStatus::Running),
 898                ],
 899                cx,
 900            );
 901            cx.run_until_parked();
 902        }
 903
 904        // Ensure that mcp-1 is restarted when the configuration was changed
 905        {
 906            let _server_events = assert_server_events(
 907                &store,
 908                vec![
 909                    (server_1_id.clone(), ContextServerStatus::Stopped),
 910                    (server_1_id.clone(), ContextServerStatus::Starting),
 911                    (server_1_id.clone(), ContextServerStatus::Running),
 912                ],
 913                cx,
 914            );
 915            set_context_server_configuration(
 916                vec![(
 917                    server_1_id.0.clone(),
 918                    ContextServerSettings::Extension {
 919                        enabled: true,
 920                        settings: json!({
 921                            "somevalue": false
 922                        }),
 923                    },
 924                )],
 925                cx,
 926            );
 927
 928            cx.run_until_parked();
 929        }
 930
 931        // Ensure that mcp-1 is not restarted when the configuration was not changed
 932        {
 933            let _server_events = assert_server_events(&store, vec![], cx);
 934            set_context_server_configuration(
 935                vec![(
 936                    server_1_id.0.clone(),
 937                    ContextServerSettings::Extension {
 938                        enabled: true,
 939                        settings: json!({
 940                            "somevalue": false
 941                        }),
 942                    },
 943                )],
 944                cx,
 945            );
 946
 947            cx.run_until_parked();
 948        }
 949
 950        // Ensure that mcp-2 is started once it is added to the settings
 951        {
 952            let _server_events = assert_server_events(
 953                &store,
 954                vec![
 955                    (server_2_id.clone(), ContextServerStatus::Starting),
 956                    (server_2_id.clone(), ContextServerStatus::Running),
 957                ],
 958                cx,
 959            );
 960            set_context_server_configuration(
 961                vec![
 962                    (
 963                        server_1_id.0.clone(),
 964                        ContextServerSettings::Extension {
 965                            enabled: true,
 966                            settings: json!({
 967                                "somevalue": false
 968                            }),
 969                        },
 970                    ),
 971                    (
 972                        server_2_id.0.clone(),
 973                        ContextServerSettings::Custom {
 974                            enabled: true,
 975                            command: ContextServerCommand {
 976                                path: "somebinary".into(),
 977                                args: vec!["arg".to_string()],
 978                                env: None,
 979                                timeout: None,
 980                            },
 981                        },
 982                    ),
 983                ],
 984                cx,
 985            );
 986
 987            cx.run_until_parked();
 988        }
 989
 990        // Ensure that mcp-2 is restarted once the args have changed
 991        {
 992            let _server_events = assert_server_events(
 993                &store,
 994                vec![
 995                    (server_2_id.clone(), ContextServerStatus::Stopped),
 996                    (server_2_id.clone(), ContextServerStatus::Starting),
 997                    (server_2_id.clone(), ContextServerStatus::Running),
 998                ],
 999                cx,
1000            );
1001            set_context_server_configuration(
1002                vec![
1003                    (
1004                        server_1_id.0.clone(),
1005                        ContextServerSettings::Extension {
1006                            enabled: true,
1007                            settings: json!({
1008                                "somevalue": false
1009                            }),
1010                        },
1011                    ),
1012                    (
1013                        server_2_id.0.clone(),
1014                        ContextServerSettings::Custom {
1015                            enabled: true,
1016                            command: ContextServerCommand {
1017                                path: "somebinary".into(),
1018                                args: vec!["anotherArg".to_string()],
1019                                env: None,
1020                                timeout: None,
1021                            },
1022                        },
1023                    ),
1024                ],
1025                cx,
1026            );
1027
1028            cx.run_until_parked();
1029        }
1030
1031        // Ensure that mcp-2 is removed once it is removed from the settings
1032        {
1033            let _server_events = assert_server_events(
1034                &store,
1035                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1036                cx,
1037            );
1038            set_context_server_configuration(
1039                vec![(
1040                    server_1_id.0.clone(),
1041                    ContextServerSettings::Extension {
1042                        enabled: true,
1043                        settings: json!({
1044                            "somevalue": false
1045                        }),
1046                    },
1047                )],
1048                cx,
1049            );
1050
1051            cx.run_until_parked();
1052
1053            cx.update(|cx| {
1054                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1055            });
1056        }
1057
1058        // Ensure that nothing happens if the settings do not change
1059        {
1060            let _server_events = assert_server_events(&store, vec![], cx);
1061            set_context_server_configuration(
1062                vec![(
1063                    server_1_id.0.clone(),
1064                    ContextServerSettings::Extension {
1065                        enabled: true,
1066                        settings: json!({
1067                            "somevalue": false
1068                        }),
1069                    },
1070                )],
1071                cx,
1072            );
1073
1074            cx.run_until_parked();
1075
1076            cx.update(|cx| {
1077                assert_eq!(
1078                    store.read(cx).status_for_server(&server_1_id),
1079                    Some(ContextServerStatus::Running)
1080                );
1081                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1082            });
1083        }
1084    }
1085
1086    #[gpui::test]
1087    async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1088        const SERVER_1_ID: &str = "mcp-1";
1089
1090        let server_1_id = ContextServerId(SERVER_1_ID.into());
1091
1092        let (_fs, project) = setup_context_server_test(
1093            cx,
1094            json!({"code.rs": ""}),
1095            vec![(
1096                SERVER_1_ID.into(),
1097                ContextServerSettings::Custom {
1098                    enabled: true,
1099                    command: ContextServerCommand {
1100                        path: "somebinary".into(),
1101                        args: vec!["arg".to_string()],
1102                        env: None,
1103                        timeout: None,
1104                    },
1105                },
1106            )],
1107        )
1108        .await;
1109
1110        let executor = cx.executor();
1111        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1112        let store = cx.new(|cx| {
1113            ContextServerStore::test_maintain_server_loop(
1114                Box::new(move |id, _| {
1115                    Arc::new(ContextServer::new(
1116                        id.clone(),
1117                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1118                    ))
1119                }),
1120                registry.clone(),
1121                project.read(cx).worktree_store(),
1122                project.downgrade(),
1123                cx,
1124            )
1125        });
1126
1127        // Ensure that mcp-1 starts up
1128        {
1129            let _server_events = assert_server_events(
1130                &store,
1131                vec![
1132                    (server_1_id.clone(), ContextServerStatus::Starting),
1133                    (server_1_id.clone(), ContextServerStatus::Running),
1134                ],
1135                cx,
1136            );
1137            cx.run_until_parked();
1138        }
1139
1140        // Ensure that mcp-1 is stopped once it is disabled.
1141        {
1142            let _server_events = assert_server_events(
1143                &store,
1144                vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1145                cx,
1146            );
1147            set_context_server_configuration(
1148                vec![(
1149                    server_1_id.0.clone(),
1150                    ContextServerSettings::Custom {
1151                        enabled: false,
1152                        command: ContextServerCommand {
1153                            path: "somebinary".into(),
1154                            args: vec!["arg".to_string()],
1155                            env: None,
1156                            timeout: None,
1157                        },
1158                    },
1159                )],
1160                cx,
1161            );
1162
1163            cx.run_until_parked();
1164        }
1165
1166        // Ensure that mcp-1 is started once it is enabled again.
1167        {
1168            let _server_events = assert_server_events(
1169                &store,
1170                vec![
1171                    (server_1_id.clone(), ContextServerStatus::Starting),
1172                    (server_1_id.clone(), ContextServerStatus::Running),
1173                ],
1174                cx,
1175            );
1176            set_context_server_configuration(
1177                vec![(
1178                    server_1_id.0.clone(),
1179                    ContextServerSettings::Custom {
1180                        enabled: true,
1181                        command: ContextServerCommand {
1182                            path: "somebinary".into(),
1183                            args: vec!["arg".to_string()],
1184                            timeout: None,
1185                            env: None,
1186                        },
1187                    },
1188                )],
1189                cx,
1190            );
1191
1192            cx.run_until_parked();
1193        }
1194    }
1195
1196    fn set_context_server_configuration(
1197        context_servers: Vec<(Arc<str>, ContextServerSettings)>,
1198        cx: &mut TestAppContext,
1199    ) {
1200        cx.update(|cx| {
1201            SettingsStore::update_global(cx, |store, cx| {
1202                let mut settings = ProjectSettings::default();
1203                for (id, config) in context_servers {
1204                    settings.context_servers.insert(id, config);
1205                }
1206                store
1207                    .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
1208                    .unwrap();
1209            })
1210        });
1211    }
1212
1213    struct ServerEvents {
1214        received_event_count: Rc<RefCell<usize>>,
1215        expected_event_count: usize,
1216        _subscription: Subscription,
1217    }
1218
1219    impl Drop for ServerEvents {
1220        fn drop(&mut self) {
1221            let actual_event_count = *self.received_event_count.borrow();
1222            assert_eq!(
1223                actual_event_count, self.expected_event_count,
1224                "
1225                Expected to receive {} context server store events, but received {} events",
1226                self.expected_event_count, actual_event_count
1227            );
1228        }
1229    }
1230
1231    fn dummy_server_settings() -> ContextServerSettings {
1232        ContextServerSettings::Custom {
1233            enabled: true,
1234            command: ContextServerCommand {
1235                path: "somebinary".into(),
1236                args: vec!["arg".to_string()],
1237                env: None,
1238                timeout: None,
1239            },
1240        }
1241    }
1242
1243    fn assert_server_events(
1244        store: &Entity<ContextServerStore>,
1245        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1246        cx: &mut TestAppContext,
1247    ) -> ServerEvents {
1248        cx.update(|cx| {
1249            let mut ix = 0;
1250            let received_event_count = Rc::new(RefCell::new(0));
1251            let expected_event_count = expected_events.len();
1252            let subscription = cx.subscribe(store, {
1253                let received_event_count = received_event_count.clone();
1254                move |_, event, _| match event {
1255                    Event::ServerStatusChanged {
1256                        server_id: actual_server_id,
1257                        status: actual_status,
1258                    } => {
1259                        let (expected_server_id, expected_status) = &expected_events[ix];
1260
1261                        assert_eq!(
1262                            actual_server_id, expected_server_id,
1263                            "Expected different server id at index {}",
1264                            ix
1265                        );
1266                        assert_eq!(
1267                            actual_status, expected_status,
1268                            "Expected different status at index {}",
1269                            ix
1270                        );
1271                        ix += 1;
1272                        *received_event_count.borrow_mut() += 1;
1273                    }
1274                }
1275            });
1276            ServerEvents {
1277                expected_event_count,
1278                received_event_count,
1279                _subscription: subscription,
1280            }
1281        })
1282    }
1283
1284    async fn setup_context_server_test(
1285        cx: &mut TestAppContext,
1286        files: serde_json::Value,
1287        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1288    ) -> (Arc<FakeFs>, Entity<Project>) {
1289        cx.update(|cx| {
1290            let settings_store = SettingsStore::test(cx);
1291            cx.set_global(settings_store);
1292            Project::init_settings(cx);
1293            let mut settings = ProjectSettings::get_global(cx).clone();
1294            for (id, config) in context_server_configurations {
1295                settings.context_servers.insert(id, config);
1296            }
1297            ProjectSettings::override_global(settings, cx);
1298        });
1299
1300        let fs = FakeFs::new(cx.executor());
1301        fs.insert_tree(path!("/test"), files).await;
1302        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1303
1304        (fs, project)
1305    }
1306
1307    struct FakeContextServerDescriptor {
1308        path: PathBuf,
1309    }
1310
1311    impl FakeContextServerDescriptor {
1312        fn new(path: impl Into<PathBuf>) -> Self {
1313            Self { path: path.into() }
1314        }
1315    }
1316
1317    impl ContextServerDescriptor for FakeContextServerDescriptor {
1318        fn command(
1319            &self,
1320            _worktree_store: Entity<WorktreeStore>,
1321            _cx: &AsyncApp,
1322        ) -> Task<Result<ContextServerCommand>> {
1323            Task::ready(Ok(ContextServerCommand {
1324                path: self.path.clone(),
1325                args: vec!["arg1".to_string(), "arg2".to_string()],
1326                env: None,
1327                timeout: None,
1328            }))
1329        }
1330
1331        fn configuration(
1332            &self,
1333            _worktree_store: Entity<WorktreeStore>,
1334            _cx: &AsyncApp,
1335        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1336            Task::ready(Ok(None))
1337        }
1338    }
1339}