1pub mod extension;
   2pub mod registry;
   3
   4use std::sync::Arc;
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use futures::{FutureExt as _, future::join_all};
  10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  11use registry::ContextServerDescriptorRegistry;
  12use settings::{Settings as _, SettingsStore};
  13use util::{ResultExt as _, rel_path::RelPath};
  14
  15use crate::{
  16    Project,
  17    project_settings::{ContextServerSettings, ProjectSettings},
  18    worktree_store::WorktreeStore,
  19};
  20
  21pub fn init(cx: &mut App) {
  22    extension::init(cx);
  23}
  24
  25actions!(
  26    context_server,
  27    [
  28        /// Restarts the context server.
  29        Restart
  30    ]
  31);
  32
  33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  34pub enum ContextServerStatus {
  35    Starting,
  36    Running,
  37    Stopped,
  38    Error(Arc<str>),
  39}
  40
  41impl ContextServerStatus {
  42    fn from_state(state: &ContextServerState) -> Self {
  43        match state {
  44            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  45            ContextServerState::Running { .. } => ContextServerStatus::Running,
  46            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  47            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  48        }
  49    }
  50}
  51
  52enum ContextServerState {
  53    Starting {
  54        server: Arc<ContextServer>,
  55        configuration: Arc<ContextServerConfiguration>,
  56        _task: Task<()>,
  57    },
  58    Running {
  59        server: Arc<ContextServer>,
  60        configuration: Arc<ContextServerConfiguration>,
  61    },
  62    Stopped {
  63        server: Arc<ContextServer>,
  64        configuration: Arc<ContextServerConfiguration>,
  65    },
  66    Error {
  67        server: Arc<ContextServer>,
  68        configuration: Arc<ContextServerConfiguration>,
  69        error: Arc<str>,
  70    },
  71}
  72
  73impl ContextServerState {
  74    pub fn server(&self) -> Arc<ContextServer> {
  75        match self {
  76            ContextServerState::Starting { server, .. } => server.clone(),
  77            ContextServerState::Running { server, .. } => server.clone(),
  78            ContextServerState::Stopped { server, .. } => server.clone(),
  79            ContextServerState::Error { server, .. } => server.clone(),
  80        }
  81    }
  82
  83    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  84        match self {
  85            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  86            ContextServerState::Running { configuration, .. } => configuration.clone(),
  87            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  88            ContextServerState::Error { configuration, .. } => configuration.clone(),
  89        }
  90    }
  91}
  92
  93#[derive(Debug, PartialEq, Eq)]
  94pub enum ContextServerConfiguration {
  95    Custom {
  96        command: ContextServerCommand,
  97    },
  98    Extension {
  99        command: ContextServerCommand,
 100        settings: serde_json::Value,
 101    },
 102}
 103
 104impl ContextServerConfiguration {
 105    pub fn command(&self) -> &ContextServerCommand {
 106        match self {
 107            ContextServerConfiguration::Custom { command } => command,
 108            ContextServerConfiguration::Extension { command, .. } => command,
 109        }
 110    }
 111
 112    pub async fn from_settings(
 113        settings: ContextServerSettings,
 114        id: ContextServerId,
 115        registry: Entity<ContextServerDescriptorRegistry>,
 116        worktree_store: Entity<WorktreeStore>,
 117        cx: &AsyncApp,
 118    ) -> Option<Self> {
 119        match settings {
 120            ContextServerSettings::Custom {
 121                enabled: _,
 122                command,
 123            } => Some(ContextServerConfiguration::Custom { command }),
 124            ContextServerSettings::Extension {
 125                enabled: _,
 126                settings,
 127            } => {
 128                let descriptor = cx
 129                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 130                    .ok()
 131                    .flatten()?;
 132
 133                let command = descriptor.command(worktree_store, cx).await.log_err()?;
 134
 135                Some(ContextServerConfiguration::Extension { command, settings })
 136            }
 137        }
 138    }
 139}
 140
 141pub type ContextServerFactory =
 142    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 143
 144pub struct ContextServerStore {
 145    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 146    servers: HashMap<ContextServerId, ContextServerState>,
 147    worktree_store: Entity<WorktreeStore>,
 148    project: WeakEntity<Project>,
 149    registry: Entity<ContextServerDescriptorRegistry>,
 150    update_servers_task: Option<Task<Result<()>>>,
 151    context_server_factory: Option<ContextServerFactory>,
 152    needs_server_update: bool,
 153    _subscriptions: Vec<Subscription>,
 154}
 155
 156pub enum Event {
 157    ServerStatusChanged {
 158        server_id: ContextServerId,
 159        status: ContextServerStatus,
 160    },
 161}
 162
 163impl EventEmitter<Event> for ContextServerStore {}
 164
 165impl ContextServerStore {
 166    pub fn new(
 167        worktree_store: Entity<WorktreeStore>,
 168        weak_project: WeakEntity<Project>,
 169        cx: &mut Context<Self>,
 170    ) -> Self {
 171        Self::new_internal(
 172            true,
 173            None,
 174            ContextServerDescriptorRegistry::default_global(cx),
 175            worktree_store,
 176            weak_project,
 177            cx,
 178        )
 179    }
 180
 181    /// Returns all configured context server ids, regardless of enabled state.
 182    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 183        self.context_server_settings
 184            .keys()
 185            .cloned()
 186            .map(ContextServerId)
 187            .collect()
 188    }
 189
 190    #[cfg(any(test, feature = "test-support"))]
 191    pub fn test(
 192        registry: Entity<ContextServerDescriptorRegistry>,
 193        worktree_store: Entity<WorktreeStore>,
 194        weak_project: WeakEntity<Project>,
 195        cx: &mut Context<Self>,
 196    ) -> Self {
 197        Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
 198    }
 199
 200    #[cfg(any(test, feature = "test-support"))]
 201    pub fn test_maintain_server_loop(
 202        context_server_factory: ContextServerFactory,
 203        registry: Entity<ContextServerDescriptorRegistry>,
 204        worktree_store: Entity<WorktreeStore>,
 205        weak_project: WeakEntity<Project>,
 206        cx: &mut Context<Self>,
 207    ) -> Self {
 208        Self::new_internal(
 209            true,
 210            Some(context_server_factory),
 211            registry,
 212            worktree_store,
 213            weak_project,
 214            cx,
 215        )
 216    }
 217
 218    fn new_internal(
 219        maintain_server_loop: bool,
 220        context_server_factory: Option<ContextServerFactory>,
 221        registry: Entity<ContextServerDescriptorRegistry>,
 222        worktree_store: Entity<WorktreeStore>,
 223        weak_project: WeakEntity<Project>,
 224        cx: &mut Context<Self>,
 225    ) -> Self {
 226        let subscriptions = if maintain_server_loop {
 227            vec![
 228                cx.observe(®istry, |this, _registry, cx| {
 229                    this.available_context_servers_changed(cx);
 230                }),
 231                cx.observe_global::<SettingsStore>(|this, cx| {
 232                    let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
 233                    if &this.context_server_settings == settings {
 234                        return;
 235                    }
 236                    this.context_server_settings = settings.clone();
 237                    this.available_context_servers_changed(cx);
 238                }),
 239            ]
 240        } else {
 241            Vec::new()
 242        };
 243
 244        let mut this = Self {
 245            _subscriptions: subscriptions,
 246            context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
 247                .clone(),
 248            worktree_store,
 249            project: weak_project,
 250            registry,
 251            needs_server_update: false,
 252            servers: HashMap::default(),
 253            update_servers_task: None,
 254            context_server_factory,
 255        };
 256        if maintain_server_loop {
 257            this.available_context_servers_changed(cx);
 258        }
 259        this
 260    }
 261
 262    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 263        self.servers.get(id).map(|state| state.server())
 264    }
 265
 266    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 267        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 268            Some(server.clone())
 269        } else {
 270            None
 271        }
 272    }
 273
 274    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 275        self.servers.get(id).map(ContextServerStatus::from_state)
 276    }
 277
 278    pub fn configuration_for_server(
 279        &self,
 280        id: &ContextServerId,
 281    ) -> Option<Arc<ContextServerConfiguration>> {
 282        self.servers.get(id).map(|state| state.configuration())
 283    }
 284
 285    pub fn server_ids(&self, cx: &App) -> HashSet<ContextServerId> {
 286        self.servers
 287            .keys()
 288            .cloned()
 289            .chain(
 290                self.registry
 291                    .read(cx)
 292                    .context_server_descriptors()
 293                    .into_iter()
 294                    .map(|(id, _)| ContextServerId(id)),
 295            )
 296            .collect()
 297    }
 298
 299    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 300        self.servers
 301            .values()
 302            .filter_map(|state| {
 303                if let ContextServerState::Running { server, .. } = state {
 304                    Some(server.clone())
 305                } else {
 306                    None
 307                }
 308            })
 309            .collect()
 310    }
 311
 312    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 313        cx.spawn(async move |this, cx| {
 314            let this = this.upgrade().context("Context server store dropped")?;
 315            let settings = this
 316                .update(cx, |this, _| {
 317                    this.context_server_settings.get(&server.id().0).cloned()
 318                })
 319                .ok()
 320                .flatten()
 321                .context("Failed to get context server settings")?;
 322
 323            if !settings.enabled() {
 324                return Ok(());
 325            }
 326
 327            let (registry, worktree_store) = this.update(cx, |this, _| {
 328                (this.registry.clone(), this.worktree_store.clone())
 329            })?;
 330            let configuration = ContextServerConfiguration::from_settings(
 331                settings,
 332                server.id(),
 333                registry,
 334                worktree_store,
 335                cx,
 336            )
 337            .await
 338            .context("Failed to create context server configuration")?;
 339
 340            this.update(cx, |this, cx| {
 341                this.run_server(server, Arc::new(configuration), cx)
 342            })
 343        })
 344        .detach_and_log_err(cx);
 345    }
 346
 347    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 348        if matches!(
 349            self.servers.get(id),
 350            Some(ContextServerState::Stopped { .. })
 351        ) {
 352            return Ok(());
 353        }
 354
 355        let state = self
 356            .servers
 357            .remove(id)
 358            .context("Context server not found")?;
 359
 360        let server = state.server();
 361        let configuration = state.configuration();
 362        let mut result = Ok(());
 363        if let ContextServerState::Running { server, .. } = &state {
 364            result = server.stop();
 365        }
 366        drop(state);
 367
 368        self.update_server_state(
 369            id.clone(),
 370            ContextServerState::Stopped {
 371                configuration,
 372                server,
 373            },
 374            cx,
 375        );
 376
 377        result
 378    }
 379
 380    pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 381        if let Some(state) = self.servers.get(id) {
 382            let configuration = state.configuration();
 383
 384            self.stop_server(&state.server().id(), cx)?;
 385            let new_server = self.create_context_server(id.clone(), configuration.clone(), cx);
 386            self.run_server(new_server, configuration, cx);
 387        }
 388        Ok(())
 389    }
 390
 391    fn run_server(
 392        &mut self,
 393        server: Arc<ContextServer>,
 394        configuration: Arc<ContextServerConfiguration>,
 395        cx: &mut Context<Self>,
 396    ) {
 397        let id = server.id();
 398        if matches!(
 399            self.servers.get(&id),
 400            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 401        ) {
 402            self.stop_server(&id, cx).log_err();
 403        }
 404
 405        let task = cx.spawn({
 406            let id = server.id();
 407            let server = server.clone();
 408            let configuration = configuration.clone();
 409            async move |this, cx| {
 410                match server.clone().start(cx).await {
 411                    Ok(_) => {
 412                        debug_assert!(server.client().is_some());
 413
 414                        this.update(cx, |this, cx| {
 415                            this.update_server_state(
 416                                id.clone(),
 417                                ContextServerState::Running {
 418                                    server,
 419                                    configuration,
 420                                },
 421                                cx,
 422                            )
 423                        })
 424                        .log_err()
 425                    }
 426                    Err(err) => {
 427                        log::error!("{} context server failed to start: {}", id, err);
 428                        this.update(cx, |this, cx| {
 429                            this.update_server_state(
 430                                id.clone(),
 431                                ContextServerState::Error {
 432                                    configuration,
 433                                    server,
 434                                    error: err.to_string().into(),
 435                                },
 436                                cx,
 437                            )
 438                        })
 439                        .log_err()
 440                    }
 441                };
 442            }
 443        });
 444
 445        self.update_server_state(
 446            id.clone(),
 447            ContextServerState::Starting {
 448                configuration,
 449                _task: task,
 450                server,
 451            },
 452            cx,
 453        );
 454    }
 455
 456    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 457        let state = self
 458            .servers
 459            .remove(id)
 460            .context("Context server not found")?;
 461        drop(state);
 462        cx.emit(Event::ServerStatusChanged {
 463            server_id: id.clone(),
 464            status: ContextServerStatus::Stopped,
 465        });
 466        Ok(())
 467    }
 468
 469    fn create_context_server(
 470        &self,
 471        id: ContextServerId,
 472        configuration: Arc<ContextServerConfiguration>,
 473        cx: &mut Context<Self>,
 474    ) -> Arc<ContextServer> {
 475        let project = self.project.upgrade();
 476        let mut root_path = None;
 477        if let Some(project) = project {
 478            let project = project.read(cx);
 479            if project.is_local() {
 480                if let Some(path) = project.active_project_directory(cx) {
 481                    root_path = Some(path);
 482                } else {
 483                    for worktree in self.worktree_store.read(cx).visible_worktrees(cx) {
 484                        if let Some(path) = worktree.read(cx).root_dir() {
 485                            root_path = Some(path);
 486                            break;
 487                        }
 488                    }
 489                }
 490            }
 491        };
 492
 493        if let Some(factory) = self.context_server_factory.as_ref() {
 494            factory(id, configuration)
 495        } else {
 496            Arc::new(ContextServer::stdio(
 497                id,
 498                configuration.command().clone(),
 499                root_path,
 500            ))
 501        }
 502    }
 503
 504    fn resolve_context_server_settings<'a>(
 505        worktree_store: &'a Entity<WorktreeStore>,
 506        cx: &'a App,
 507    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 508        let location = worktree_store
 509            .read(cx)
 510            .visible_worktrees(cx)
 511            .next()
 512            .map(|worktree| settings::SettingsLocation {
 513                worktree_id: worktree.read(cx).id(),
 514                path: RelPath::empty(),
 515            });
 516        &ProjectSettings::get(location, cx).context_servers
 517    }
 518
 519    fn update_server_state(
 520        &mut self,
 521        id: ContextServerId,
 522        state: ContextServerState,
 523        cx: &mut Context<Self>,
 524    ) {
 525        let status = ContextServerStatus::from_state(&state);
 526        self.servers.insert(id.clone(), state);
 527        cx.emit(Event::ServerStatusChanged {
 528            server_id: id,
 529            status,
 530        });
 531    }
 532
 533    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 534        if self.update_servers_task.is_some() {
 535            self.needs_server_update = true;
 536        } else {
 537            self.needs_server_update = false;
 538            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 539                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 540                    log::error!("Error maintaining context servers: {}", err);
 541                }
 542
 543                this.update(cx, |this, cx| {
 544                    this.update_servers_task.take();
 545                    if this.needs_server_update {
 546                        this.available_context_servers_changed(cx);
 547                    }
 548                })?;
 549
 550                Ok(())
 551            }));
 552        }
 553    }
 554
 555    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 556        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
 557            (
 558                this.context_server_settings.clone(),
 559                this.registry.clone(),
 560                this.worktree_store.clone(),
 561            )
 562        })?;
 563
 564        for (id, _) in
 565            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 566        {
 567            configured_servers
 568                .entry(id)
 569                .or_insert(ContextServerSettings::default_extension());
 570        }
 571
 572        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
 573            configured_servers
 574                .into_iter()
 575                .partition(|(_, settings)| settings.enabled());
 576
 577        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
 578            let id = ContextServerId(id);
 579            ContextServerConfiguration::from_settings(
 580                settings,
 581                id.clone(),
 582                registry.clone(),
 583                worktree_store.clone(),
 584                cx,
 585            )
 586            .map(|config| (id, config))
 587        }))
 588        .await
 589        .into_iter()
 590        .filter_map(|(id, config)| config.map(|config| (id, config)))
 591        .collect::<HashMap<_, _>>();
 592
 593        let mut servers_to_start = Vec::new();
 594        let mut servers_to_remove = HashSet::default();
 595        let mut servers_to_stop = HashSet::default();
 596
 597        this.update(cx, |this, cx| {
 598            for server_id in this.servers.keys() {
 599                // All servers that are not in desired_servers should be removed from the store.
 600                // This can happen if the user removed a server from the context server settings.
 601                if !configured_servers.contains_key(server_id) {
 602                    if disabled_servers.contains_key(&server_id.0) {
 603                        servers_to_stop.insert(server_id.clone());
 604                    } else {
 605                        servers_to_remove.insert(server_id.clone());
 606                    }
 607                }
 608            }
 609
 610            for (id, config) in configured_servers {
 611                let state = this.servers.get(&id);
 612                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
 613                let existing_config = state.as_ref().map(|state| state.configuration());
 614                if existing_config.as_deref() != Some(&config) || is_stopped {
 615                    let config = Arc::new(config);
 616                    let server = this.create_context_server(id.clone(), config.clone(), cx);
 617                    servers_to_start.push((server, config));
 618                    if this.servers.contains_key(&id) {
 619                        servers_to_stop.insert(id);
 620                    }
 621                }
 622            }
 623        })?;
 624
 625        this.update(cx, |this, cx| {
 626            for id in servers_to_stop {
 627                this.stop_server(&id, cx)?;
 628            }
 629            for id in servers_to_remove {
 630                this.remove_server(&id, cx)?;
 631            }
 632            for (server, config) in servers_to_start {
 633                this.run_server(server, config, cx);
 634            }
 635            anyhow::Ok(())
 636        })?
 637    }
 638}
 639
 640#[cfg(test)]
 641mod tests {
 642    use super::*;
 643    use crate::{
 644        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 645        project_settings::ProjectSettings,
 646    };
 647    use context_server::test::create_fake_transport;
 648    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 649    use serde_json::json;
 650    use std::{cell::RefCell, path::PathBuf, rc::Rc};
 651    use util::path;
 652
 653    #[gpui::test]
 654    async fn test_context_server_status(cx: &mut TestAppContext) {
 655        const SERVER_1_ID: &str = "mcp-1";
 656        const SERVER_2_ID: &str = "mcp-2";
 657
 658        let (_fs, project) = setup_context_server_test(
 659            cx,
 660            json!({"code.rs": ""}),
 661            vec![
 662                (SERVER_1_ID.into(), dummy_server_settings()),
 663                (SERVER_2_ID.into(), dummy_server_settings()),
 664            ],
 665        )
 666        .await;
 667
 668        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 669        let store = cx.new(|cx| {
 670            ContextServerStore::test(
 671                registry.clone(),
 672                project.read(cx).worktree_store(),
 673                project.downgrade(),
 674                cx,
 675            )
 676        });
 677
 678        let server_1_id = ContextServerId(SERVER_1_ID.into());
 679        let server_2_id = ContextServerId(SERVER_2_ID.into());
 680
 681        let server_1 = Arc::new(ContextServer::new(
 682            server_1_id.clone(),
 683            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 684        ));
 685        let server_2 = Arc::new(ContextServer::new(
 686            server_2_id.clone(),
 687            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 688        ));
 689
 690        store.update(cx, |store, cx| store.start_server(server_1, cx));
 691
 692        cx.run_until_parked();
 693
 694        cx.update(|cx| {
 695            assert_eq!(
 696                store.read(cx).status_for_server(&server_1_id),
 697                Some(ContextServerStatus::Running)
 698            );
 699            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 700        });
 701
 702        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 703
 704        cx.run_until_parked();
 705
 706        cx.update(|cx| {
 707            assert_eq!(
 708                store.read(cx).status_for_server(&server_1_id),
 709                Some(ContextServerStatus::Running)
 710            );
 711            assert_eq!(
 712                store.read(cx).status_for_server(&server_2_id),
 713                Some(ContextServerStatus::Running)
 714            );
 715        });
 716
 717        store
 718            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 719            .unwrap();
 720
 721        cx.update(|cx| {
 722            assert_eq!(
 723                store.read(cx).status_for_server(&server_1_id),
 724                Some(ContextServerStatus::Running)
 725            );
 726            assert_eq!(
 727                store.read(cx).status_for_server(&server_2_id),
 728                Some(ContextServerStatus::Stopped)
 729            );
 730        });
 731    }
 732
 733    #[gpui::test]
 734    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 735        const SERVER_1_ID: &str = "mcp-1";
 736        const SERVER_2_ID: &str = "mcp-2";
 737
 738        let (_fs, project) = setup_context_server_test(
 739            cx,
 740            json!({"code.rs": ""}),
 741            vec![
 742                (SERVER_1_ID.into(), dummy_server_settings()),
 743                (SERVER_2_ID.into(), dummy_server_settings()),
 744            ],
 745        )
 746        .await;
 747
 748        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 749        let store = cx.new(|cx| {
 750            ContextServerStore::test(
 751                registry.clone(),
 752                project.read(cx).worktree_store(),
 753                project.downgrade(),
 754                cx,
 755            )
 756        });
 757
 758        let server_1_id = ContextServerId(SERVER_1_ID.into());
 759        let server_2_id = ContextServerId(SERVER_2_ID.into());
 760
 761        let server_1 = Arc::new(ContextServer::new(
 762            server_1_id.clone(),
 763            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 764        ));
 765        let server_2 = Arc::new(ContextServer::new(
 766            server_2_id.clone(),
 767            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 768        ));
 769
 770        let _server_events = assert_server_events(
 771            &store,
 772            vec![
 773                (server_1_id.clone(), ContextServerStatus::Starting),
 774                (server_1_id, ContextServerStatus::Running),
 775                (server_2_id.clone(), ContextServerStatus::Starting),
 776                (server_2_id.clone(), ContextServerStatus::Running),
 777                (server_2_id.clone(), ContextServerStatus::Stopped),
 778            ],
 779            cx,
 780        );
 781
 782        store.update(cx, |store, cx| store.start_server(server_1, cx));
 783
 784        cx.run_until_parked();
 785
 786        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 787
 788        cx.run_until_parked();
 789
 790        store
 791            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 792            .unwrap();
 793    }
 794
 795    #[gpui::test(iterations = 25)]
 796    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 797        const SERVER_1_ID: &str = "mcp-1";
 798
 799        let (_fs, project) = setup_context_server_test(
 800            cx,
 801            json!({"code.rs": ""}),
 802            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 803        )
 804        .await;
 805
 806        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 807        let store = cx.new(|cx| {
 808            ContextServerStore::test(
 809                registry.clone(),
 810                project.read(cx).worktree_store(),
 811                project.downgrade(),
 812                cx,
 813            )
 814        });
 815
 816        let server_id = ContextServerId(SERVER_1_ID.into());
 817
 818        let server_with_same_id_1 = Arc::new(ContextServer::new(
 819            server_id.clone(),
 820            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 821        ));
 822        let server_with_same_id_2 = Arc::new(ContextServer::new(
 823            server_id.clone(),
 824            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 825        ));
 826
 827        // If we start another server with the same id, we should report that we stopped the previous one
 828        let _server_events = assert_server_events(
 829            &store,
 830            vec![
 831                (server_id.clone(), ContextServerStatus::Starting),
 832                (server_id.clone(), ContextServerStatus::Stopped),
 833                (server_id.clone(), ContextServerStatus::Starting),
 834                (server_id.clone(), ContextServerStatus::Running),
 835            ],
 836            cx,
 837        );
 838
 839        store.update(cx, |store, cx| {
 840            store.start_server(server_with_same_id_1.clone(), cx)
 841        });
 842        store.update(cx, |store, cx| {
 843            store.start_server(server_with_same_id_2.clone(), cx)
 844        });
 845
 846        cx.run_until_parked();
 847
 848        cx.update(|cx| {
 849            assert_eq!(
 850                store.read(cx).status_for_server(&server_id),
 851                Some(ContextServerStatus::Running)
 852            );
 853        });
 854    }
 855
 856    #[gpui::test]
 857    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 858        const SERVER_1_ID: &str = "mcp-1";
 859        const SERVER_2_ID: &str = "mcp-2";
 860
 861        let server_1_id = ContextServerId(SERVER_1_ID.into());
 862        let server_2_id = ContextServerId(SERVER_2_ID.into());
 863
 864        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 865
 866        let (_fs, project) = setup_context_server_test(
 867            cx,
 868            json!({"code.rs": ""}),
 869            vec![(
 870                SERVER_1_ID.into(),
 871                ContextServerSettings::Extension {
 872                    enabled: true,
 873                    settings: json!({
 874                        "somevalue": true
 875                    }),
 876                },
 877            )],
 878        )
 879        .await;
 880
 881        let executor = cx.executor();
 882        let registry = cx.new(|cx| {
 883            let mut registry = ContextServerDescriptorRegistry::new();
 884            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
 885            registry
 886        });
 887        let store = cx.new(|cx| {
 888            ContextServerStore::test_maintain_server_loop(
 889                Box::new(move |id, _| {
 890                    Arc::new(ContextServer::new(
 891                        id.clone(),
 892                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 893                    ))
 894                }),
 895                registry.clone(),
 896                project.read(cx).worktree_store(),
 897                project.downgrade(),
 898                cx,
 899            )
 900        });
 901
 902        // Ensure that mcp-1 starts up
 903        {
 904            let _server_events = assert_server_events(
 905                &store,
 906                vec![
 907                    (server_1_id.clone(), ContextServerStatus::Starting),
 908                    (server_1_id.clone(), ContextServerStatus::Running),
 909                ],
 910                cx,
 911            );
 912            cx.run_until_parked();
 913        }
 914
 915        // Ensure that mcp-1 is restarted when the configuration was changed
 916        {
 917            let _server_events = assert_server_events(
 918                &store,
 919                vec![
 920                    (server_1_id.clone(), ContextServerStatus::Stopped),
 921                    (server_1_id.clone(), ContextServerStatus::Starting),
 922                    (server_1_id.clone(), ContextServerStatus::Running),
 923                ],
 924                cx,
 925            );
 926            set_context_server_configuration(
 927                vec![(
 928                    server_1_id.0.clone(),
 929                    settings::ContextServerSettingsContent::Extension {
 930                        enabled: true,
 931                        settings: json!({
 932                            "somevalue": false
 933                        }),
 934                    },
 935                )],
 936                cx,
 937            );
 938
 939            cx.run_until_parked();
 940        }
 941
 942        // Ensure that mcp-1 is not restarted when the configuration was not changed
 943        {
 944            let _server_events = assert_server_events(&store, vec![], cx);
 945            set_context_server_configuration(
 946                vec![(
 947                    server_1_id.0.clone(),
 948                    settings::ContextServerSettingsContent::Extension {
 949                        enabled: true,
 950                        settings: json!({
 951                            "somevalue": false
 952                        }),
 953                    },
 954                )],
 955                cx,
 956            );
 957
 958            cx.run_until_parked();
 959        }
 960
 961        // Ensure that mcp-2 is started once it is added to the settings
 962        {
 963            let _server_events = assert_server_events(
 964                &store,
 965                vec![
 966                    (server_2_id.clone(), ContextServerStatus::Starting),
 967                    (server_2_id.clone(), ContextServerStatus::Running),
 968                ],
 969                cx,
 970            );
 971            set_context_server_configuration(
 972                vec![
 973                    (
 974                        server_1_id.0.clone(),
 975                        settings::ContextServerSettingsContent::Extension {
 976                            enabled: true,
 977                            settings: json!({
 978                                "somevalue": false
 979                            }),
 980                        },
 981                    ),
 982                    (
 983                        server_2_id.0.clone(),
 984                        settings::ContextServerSettingsContent::Custom {
 985                            enabled: true,
 986                            command: ContextServerCommand {
 987                                path: "somebinary".into(),
 988                                args: vec!["arg".to_string()],
 989                                env: None,
 990                                timeout: None,
 991                            },
 992                        },
 993                    ),
 994                ],
 995                cx,
 996            );
 997
 998            cx.run_until_parked();
 999        }
