context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::{path::Path, sync::Arc};
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use futures::{FutureExt as _, future::join_all};
  10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  11use registry::ContextServerDescriptorRegistry;
  12use settings::{Settings as _, SettingsStore};
  13use util::ResultExt as _;
  14
  15use crate::{
  16    Project,
  17    project_settings::{ContextServerSettings, ProjectSettings},
  18    worktree_store::WorktreeStore,
  19};
  20
  21pub fn init(cx: &mut App) {
  22    extension::init(cx);
  23}
  24
  25actions!(
  26    context_server,
  27    [
  28        /// Restarts the context server.
  29        Restart
  30    ]
  31);
  32
  33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  34pub enum ContextServerStatus {
  35    Starting,
  36    Running,
  37    Stopped,
  38    Error(Arc<str>),
  39}
  40
  41impl ContextServerStatus {
  42    fn from_state(state: &ContextServerState) -> Self {
  43        match state {
  44            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  45            ContextServerState::Running { .. } => ContextServerStatus::Running,
  46            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  47            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  48        }
  49    }
  50}
  51
  52enum ContextServerState {
  53    Starting {
  54        server: Arc<ContextServer>,
  55        configuration: Arc<ContextServerConfiguration>,
  56        _task: Task<()>,
  57    },
  58    Running {
  59        server: Arc<ContextServer>,
  60        configuration: Arc<ContextServerConfiguration>,
  61    },
  62    Stopped {
  63        server: Arc<ContextServer>,
  64        configuration: Arc<ContextServerConfiguration>,
  65    },
  66    Error {
  67        server: Arc<ContextServer>,
  68        configuration: Arc<ContextServerConfiguration>,
  69        error: Arc<str>,
  70    },
  71}
  72
  73impl ContextServerState {
  74    pub fn server(&self) -> Arc<ContextServer> {
  75        match self {
  76            ContextServerState::Starting { server, .. } => server.clone(),
  77            ContextServerState::Running { server, .. } => server.clone(),
  78            ContextServerState::Stopped { server, .. } => server.clone(),
  79            ContextServerState::Error { server, .. } => server.clone(),
  80        }
  81    }
  82
  83    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  84        match self {
  85            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  86            ContextServerState::Running { configuration, .. } => configuration.clone(),
  87            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  88            ContextServerState::Error { configuration, .. } => configuration.clone(),
  89        }
  90    }
  91}
  92
  93#[derive(Debug, PartialEq, Eq)]
  94pub enum ContextServerConfiguration {
  95    Custom {
  96        command: ContextServerCommand,
  97    },
  98    Extension {
  99        command: ContextServerCommand,
 100        settings: serde_json::Value,
 101    },
 102}
 103
 104impl ContextServerConfiguration {
 105    pub fn command(&self) -> &ContextServerCommand {
 106        match self {
 107            ContextServerConfiguration::Custom { command } => command,
 108            ContextServerConfiguration::Extension { command, .. } => command,
 109        }
 110    }
 111
 112    pub async fn from_settings(
 113        settings: ContextServerSettings,
 114        id: ContextServerId,
 115        registry: Entity<ContextServerDescriptorRegistry>,
 116        worktree_store: Entity<WorktreeStore>,
 117        cx: &AsyncApp,
 118    ) -> Option<Self> {
 119        match settings {
 120            ContextServerSettings::Custom {
 121                enabled: _,
 122                command,
 123            } => Some(ContextServerConfiguration::Custom { command }),
 124            ContextServerSettings::Extension {
 125                enabled: _,
 126                settings,
 127            } => {
 128                let descriptor = cx
 129                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 130                    .ok()
 131                    .flatten()?;
 132
 133                let command = descriptor.command(worktree_store, cx).await.log_err()?;
 134
 135                Some(ContextServerConfiguration::Extension { command, settings })
 136            }
 137        }
 138    }
 139}
 140
 141pub type ContextServerFactory =
 142    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 143
 144pub struct ContextServerStore {
 145    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 146    servers: HashMap<ContextServerId, ContextServerState>,
 147    worktree_store: Entity<WorktreeStore>,
 148    project: WeakEntity<Project>,
 149    registry: Entity<ContextServerDescriptorRegistry>,
 150    update_servers_task: Option<Task<Result<()>>>,
 151    context_server_factory: Option<ContextServerFactory>,
 152    needs_server_update: bool,
 153    _subscriptions: Vec<Subscription>,
 154}
 155
 156pub enum Event {
 157    ServerStatusChanged {
 158        server_id: ContextServerId,
 159        status: ContextServerStatus,
 160    },
 161}
 162
 163impl EventEmitter<Event> for ContextServerStore {}
 164
 165impl ContextServerStore {
 166    pub fn new(
 167        worktree_store: Entity<WorktreeStore>,
 168        weak_project: WeakEntity<Project>,
 169        cx: &mut Context<Self>,
 170    ) -> Self {
 171        Self::new_internal(
 172            true,
 173            None,
 174            ContextServerDescriptorRegistry::default_global(cx),
 175            worktree_store,
 176            weak_project,
 177            cx,
 178        )
 179    }
 180
 181    /// Returns all configured context server ids, regardless of enabled state.
 182    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 183        self.context_server_settings
 184            .keys()
 185            .cloned()
 186            .map(ContextServerId)
 187            .collect()
 188    }
 189
 190    #[cfg(any(test, feature = "test-support"))]
 191    pub fn test(
 192        registry: Entity<ContextServerDescriptorRegistry>,
 193        worktree_store: Entity<WorktreeStore>,
 194        weak_project: WeakEntity<Project>,
 195        cx: &mut Context<Self>,
 196    ) -> Self {
 197        Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
 198    }
 199
 200    #[cfg(any(test, feature = "test-support"))]
 201    pub fn test_maintain_server_loop(
 202        context_server_factory: ContextServerFactory,
 203        registry: Entity<ContextServerDescriptorRegistry>,
 204        worktree_store: Entity<WorktreeStore>,
 205        weak_project: WeakEntity<Project>,
 206        cx: &mut Context<Self>,
 207    ) -> Self {
 208        Self::new_internal(
 209            true,
 210            Some(context_server_factory),
 211            registry,
 212            worktree_store,
 213            weak_project,
 214            cx,
 215        )
 216    }
 217
 218    fn new_internal(
 219        maintain_server_loop: bool,
 220        context_server_factory: Option<ContextServerFactory>,
 221        registry: Entity<ContextServerDescriptorRegistry>,
 222        worktree_store: Entity<WorktreeStore>,
 223        weak_project: WeakEntity<Project>,
 224        cx: &mut Context<Self>,
 225    ) -> Self {
 226        let subscriptions = if maintain_server_loop {
 227            vec![
 228                cx.observe(&registry, |this, _registry, cx| {
 229                    this.available_context_servers_changed(cx);
 230                }),
 231                cx.observe_global::<SettingsStore>(|this, cx| {
 232                    let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
 233                    if &this.context_server_settings == settings {
 234                        return;
 235                    }
 236                    this.context_server_settings = settings.clone();
 237                    this.available_context_servers_changed(cx);
 238                }),
 239            ]
 240        } else {
 241            Vec::new()
 242        };
 243
 244        let mut this = Self {
 245            _subscriptions: subscriptions,
 246            context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
 247                .clone(),
 248            worktree_store,
 249            project: weak_project,
 250            registry,
 251            needs_server_update: false,
 252            servers: HashMap::default(),
 253            update_servers_task: None,
 254            context_server_factory,
 255        };
 256        if maintain_server_loop {
 257            this.available_context_servers_changed(cx);
 258        }
 259        this
 260    }
 261
 262    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 263        self.servers.get(id).map(|state| state.server())
 264    }
 265
 266    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 267        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 268            Some(server.clone())
 269        } else {
 270            None
 271        }
 272    }
 273
 274    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 275        self.servers.get(id).map(ContextServerStatus::from_state)
 276    }
 277
 278    pub fn configuration_for_server(
 279        &self,
 280        id: &ContextServerId,
 281    ) -> Option<Arc<ContextServerConfiguration>> {
 282        self.servers.get(id).map(|state| state.configuration())
 283    }
 284
 285    pub fn all_server_ids(&self) -> Vec<ContextServerId> {
 286        self.servers.keys().cloned().collect()
 287    }
 288
 289    pub fn all_registry_descriptor_ids(&self, cx: &App) -> Vec<ContextServerId> {
 290        self.registry
 291            .read(cx)
 292            .context_server_descriptors()
 293            .into_iter()
 294            .map(|(id, _)| ContextServerId(id))
 295            .collect()
 296    }
 297
 298    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 299        self.servers
 300            .values()
 301            .filter_map(|state| {
 302                if let ContextServerState::Running { server, .. } = state {
 303                    Some(server.clone())
 304                } else {
 305                    None
 306                }
 307            })
 308            .collect()
 309    }
 310
 311    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 312        cx.spawn(async move |this, cx| {
 313            let this = this.upgrade().context("Context server store dropped")?;
 314            let settings = this
 315                .update(cx, |this, _| {
 316                    this.context_server_settings.get(&server.id().0).cloned()
 317                })
 318                .ok()
 319                .flatten()
 320                .context("Failed to get context server settings")?;
 321
 322            if !settings.enabled() {
 323                return Ok(());
 324            }
 325
 326            let (registry, worktree_store) = this.update(cx, |this, _| {
 327                (this.registry.clone(), this.worktree_store.clone())
 328            })?;
 329            let configuration = ContextServerConfiguration::from_settings(
 330                settings,
 331                server.id(),
 332                registry,
 333                worktree_store,
 334                cx,
 335            )
 336            .await
 337            .context("Failed to create context server configuration")?;
 338
 339            this.update(cx, |this, cx| {
 340                this.run_server(server, Arc::new(configuration), cx)
 341            })
 342        })
 343        .detach_and_log_err(cx);
 344    }
 345
 346    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 347        if matches!(
 348            self.servers.get(id),
 349            Some(ContextServerState::Stopped { .. })
 350        ) {
 351            return Ok(());
 352        }
 353
 354        let state = self
 355            .servers
 356            .remove(id)
 357            .context("Context server not found")?;
 358
 359        let server = state.server();
 360        let configuration = state.configuration();
 361        let mut result = Ok(());
 362        if let ContextServerState::Running { server, .. } = &state {
 363            result = server.stop();
 364        }
 365        drop(state);
 366
 367        self.update_server_state(
 368            id.clone(),
 369            ContextServerState::Stopped {
 370                configuration,
 371                server,
 372            },
 373            cx,
 374        );
 375
 376        result
 377    }
 378
 379    pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 380        if let Some(state) = self.servers.get(id) {
 381            let configuration = state.configuration();
 382
 383            self.stop_server(&state.server().id(), cx)?;
 384            let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
 385            self.run_server(new_server, configuration, cx);
 386        }
 387        Ok(())
 388    }
 389
 390    fn run_server(
 391        &mut self,
 392        server: Arc<ContextServer>,
 393        configuration: Arc<ContextServerConfiguration>,
 394        cx: &mut Context<Self>,
 395    ) {
 396        let id = server.id();
 397        if matches!(
 398            self.servers.get(&id),
 399            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 400        ) {
 401            self.stop_server(&id, cx).log_err();
 402        }
 403
 404        let task = cx.spawn({
 405            let id = server.id();
 406            let server = server.clone();
 407            let configuration = configuration.clone();
 408            async move |this, cx| {
 409                match server.clone().start(cx).await {
 410                    Ok(_) => {
 411                        debug_assert!(server.client().is_some());
 412
 413                        this.update(cx, |this, cx| {
 414                            this.update_server_state(
 415                                id.clone(),
 416                                ContextServerState::Running {
 417                                    server,
 418                                    configuration,
 419                                },
 420                                cx,
 421                            )
 422                        })
 423                        .log_err()
 424                    }
 425                    Err(err) => {
 426                        log::error!("{} context server failed to start: {}", id, err);
 427                        this.update(cx, |this, cx| {
 428                            this.update_server_state(
 429                                id.clone(),
 430                                ContextServerState::Error {
 431                                    configuration,
 432                                    server,
 433                                    error: err.to_string().into(),
 434                                },
 435                                cx,
 436                            )
 437                        })
 438                        .log_err()
 439                    }
 440                };
 441            }
 442        });
 443
 444        self.update_server_state(
 445            id.clone(),
 446            ContextServerState::Starting {
 447                configuration,
 448                _task: task,
 449                server,
 450            },
 451            cx,
 452        );
 453    }
 454
 455    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 456        let state = self
 457            .servers
 458            .remove(id)
 459            .context("Context server not found")?;
 460        drop(state);
 461        cx.emit(Event::ServerStatusChanged {
 462            server_id: id.clone(),
 463            status: ContextServerStatus::Stopped,
 464        });
 465        Ok(())
 466    }
 467
 468    fn create_context_server(
 469        &self,
 470        id: ContextServerId,
 471        configuration: Arc<ContextServerConfiguration>,
 472        cx: &mut Context<Self>,
 473    ) -> Arc<ContextServer> {
 474        let root_path = self
 475            .project
 476            .read_with(cx, |project, cx| project.active_project_directory(cx))
 477            .ok()
 478            .flatten()
 479            .or_else(|| {
 480                self.worktree_store.read_with(cx, |store, cx| {
 481                    store.visible_worktrees(cx).fold(None, |acc, item| {
 482                        if acc.is_none() {
 483                            item.read(cx).root_dir()
 484                        } else {
 485                            acc
 486                        }
 487                    })
 488                })
 489            });
 490
 491        if let Some(factory) = self.context_server_factory.as_ref() {
 492            factory(id, configuration)
 493        } else {
 494            Arc::new(ContextServer::stdio(
 495                id,
 496                configuration.command().clone(),
 497                root_path,
 498            ))
 499        }
 500    }
 501
 502    fn resolve_context_server_settings<'a>(
 503        worktree_store: &'a Entity<WorktreeStore>,
 504        cx: &'a App,
 505    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 506        let location = worktree_store
 507            .read(cx)
 508            .visible_worktrees(cx)
 509            .next()
 510            .map(|worktree| settings::SettingsLocation {
 511                worktree_id: worktree.read(cx).id(),
 512                path: Path::new(""),
 513            });
 514        &ProjectSettings::get(location, cx).context_servers
 515    }
 516
 517    fn update_server_state(
 518        &mut self,
 519        id: ContextServerId,
 520        state: ContextServerState,
 521        cx: &mut Context<Self>,
 522    ) {
 523        let status = ContextServerStatus::from_state(&state);
 524        self.servers.insert(id.clone(), state);
 525        cx.emit(Event::ServerStatusChanged {
 526            server_id: id,
 527            status,
 528        });
 529    }
 530
 531    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 532        if self.update_servers_task.is_some() {
 533            self.needs_server_update = true;
 534        } else {
 535            self.needs_server_update = false;
 536            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 537                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 538                    log::error!("Error maintaining context servers: {}", err);
 539                }
 540
 541                this.update(cx, |this, cx| {
 542                    this.update_servers_task.take();
 543                    if this.needs_server_update {
 544                        this.available_context_servers_changed(cx);
 545                    }
 546                })?;
 547
 548                Ok(())
 549            }));
 550        }
 551    }
 552
 553    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 554        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
 555            (
 556                this.context_server_settings.clone(),
 557                this.registry.clone(),
 558                this.worktree_store.clone(),
 559            )
 560        })?;
 561
 562        for (id, _) in
 563            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 564        {
 565            configured_servers
 566                .entry(id)
 567                .or_insert(ContextServerSettings::default_extension());
 568        }
 569
 570        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
 571            configured_servers
 572                .into_iter()
 573                .partition(|(_, settings)| settings.enabled());
 574
 575        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
 576            let id = ContextServerId(id);
 577            ContextServerConfiguration::from_settings(
 578                settings,
 579                id.clone(),
 580                registry.clone(),
 581                worktree_store.clone(),
 582                cx,
 583            )
 584            .map(|config| (id, config))
 585        }))
 586        .await
 587        .into_iter()
 588        .filter_map(|(id, config)| config.map(|config| (id, config)))
 589        .collect::<HashMap<_, _>>();
 590
 591        let mut servers_to_start = Vec::new();
 592        let mut servers_to_remove = HashSet::default();
 593        let mut servers_to_stop = HashSet::default();
 594
 595        this.update(cx, |this, cx| {
 596            for server_id in this.servers.keys() {
 597                // All servers that are not in desired_servers should be removed from the store.
 598                // This can happen if the user removed a server from the context server settings.
 599                if !configured_servers.contains_key(server_id) {
 600                    if disabled_servers.contains_key(&server_id.0) {
 601                        servers_to_stop.insert(server_id.clone());
 602                    } else {
 603                        servers_to_remove.insert(server_id.clone());
 604                    }
 605                }
 606            }
 607
 608            for (id, config) in configured_servers {
 609                let state = this.servers.get(&id);
 610                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
 611                let existing_config = state.as_ref().map(|state| state.configuration());
 612                if existing_config.as_deref() != Some(&config) || is_stopped {
 613                    let config = Arc::new(config);
 614                    let server = this.create_context_server(id.clone(), config.clone(), cx);
 615                    servers_to_start.push((server, config));
 616                    if this.servers.contains_key(&id) {
 617                        servers_to_stop.insert(id);
 618                    }
 619                }
 620            }
 621        })?;
 622
 623        this.update(cx, |this, cx| {
 624            for id in servers_to_stop {
 625                this.stop_server(&id, cx)?;
 626            }
 627            for id in servers_to_remove {
 628                this.remove_server(&id, cx)?;
 629            }
 630            for (server, config) in servers_to_start {
 631                this.run_server(server, config, cx);
 632            }
 633            anyhow::Ok(())
 634        })?
 635    }
 636}
 637
 638#[cfg(test)]
 639mod tests {
 640    use super::*;
 641    use crate::{
 642        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 643        project_settings::ProjectSettings,
 644    };
 645    use context_server::test::create_fake_transport;
 646    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 647    use serde_json::json;
 648    use std::{cell::RefCell, path::PathBuf, rc::Rc};
 649    use util::path;
 650
 651    #[gpui::test]
 652    async fn test_context_server_status(cx: &mut TestAppContext) {
 653        const SERVER_1_ID: &str = "mcp-1";
 654        const SERVER_2_ID: &str = "mcp-2";
 655
 656        let (_fs, project) = setup_context_server_test(
 657            cx,
 658            json!({"code.rs": ""}),
 659            vec![
 660                (SERVER_1_ID.into(), dummy_server_settings()),
 661                (SERVER_2_ID.into(), dummy_server_settings()),
 662            ],
 663        )
 664        .await;
 665
 666        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 667        let store = cx.new(|cx| {
 668            ContextServerStore::test(
 669                registry.clone(),
 670                project.read(cx).worktree_store(),
 671                project.downgrade(),
 672                cx,
 673            )
 674        });
 675
 676        let server_1_id = ContextServerId(SERVER_1_ID.into());
 677        let server_2_id = ContextServerId(SERVER_2_ID.into());
 678
 679        let server_1 = Arc::new(ContextServer::new(
 680            server_1_id.clone(),
 681            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 682        ));
 683        let server_2 = Arc::new(ContextServer::new(
 684            server_2_id.clone(),
 685            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 686        ));
 687
 688        store.update(cx, |store, cx| store.start_server(server_1, cx));
 689
 690        cx.run_until_parked();
 691
 692        cx.update(|cx| {
 693            assert_eq!(
 694                store.read(cx).status_for_server(&server_1_id),
 695                Some(ContextServerStatus::Running)
 696            );
 697            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 698        });
 699
 700        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 701
 702        cx.run_until_parked();
 703
 704        cx.update(|cx| {
 705            assert_eq!(
 706                store.read(cx).status_for_server(&server_1_id),
 707                Some(ContextServerStatus::Running)
 708            );
 709            assert_eq!(
 710                store.read(cx).status_for_server(&server_2_id),
 711                Some(ContextServerStatus::Running)
 712            );
 713        });
 714
 715        store
 716            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 717            .unwrap();
 718
 719        cx.update(|cx| {
 720            assert_eq!(
 721                store.read(cx).status_for_server(&server_1_id),
 722                Some(ContextServerStatus::Running)
 723            );
 724            assert_eq!(
 725                store.read(cx).status_for_server(&server_2_id),
 726                Some(ContextServerStatus::Stopped)
 727            );
 728        });
 729    }
 730
 731    #[gpui::test]
 732    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 733        const SERVER_1_ID: &str = "mcp-1";
 734        const SERVER_2_ID: &str = "mcp-2";
 735
 736        let (_fs, project) = setup_context_server_test(
 737            cx,
 738            json!({"code.rs": ""}),
 739            vec![
 740                (SERVER_1_ID.into(), dummy_server_settings()),
 741                (SERVER_2_ID.into(), dummy_server_settings()),
 742            ],
 743        )
 744        .await;
 745
 746        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 747        let store = cx.new(|cx| {
 748            ContextServerStore::test(
 749                registry.clone(),
 750                project.read(cx).worktree_store(),
 751                project.downgrade(),
 752                cx,
 753            )
 754        });
 755
 756        let server_1_id = ContextServerId(SERVER_1_ID.into());
 757        let server_2_id = ContextServerId(SERVER_2_ID.into());
 758
 759        let server_1 = Arc::new(ContextServer::new(
 760            server_1_id.clone(),
 761            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 762        ));
 763        let server_2 = Arc::new(ContextServer::new(
 764            server_2_id.clone(),
 765            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 766        ));
 767
 768        let _server_events = assert_server_events(
 769            &store,
 770            vec![
 771                (server_1_id.clone(), ContextServerStatus::Starting),
 772                (server_1_id, ContextServerStatus::Running),
 773                (server_2_id.clone(), ContextServerStatus::Starting),
 774                (server_2_id.clone(), ContextServerStatus::Running),
 775                (server_2_id.clone(), ContextServerStatus::Stopped),
 776            ],
 777            cx,
 778        );
 779
 780        store.update(cx, |store, cx| store.start_server(server_1, cx));
 781
 782        cx.run_until_parked();
 783
 784        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 785
 786        cx.run_until_parked();
 787
 788        store
 789            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 790            .unwrap();
 791    }
 792
 793    #[gpui::test(iterations = 25)]
 794    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 795        const SERVER_1_ID: &str = "mcp-1";
 796
 797        let (_fs, project) = setup_context_server_test(
 798            cx,
 799            json!({"code.rs": ""}),
 800            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 801        )
 802        .await;
 803
 804        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 805        let store = cx.new(|cx| {
 806            ContextServerStore::test(
 807                registry.clone(),
 808                project.read(cx).worktree_store(),
 809                project.downgrade(),
 810                cx,
 811            )
 812        });
 813
 814        let server_id = ContextServerId(SERVER_1_ID.into());
 815
 816        let server_with_same_id_1 = Arc::new(ContextServer::new(
 817            server_id.clone(),
 818            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 819        ));
 820        let server_with_same_id_2 = Arc::new(ContextServer::new(
 821            server_id.clone(),
 822            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 823        ));
 824
 825        // If we start another server with the same id, we should report that we stopped the previous one
 826        let _server_events = assert_server_events(
 827            &store,
 828            vec![
 829                (server_id.clone(), ContextServerStatus::Starting),
 830                (server_id.clone(), ContextServerStatus::Stopped),
 831                (server_id.clone(), ContextServerStatus::Starting),
 832                (server_id.clone(), ContextServerStatus::Running),
 833            ],
 834            cx,
 835        );
 836
 837        store.update(cx, |store, cx| {
 838            store.start_server(server_with_same_id_1.clone(), cx)
 839        });
 840        store.update(cx, |store, cx| {
 841            store.start_server(server_with_same_id_2.clone(), cx)
 842        });
 843
 844        cx.run_until_parked();
 845
 846        cx.update(|cx| {
 847            assert_eq!(
 848                store.read(cx).status_for_server(&server_id),
 849                Some(ContextServerStatus::Running)
 850            );
 851        });
 852    }
 853
 854    #[gpui::test]
 855    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 856        const SERVER_1_ID: &str = "mcp-1";
 857        const SERVER_2_ID: &str = "mcp-2";
 858
 859        let server_1_id = ContextServerId(SERVER_1_ID.into());
 860        let server_2_id = ContextServerId(SERVER_2_ID.into());
 861
 862        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 863
 864        let (_fs, project) = setup_context_server_test(
 865            cx,
 866            json!({"code.rs": ""}),
 867            vec![(
 868                SERVER_1_ID.into(),
 869                ContextServerSettings::Extension {
 870                    enabled: true,
 871                    settings: json!({
 872                        "somevalue": true
 873                    }),
 874                },
 875            )],
 876        )
 877        .await;
 878
 879        let executor = cx.executor();
 880        let registry = cx.new(|cx| {
 881            let mut registry = ContextServerDescriptorRegistry::new();
 882            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
 883            registry
 884        });
 885        let store = cx.new(|cx| {
 886            ContextServerStore::test_maintain_server_loop(
 887                Box::new(move |id, _| {
 888                    Arc::new(ContextServer::new(
 889                        id.clone(),
 890                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 891                    ))
 892                }),
 893                registry.clone(),
 894                project.read(cx).worktree_store(),
 895                project.downgrade(),
 896                cx,
 897            )
 898        });
 899
 900        // Ensure that mcp-1 starts up
 901        {
 902            let _server_events = assert_server_events(
 903                &store,
 904                vec![
 905                    (server_1_id.clone(), ContextServerStatus::Starting),
 906                    (server_1_id.clone(), ContextServerStatus::Running),
 907                ],
 908                cx,
 909            );
 910            cx.run_until_parked();
 911        }
 912
 913        // Ensure that mcp-1 is restarted when the configuration was changed
 914        {
 915            let _server_events = assert_server_events(
 916                &store,
 917                vec![
 918                    (server_1_id.clone(), ContextServerStatus::Stopped),
 919                    (server_1_id.clone(), ContextServerStatus::Starting),
 920                    (server_1_id.clone(), ContextServerStatus::Running),
 921                ],
 922                cx,
 923            );
 924            set_context_server_configuration(
 925                vec![(
 926                    server_1_id.0.clone(),
 927                    ContextServerSettings::Extension {
 928                        enabled: true,
 929                        settings: json!({
 930                            "somevalue": false
 931                        }),
 932                    },
 933                )],
 934                cx,
 935            );
 936
 937            cx.run_until_parked();
 938        }
 939
 940        // Ensure that mcp-1 is not restarted when the configuration was not changed
 941        {
 942            let _server_events = assert_server_events(&store, vec![], cx);
 943            set_context_server_configuration(
 944                vec![(
 945                    server_1_id.0.clone(),
 946                    ContextServerSettings::Extension {
 947                        enabled: true,
 948                        settings: json!({
 949                            "somevalue": false
 950                        }),
 951                    },
 952                )],
 953                cx,
 954            );
 955
 956            cx.run_until_parked();
 957        }
 958
 959        // Ensure that mcp-2 is started once it is added to the settings
 960        {
 961            let _server_events = assert_server_events(
 962                &store,
 963                vec![
 964                    (server_2_id.clone(), ContextServerStatus::Starting),
 965                    (server_2_id.clone(), ContextServerStatus::Running),
 966                ],
 967                cx,
 968            );
 969            set_context_server_configuration(
 970                vec![
 971                    (
 972                        server_1_id.0.clone(),
 973                        ContextServerSettings::Extension {
 974                            enabled: true,
 975                            settings: json!({
 976                                "somevalue": false
 977                            }),
 978                        },
 979                    ),
 980                    (
 981                        server_2_id.0.clone(),
 982                        ContextServerSettings::Custom {
 983                            enabled: true,
 984                            command: ContextServerCommand {
 985                                path: "somebinary".into(),
 986                                args: vec!["arg".to_string()],
 987                                env: None,
 988                                timeout: None,
 989                            },
 990                        },
 991                    ),
 992                ],
 993                cx,
 994            );
 995
 996            cx.run_until_parked();
 997        }
 998
 999        // Ensure that mcp-2 is restarted once the args have changed