1000
1001        // Ensure that mcp-2 is restarted once the args have changed
1002        {
1003            let _server_events = assert_server_events(
1004                &store,
1005                vec![
1006                    (server_2_id.clone(), ContextServerStatus::Stopped),
1007                    (server_2_id.clone(), ContextServerStatus::Starting),
1008                    (server_2_id.clone(), ContextServerStatus::Running),
1009                ],
1010                cx,
1011            );
1012            set_context_server_configuration(
1013                vec![
1014                    (
1015                        server_1_id.0.clone(),
1016                        settings::ContextServerSettingsContent::Extension {
1017                            enabled: true,
1018                            settings: json!({
1019                                "somevalue": false
1020                            }),
1021                        },
1022                    ),
1023                    (
1024                        server_2_id.0.clone(),
1025                        settings::ContextServerSettingsContent::Custom {
1026                            enabled: true,
1027                            command: ContextServerCommand {
1028                                path: "somebinary".into(),
1029                                args: vec!["anotherArg".to_string()],
1030                                env: None,
1031                                timeout: None,
1032                            },
1033                        },
1034                    ),
1035                ],
1036                cx,
1037            );
1038
1039            cx.run_until_parked();
1040        }
1041
1042        // Ensure that mcp-2 is removed once it is removed from the settings
1043        {
1044            let _server_events = assert_server_events(
1045                &store,
1046                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1047                cx,
1048            );
1049            set_context_server_configuration(
1050                vec![(
1051                    server_1_id.0.clone(),
1052                    settings::ContextServerSettingsContent::Extension {
1053                        enabled: true,
1054                        settings: json!({
1055                            "somevalue": false
1056                        }),
1057                    },
1058                )],
1059                cx,
1060            );
1061
1062            cx.run_until_parked();
1063
1064            cx.update(|cx| {
1065                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1066            });
1067        }
1068
1069        // Ensure that nothing happens if the settings do not change
1070        {
1071            let _server_events = assert_server_events(&store, vec![], cx);
1072            set_context_server_configuration(
1073                vec![(
1074                    server_1_id.0.clone(),
1075                    settings::ContextServerSettingsContent::Extension {
1076                        enabled: true,
1077                        settings: json!({
1078                            "somevalue": false
1079                        }),
1080                    },
1081                )],
1082                cx,
1083            );
1084
1085            cx.run_until_parked();
1086
1087            cx.update(|cx| {
1088                assert_eq!(
1089                    store.read(cx).status_for_server(&server_1_id),
1090                    Some(ContextServerStatus::Running)
1091                );
1092                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1093            });
1094        }
1095    }
1096
1097    #[gpui::test]
1098    async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1099        const SERVER_1_ID: &str = "mcp-1";
1100
1101        let server_1_id = ContextServerId(SERVER_1_ID.into());
1102
1103        let (_fs, project) = setup_context_server_test(
1104            cx,
1105            json!({"code.rs": ""}),
1106            vec![(
1107                SERVER_1_ID.into(),
1108                ContextServerSettings::Custom {
1109                    enabled: true,
1110                    command: ContextServerCommand {
1111                        path: "somebinary".into(),
1112                        args: vec!["arg".to_string()],
1113                        env: None,
1114                        timeout: None,
1115                    },
1116                },
1117            )],
1118        )
1119        .await;
1120
1121        let executor = cx.executor();
1122        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1123        let store = cx.new(|cx| {
1124            ContextServerStore::test_maintain_server_loop(
1125                Box::new(move |id, _| {
1126                    Arc::new(ContextServer::new(
1127                        id.clone(),
1128                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1129                    ))
1130                }),
1131                registry.clone(),
1132                project.read(cx).worktree_store(),
1133                project.downgrade(),
1134                cx,
1135            )
1136        });
1137
1138        // Ensure that mcp-1 starts up
1139        {
1140            let _server_events = assert_server_events(
1141                &store,
1142                vec![
1143                    (server_1_id.clone(), ContextServerStatus::Starting),
1144                    (server_1_id.clone(), ContextServerStatus::Running),
1145                ],
1146                cx,
1147            );
1148            cx.run_until_parked();
1149        }
1150
1151        // Ensure that mcp-1 is stopped once it is disabled.
1152        {
1153            let _server_events = assert_server_events(
1154                &store,
1155                vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1156                cx,
1157            );
1158            set_context_server_configuration(
1159                vec![(
1160                    server_1_id.0.clone(),
1161                    settings::ContextServerSettingsContent::Custom {
1162                        enabled: false,
1163                        command: ContextServerCommand {
1164                            path: "somebinary".into(),
1165                            args: vec!["arg".to_string()],
1166                            env: None,
1167                            timeout: None,
1168                        },
1169                    },
1170                )],
1171                cx,
1172            );
1173
1174            cx.run_until_parked();
1175        }
1176
1177        // Ensure that mcp-1 is started once it is enabled again.
1178        {
1179            let _server_events = assert_server_events(
1180                &store,
1181                vec![
1182                    (server_1_id.clone(), ContextServerStatus::Starting),
1183                    (server_1_id.clone(), ContextServerStatus::Running),
1184                ],
1185                cx,
1186            );
1187            set_context_server_configuration(
1188                vec![(
1189                    server_1_id.0.clone(),
1190                    settings::ContextServerSettingsContent::Custom {
1191                        enabled: true,
1192                        command: ContextServerCommand {
1193                            path: "somebinary".into(),
1194                            args: vec!["arg".to_string()],
1195                            timeout: None,
1196                            env: None,
1197                        },
1198                    },
1199                )],
1200                cx,
1201            );
1202
1203            cx.run_until_parked();
1204        }
1205    }
1206
1207    fn set_context_server_configuration(
1208        context_servers: Vec<(Arc<str>, settings::ContextServerSettingsContent)>,
1209        cx: &mut TestAppContext,
1210    ) {
1211        cx.update(|cx| {
1212            SettingsStore::update_global(cx, |store, cx| {
1213                store.update_user_settings(cx, |content| {
1214                    content.project.context_servers.clear();
1215                    for (id, config) in context_servers {
1216                        content.project.context_servers.insert(id, config);
1217                    }
1218                });
1219            })
1220        });
1221    }
1222
1223    struct ServerEvents {
1224        received_event_count: Rc<RefCell<usize>>,
1225        expected_event_count: usize,
1226        _subscription: Subscription,
1227    }
1228
1229    impl Drop for ServerEvents {
1230        fn drop(&mut self) {
1231            let actual_event_count = *self.received_event_count.borrow();
1232            assert_eq!(
1233                actual_event_count, self.expected_event_count,
1234                "
1235                Expected to receive {} context server store events, but received {} events",
1236                self.expected_event_count, actual_event_count
1237            );
1238        }
1239    }
1240
1241    fn dummy_server_settings() -> ContextServerSettings {
1242        ContextServerSettings::Custom {
1243            enabled: true,
1244            command: ContextServerCommand {
1245                path: "somebinary".into(),
1246                args: vec!["arg".to_string()],
1247                env: None,
1248                timeout: None,
1249            },
1250        }
1251    }
1252
1253    fn assert_server_events(
1254        store: &Entity<ContextServerStore>,
1255        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1256        cx: &mut TestAppContext,
1257    ) -> ServerEvents {
1258        cx.update(|cx| {
1259            let mut ix = 0;
1260            let received_event_count = Rc::new(RefCell::new(0));
1261            let expected_event_count = expected_events.len();
1262            let subscription = cx.subscribe(store, {
1263                let received_event_count = received_event_count.clone();
1264                move |_, event, _| match event {
1265                    Event::ServerStatusChanged {
1266                        server_id: actual_server_id,
1267                        status: actual_status,
1268                    } => {
1269                        let (expected_server_id, expected_status) = &expected_events[ix];
1270
1271                        assert_eq!(
1272                            actual_server_id, expected_server_id,
1273                            "Expected different server id at index {}",
1274                            ix
1275                        );
1276                        assert_eq!(
1277                            actual_status, expected_status,
1278                            "Expected different status at index {}",
1279                            ix
1280                        );
1281                        ix += 1;
1282                        *received_event_count.borrow_mut() += 1;
1283                    }
1284                }
1285            });
1286            ServerEvents {
1287                expected_event_count,
1288                received_event_count,
1289                _subscription: subscription,
1290            }
1291        })
1292    }
1293
1294    async fn setup_context_server_test(
1295        cx: &mut TestAppContext,
1296        files: serde_json::Value,
1297        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1298    ) -> (Arc<FakeFs>, Entity<Project>) {
1299        cx.update(|cx| {
1300            let settings_store = SettingsStore::test(cx);
1301            cx.set_global(settings_store);
1302            Project::init_settings(cx);
1303            let mut settings = ProjectSettings::get_global(cx).clone();
1304            for (id, config) in context_server_configurations {
1305                settings.context_servers.insert(id, config);
1306            }
1307            ProjectSettings::override_global(settings, cx);
1308        });
1309
1310        let fs = FakeFs::new(cx.executor());
1311        fs.insert_tree(path!("/test"), files).await;
1312        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1313
1314        (fs, project)
1315    }
1316
1317    struct FakeContextServerDescriptor {
1318        path: PathBuf,
1319    }
1320
1321    impl FakeContextServerDescriptor {
1322        fn new(path: impl Into<PathBuf>) -> Self {
1323            Self { path: path.into() }
1324        }
1325    }
1326
1327    impl ContextServerDescriptor for FakeContextServerDescriptor {
1328        fn command(
1329            &self,
1330            _worktree_store: Entity<WorktreeStore>,
1331            _cx: &AsyncApp,
1332        ) -> Task<Result<ContextServerCommand>> {
1333            Task::ready(Ok(ContextServerCommand {
1334                path: self.path.clone(),
1335                args: vec!["arg1".to_string(), "arg2".to_string()],
1336                env: None,
1337                timeout: None,
1338            }))
1339        }
1340
1341        fn configuration(
1342            &self,
1343            _worktree_store: Entity<WorktreeStore>,
1344            _cx: &AsyncApp,
1345        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1346            Task::ready(Ok(None))
1347        }
1348    }
1349}