1000        {
1001            let _server_events = assert_server_events(
1002                &store,
1003                vec![
1004                    (server_2_id.clone(), ContextServerStatus::Stopped),
1005                    (server_2_id.clone(), ContextServerStatus::Starting),
1006                    (server_2_id.clone(), ContextServerStatus::Running),
1007                ],
1008                cx,
1009            );
1010            set_context_server_configuration(
1011                vec![
1012                    (
1013                        server_1_id.0.clone(),
1014                        ContextServerSettings::Extension {
1015                            enabled: true,
1016                            settings: json!({
1017                                "somevalue": false
1018                            }),
1019                        },
1020                    ),
1021                    (
1022                        server_2_id.0.clone(),
1023                        ContextServerSettings::Custom {
1024                            enabled: true,
1025                            command: ContextServerCommand {
1026                                path: "somebinary".into(),
1027                                args: vec!["anotherArg".to_string()],
1028                                env: None,
1029                                timeout: None,
1030                            },
1031                        },
1032                    ),
1033                ],
1034                cx,
1035            );
1036
1037            cx.run_until_parked();
1038        }
1039
1040        // Ensure that mcp-2 is removed once it is removed from the settings
1041        {
1042            let _server_events = assert_server_events(
1043                &store,
1044                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1045                cx,
1046            );
1047            set_context_server_configuration(
1048                vec![(
1049                    server_1_id.0.clone(),
1050                    ContextServerSettings::Extension {
1051                        enabled: true,
1052                        settings: json!({
1053                            "somevalue": false
1054                        }),
1055                    },
1056                )],
1057                cx,
1058            );
1059
1060            cx.run_until_parked();
1061
1062            cx.update(|cx| {
1063                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1064            });
1065        }
1066
1067        // Ensure that nothing happens if the settings do not change
1068        {
1069            let _server_events = assert_server_events(&store, vec![], cx);
1070            set_context_server_configuration(
1071                vec![(
1072                    server_1_id.0.clone(),
1073                    ContextServerSettings::Extension {
1074                        enabled: true,
1075                        settings: json!({
1076                            "somevalue": false
1077                        }),
1078                    },
1079                )],
1080                cx,
1081            );
1082
1083            cx.run_until_parked();
1084
1085            cx.update(|cx| {
1086                assert_eq!(
1087                    store.read(cx).status_for_server(&server_1_id),
1088                    Some(ContextServerStatus::Running)
1089                );
1090                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1091            });
1092        }
1093    }
1094
1095    #[gpui::test]
1096    async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1097        const SERVER_1_ID: &str = "mcp-1";
1098
1099        let server_1_id = ContextServerId(SERVER_1_ID.into());
1100
1101        let (_fs, project) = setup_context_server_test(
1102            cx,
1103            json!({"code.rs": ""}),
1104            vec![(
1105                SERVER_1_ID.into(),
1106                ContextServerSettings::Custom {
1107                    enabled: true,
1108                    command: ContextServerCommand {
1109                        path: "somebinary".into(),
1110                        args: vec!["arg".to_string()],
1111                        env: None,
1112                        timeout: None,
1113                    },
1114                },
1115            )],
1116        )
1117        .await;
1118
1119        let executor = cx.executor();
1120        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1121        let store = cx.new(|cx| {
1122            ContextServerStore::test_maintain_server_loop(
1123                Box::new(move |id, _| {
1124                    Arc::new(ContextServer::new(
1125                        id.clone(),
1126                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1127                    ))
1128                }),
1129                registry.clone(),
1130                project.read(cx).worktree_store(),
1131                project.downgrade(),
1132                cx,
1133            )
1134        });
1135
1136        // Ensure that mcp-1 starts up
1137        {
1138            let _server_events = assert_server_events(
1139                &store,
1140                vec![
1141                    (server_1_id.clone(), ContextServerStatus::Starting),
1142                    (server_1_id.clone(), ContextServerStatus::Running),
1143                ],
1144                cx,
1145            );
1146            cx.run_until_parked();
1147        }
1148
1149        // Ensure that mcp-1 is stopped once it is disabled.
1150        {
1151            let _server_events = assert_server_events(
1152                &store,
1153                vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1154                cx,
1155            );
1156            set_context_server_configuration(
1157                vec![(
1158                    server_1_id.0.clone(),
1159                    ContextServerSettings::Custom {
1160                        enabled: false,
1161                        command: ContextServerCommand {
1162                            path: "somebinary".into(),
1163                            args: vec!["arg".to_string()],
1164                            env: None,
1165                            timeout: None,
1166                        },
1167                    },
1168                )],
1169                cx,
1170            );
1171
1172            cx.run_until_parked();
1173        }
1174
1175        // Ensure that mcp-1 is started once it is enabled again.
1176        {
1177            let _server_events = assert_server_events(
1178                &store,
1179                vec![
1180                    (server_1_id.clone(), ContextServerStatus::Starting),
1181                    (server_1_id.clone(), ContextServerStatus::Running),
1182                ],
1183                cx,
1184            );
1185            set_context_server_configuration(
1186                vec![(
1187                    server_1_id.0.clone(),
1188                    ContextServerSettings::Custom {
1189                        enabled: true,
1190                        command: ContextServerCommand {
1191                            path: "somebinary".into(),
1192                            args: vec!["arg".to_string()],
1193                            timeout: None,
1194                            env: None,
1195                        },
1196                    },
1197                )],
1198                cx,
1199            );
1200
1201            cx.run_until_parked();
1202        }
1203    }
1204
1205    fn set_context_server_configuration(
1206        context_servers: Vec<(Arc<str>, ContextServerSettings)>,
1207        cx: &mut TestAppContext,
1208    ) {
1209        cx.update(|cx| {
1210            SettingsStore::update_global(cx, |store, cx| {
1211                let mut settings = ProjectSettings::default();
1212                for (id, config) in context_servers {
1213                    settings.context_servers.insert(id, config);
1214                }
1215                store
1216                    .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
1217                    .unwrap();
1218            })
1219        });
1220    }
1221
1222    struct ServerEvents {
1223        received_event_count: Rc<RefCell<usize>>,
1224        expected_event_count: usize,
1225        _subscription: Subscription,
1226    }
1227
1228    impl Drop for ServerEvents {
1229        fn drop(&mut self) {
1230            let actual_event_count = *self.received_event_count.borrow();
1231            assert_eq!(
1232                actual_event_count, self.expected_event_count,
1233                "
1234                Expected to receive {} context server store events, but received {} events",
1235                self.expected_event_count, actual_event_count
1236            );
1237        }
1238    }
1239
1240    fn dummy_server_settings() -> ContextServerSettings {
1241        ContextServerSettings::Custom {
1242            enabled: true,
1243            command: ContextServerCommand {
1244                path: "somebinary".into(),
1245                args: vec!["arg".to_string()],
1246                env: None,
1247                timeout: None,
1248            },
1249        }
1250    }
1251
1252    fn assert_server_events(
1253        store: &Entity<ContextServerStore>,
1254        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1255        cx: &mut TestAppContext,
1256    ) -> ServerEvents {
1257        cx.update(|cx| {
1258            let mut ix = 0;
1259            let received_event_count = Rc::new(RefCell::new(0));
1260            let expected_event_count = expected_events.len();
1261            let subscription = cx.subscribe(store, {
1262                let received_event_count = received_event_count.clone();
1263                move |_, event, _| match event {
1264                    Event::ServerStatusChanged {
1265                        server_id: actual_server_id,
1266                        status: actual_status,
1267                    } => {
1268                        let (expected_server_id, expected_status) = &expected_events[ix];
1269
1270                        assert_eq!(
1271                            actual_server_id, expected_server_id,
1272                            "Expected different server id at index {}",
1273                            ix
1274                        );
1275                        assert_eq!(
1276                            actual_status, expected_status,
1277                            "Expected different status at index {}",
1278                            ix
1279                        );
1280                        ix += 1;
1281                        *received_event_count.borrow_mut() += 1;
1282                    }
1283                }
1284            });
1285            ServerEvents {
1286                expected_event_count,
1287                received_event_count,
1288                _subscription: subscription,
1289            }
1290        })
1291    }
1292
1293    async fn setup_context_server_test(
1294        cx: &mut TestAppContext,
1295        files: serde_json::Value,
1296        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1297    ) -> (Arc<FakeFs>, Entity<Project>) {
1298        cx.update(|cx| {
1299            let settings_store = SettingsStore::test(cx);
1300            cx.set_global(settings_store);
1301            Project::init_settings(cx);
1302            let mut settings = ProjectSettings::get_global(cx).clone();
1303            for (id, config) in context_server_configurations {
1304                settings.context_servers.insert(id, config);
1305            }
1306            ProjectSettings::override_global(settings, cx);
1307        });
1308
1309        let fs = FakeFs::new(cx.executor());
1310        fs.insert_tree(path!("/test"), files).await;
1311        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1312
1313        (fs, project)
1314    }
1315
1316    struct FakeContextServerDescriptor {
1317        path: PathBuf,
1318    }
1319
1320    impl FakeContextServerDescriptor {
1321        fn new(path: impl Into<PathBuf>) -> Self {
1322            Self { path: path.into() }
1323        }
1324    }
1325
1326    impl ContextServerDescriptor for FakeContextServerDescriptor {
1327        fn command(
1328            &self,
1329            _worktree_store: Entity<WorktreeStore>,
1330            _cx: &AsyncApp,
1331        ) -> Task<Result<ContextServerCommand>> {
1332            Task::ready(Ok(ContextServerCommand {
1333                path: self.path.clone(),
1334                args: vec!["arg1".to_string(), "arg2".to_string()],
1335                env: None,
1336                timeout: None,
1337            }))
1338        }
1339
1340        fn configuration(
1341            &self,
1342            _worktree_store: Entity<WorktreeStore>,
1343            _cx: &AsyncApp,
1344        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1345            Task::ready(Ok(None))
1346        }
1347    }
1348}