context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::sync::Arc;
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use futures::{FutureExt as _, future::join_all};
  10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  11use registry::ContextServerDescriptorRegistry;
  12use settings::{Settings as _, SettingsStore};
  13use util::{ResultExt as _, rel_path::RelPath};
  14
  15use crate::{
  16    Project,
  17    project_settings::{ContextServerSettings, ProjectSettings},
  18    trusted_worktrees::wait_for_workspace_trust,
  19    worktree_store::WorktreeStore,
  20};
  21
  22pub fn init(cx: &mut App) {
  23    extension::init(cx);
  24}
  25
  26actions!(
  27    context_server,
  28    [
  29        /// Restarts the context server.
  30        Restart
  31    ]
  32);
  33
  34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  35pub enum ContextServerStatus {
  36    Starting,
  37    Running,
  38    Stopped,
  39    Error(Arc<str>),
  40}
  41
  42impl ContextServerStatus {
  43    fn from_state(state: &ContextServerState) -> Self {
  44        match state {
  45            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  46            ContextServerState::Running { .. } => ContextServerStatus::Running,
  47            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  48            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  49        }
  50    }
  51}
  52
  53enum ContextServerState {
  54    Starting {
  55        server: Arc<ContextServer>,
  56        configuration: Arc<ContextServerConfiguration>,
  57        _task: Task<()>,
  58    },
  59    Running {
  60        server: Arc<ContextServer>,
  61        configuration: Arc<ContextServerConfiguration>,
  62    },
  63    Stopped {
  64        server: Arc<ContextServer>,
  65        configuration: Arc<ContextServerConfiguration>,
  66    },
  67    Error {
  68        server: Arc<ContextServer>,
  69        configuration: Arc<ContextServerConfiguration>,
  70        error: Arc<str>,
  71    },
  72}
  73
  74impl ContextServerState {
  75    pub fn server(&self) -> Arc<ContextServer> {
  76        match self {
  77            ContextServerState::Starting { server, .. } => server.clone(),
  78            ContextServerState::Running { server, .. } => server.clone(),
  79            ContextServerState::Stopped { server, .. } => server.clone(),
  80            ContextServerState::Error { server, .. } => server.clone(),
  81        }
  82    }
  83
  84    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  85        match self {
  86            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  87            ContextServerState::Running { configuration, .. } => configuration.clone(),
  88            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  89            ContextServerState::Error { configuration, .. } => configuration.clone(),
  90        }
  91    }
  92}
  93
  94#[derive(Debug, PartialEq, Eq)]
  95pub enum ContextServerConfiguration {
  96    Custom {
  97        command: ContextServerCommand,
  98    },
  99    Extension {
 100        command: ContextServerCommand,
 101        settings: serde_json::Value,
 102    },
 103    Http {
 104        url: url::Url,
 105        headers: HashMap<String, String>,
 106    },
 107}
 108
 109impl ContextServerConfiguration {
 110    pub fn command(&self) -> Option<&ContextServerCommand> {
 111        match self {
 112            ContextServerConfiguration::Custom { command } => Some(command),
 113            ContextServerConfiguration::Extension { command, .. } => Some(command),
 114            ContextServerConfiguration::Http { .. } => None,
 115        }
 116    }
 117
 118    pub async fn from_settings(
 119        settings: ContextServerSettings,
 120        id: ContextServerId,
 121        registry: Entity<ContextServerDescriptorRegistry>,
 122        worktree_store: Entity<WorktreeStore>,
 123        cx: &AsyncApp,
 124    ) -> Option<Self> {
 125        match settings {
 126            ContextServerSettings::Stdio {
 127                enabled: _,
 128                command,
 129            } => Some(ContextServerConfiguration::Custom { command }),
 130            ContextServerSettings::Extension {
 131                enabled: _,
 132                settings,
 133            } => {
 134                let descriptor = cx
 135                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 136                    .ok()
 137                    .flatten()?;
 138
 139                match descriptor.command(worktree_store, cx).await {
 140                    Ok(command) => {
 141                        Some(ContextServerConfiguration::Extension { command, settings })
 142                    }
 143                    Err(e) => {
 144                        log::error!(
 145                            "Failed to create context server configuration from settings: {e:#}"
 146                        );
 147                        None
 148                    }
 149                }
 150            }
 151            ContextServerSettings::Http {
 152                enabled: _,
 153                url,
 154                headers: auth,
 155            } => {
 156                let url = url::Url::parse(&url).log_err()?;
 157                Some(ContextServerConfiguration::Http { url, headers: auth })
 158            }
 159        }
 160    }
 161}
 162
 163pub type ContextServerFactory =
 164    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 165
 166pub struct ContextServerStore {
 167    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 168    servers: HashMap<ContextServerId, ContextServerState>,
 169    worktree_store: Entity<WorktreeStore>,
 170    project: WeakEntity<Project>,
 171    registry: Entity<ContextServerDescriptorRegistry>,
 172    update_servers_task: Option<Task<Result<()>>>,
 173    context_server_factory: Option<ContextServerFactory>,
 174    needs_server_update: bool,
 175    _subscriptions: Vec<Subscription>,
 176}
 177
 178pub enum Event {
 179    ServerStatusChanged {
 180        server_id: ContextServerId,
 181        status: ContextServerStatus,
 182    },
 183}
 184
 185impl EventEmitter<Event> for ContextServerStore {}
 186
 187impl ContextServerStore {
 188    pub fn new(
 189        worktree_store: Entity<WorktreeStore>,
 190        weak_project: WeakEntity<Project>,
 191        cx: &mut Context<Self>,
 192    ) -> Self {
 193        Self::new_internal(
 194            true,
 195            None,
 196            ContextServerDescriptorRegistry::default_global(cx),
 197            worktree_store,
 198            weak_project,
 199            cx,
 200        )
 201    }
 202
 203    /// Returns all configured context server ids, excluding the ones that are disabled
 204    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 205        self.context_server_settings
 206            .iter()
 207            .filter(|(_, settings)| settings.enabled())
 208            .map(|(id, _)| ContextServerId(id.clone()))
 209            .collect()
 210    }
 211
 212    #[cfg(any(test, feature = "test-support"))]
 213    pub fn test(
 214        registry: Entity<ContextServerDescriptorRegistry>,
 215        worktree_store: Entity<WorktreeStore>,
 216        weak_project: WeakEntity<Project>,
 217        cx: &mut Context<Self>,
 218    ) -> Self {
 219        Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
 220    }
 221
 222    #[cfg(any(test, feature = "test-support"))]
 223    pub fn test_maintain_server_loop(
 224        context_server_factory: Option<ContextServerFactory>,
 225        registry: Entity<ContextServerDescriptorRegistry>,
 226        worktree_store: Entity<WorktreeStore>,
 227        weak_project: WeakEntity<Project>,
 228        cx: &mut Context<Self>,
 229    ) -> Self {
 230        Self::new_internal(
 231            true,
 232            context_server_factory,
 233            registry,
 234            worktree_store,
 235            weak_project,
 236            cx,
 237        )
 238    }
 239
 240    fn new_internal(
 241        maintain_server_loop: bool,
 242        context_server_factory: Option<ContextServerFactory>,
 243        registry: Entity<ContextServerDescriptorRegistry>,
 244        worktree_store: Entity<WorktreeStore>,
 245        weak_project: WeakEntity<Project>,
 246        cx: &mut Context<Self>,
 247    ) -> Self {
 248        let subscriptions = if maintain_server_loop {
 249            vec![
 250                cx.observe(&registry, |this, _registry, cx| {
 251                    this.available_context_servers_changed(cx);
 252                }),
 253                cx.observe_global::<SettingsStore>(|this, cx| {
 254                    let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
 255                    if &this.context_server_settings == settings {
 256                        return;
 257                    }
 258                    this.context_server_settings = settings.clone();
 259                    this.available_context_servers_changed(cx);
 260                }),
 261            ]
 262        } else {
 263            Vec::new()
 264        };
 265
 266        let mut this = Self {
 267            _subscriptions: subscriptions,
 268            context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
 269                .clone(),
 270            worktree_store,
 271            project: weak_project,
 272            registry,
 273            needs_server_update: false,
 274            servers: HashMap::default(),
 275            update_servers_task: None,
 276            context_server_factory,
 277        };
 278        if maintain_server_loop {
 279            this.available_context_servers_changed(cx);
 280        }
 281        this
 282    }
 283
 284    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 285        self.servers.get(id).map(|state| state.server())
 286    }
 287
 288    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 289        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 290            Some(server.clone())
 291        } else {
 292            None
 293        }
 294    }
 295
 296    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 297        self.servers.get(id).map(ContextServerStatus::from_state)
 298    }
 299
 300    pub fn configuration_for_server(
 301        &self,
 302        id: &ContextServerId,
 303    ) -> Option<Arc<ContextServerConfiguration>> {
 304        self.servers.get(id).map(|state| state.configuration())
 305    }
 306
 307    pub fn server_ids(&self, cx: &App) -> HashSet<ContextServerId> {
 308        self.servers
 309            .keys()
 310            .cloned()
 311            .chain(
 312                self.registry
 313                    .read(cx)
 314                    .context_server_descriptors()
 315                    .into_iter()
 316                    .map(|(id, _)| ContextServerId(id)),
 317            )
 318            .collect()
 319    }
 320
 321    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 322        self.servers
 323            .values()
 324            .filter_map(|state| {
 325                if let ContextServerState::Running { server, .. } = state {
 326                    Some(server.clone())
 327                } else {
 328                    None
 329                }
 330            })
 331            .collect()
 332    }
 333
 334    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 335        cx.spawn(async move |this, cx| {
 336            let wait_task = this.update(cx, |context_server_store, cx| {
 337                context_server_store.project.update(cx, |project, cx| {
 338                    let remote_host = project.remote_connection_options(cx);
 339                    wait_for_workspace_trust(remote_host, "context servers", cx)
 340                })
 341            })??;
 342            if let Some(wait_task) = wait_task {
 343                wait_task.await;
 344            }
 345            let this = this.upgrade().context("Context server store dropped")?;
 346            let settings = this
 347                .update(cx, |this, _| {
 348                    this.context_server_settings.get(&server.id().0).cloned()
 349                })
 350                .ok()
 351                .flatten()
 352                .context("Failed to get context server settings")?;
 353
 354            if !settings.enabled() {
 355                return Ok(());
 356            }
 357
 358            let (registry, worktree_store) = this.update(cx, |this, _| {
 359                (this.registry.clone(), this.worktree_store.clone())
 360            })?;
 361            let configuration = ContextServerConfiguration::from_settings(
 362                settings,
 363                server.id(),
 364                registry,
 365                worktree_store,
 366                cx,
 367            )
 368            .await
 369            .context("Failed to create context server configuration")?;
 370
 371            this.update(cx, |this, cx| {
 372                this.run_server(server, Arc::new(configuration), cx)
 373            })
 374        })
 375        .detach_and_log_err(cx);
 376    }
 377
 378    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 379        if matches!(
 380            self.servers.get(id),
 381            Some(ContextServerState::Stopped { .. })
 382        ) {
 383            return Ok(());
 384        }
 385
 386        let state = self
 387            .servers
 388            .remove(id)
 389            .context("Context server not found")?;
 390
 391        let server = state.server();
 392        let configuration = state.configuration();
 393        let mut result = Ok(());
 394        if let ContextServerState::Running { server, .. } = &state {
 395            result = server.stop();
 396        }
 397        drop(state);
 398
 399        self.update_server_state(
 400            id.clone(),
 401            ContextServerState::Stopped {
 402                configuration,
 403                server,
 404            },
 405            cx,
 406        );
 407
 408        result
 409    }
 410
 411    fn run_server(
 412        &mut self,
 413        server: Arc<ContextServer>,
 414        configuration: Arc<ContextServerConfiguration>,
 415        cx: &mut Context<Self>,
 416    ) {
 417        let id = server.id();
 418        if matches!(
 419            self.servers.get(&id),
 420            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 421        ) {
 422            self.stop_server(&id, cx).log_err();
 423        }
 424        let task = cx.spawn({
 425            let id = server.id();
 426            let server = server.clone();
 427            let configuration = configuration.clone();
 428
 429            async move |this, cx| {
 430                match server.clone().start(cx).await {
 431                    Ok(_) => {
 432                        debug_assert!(server.client().is_some());
 433
 434                        this.update(cx, |this, cx| {
 435                            this.update_server_state(
 436                                id.clone(),
 437                                ContextServerState::Running {
 438                                    server,
 439                                    configuration,
 440                                },
 441                                cx,
 442                            )
 443                        })
 444                        .log_err()
 445                    }
 446                    Err(err) => {
 447                        log::error!("{} context server failed to start: {}", id, err);
 448                        this.update(cx, |this, cx| {
 449                            this.update_server_state(
 450                                id.clone(),
 451                                ContextServerState::Error {
 452                                    configuration,
 453                                    server,
 454                                    error: err.to_string().into(),
 455                                },
 456                                cx,
 457                            )
 458                        })
 459                        .log_err()
 460                    }
 461                };
 462            }
 463        });
 464
 465        self.update_server_state(
 466            id.clone(),
 467            ContextServerState::Starting {
 468                configuration,
 469                _task: task,
 470                server,
 471            },
 472            cx,
 473        );
 474    }
 475
 476    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 477        let state = self
 478            .servers
 479            .remove(id)
 480            .context("Context server not found")?;
 481        drop(state);
 482        cx.emit(Event::ServerStatusChanged {
 483            server_id: id.clone(),
 484            status: ContextServerStatus::Stopped,
 485        });
 486        Ok(())
 487    }
 488
 489    fn create_context_server(
 490        &self,
 491        id: ContextServerId,
 492        configuration: Arc<ContextServerConfiguration>,
 493        cx: &mut Context<Self>,
 494    ) -> Result<Arc<ContextServer>> {
 495        if let Some(factory) = self.context_server_factory.as_ref() {
 496            return Ok(factory(id, configuration));
 497        }
 498
 499        match configuration.as_ref() {
 500            ContextServerConfiguration::Http { url, headers } => Ok(Arc::new(ContextServer::http(
 501                id,
 502                url,
 503                headers.clone(),
 504                cx.http_client(),
 505                cx.background_executor().clone(),
 506            )?)),
 507            _ => {
 508                let root_path = self
 509                    .project
 510                    .read_with(cx, |project, cx| project.active_project_directory(cx))
 511                    .ok()
 512                    .flatten()
 513                    .or_else(|| {
 514                        self.worktree_store.read_with(cx, |store, cx| {
 515                            store.visible_worktrees(cx).fold(None, |acc, item| {
 516                                if acc.is_none() {
 517                                    item.read(cx).root_dir()
 518                                } else {
 519                                    acc
 520                                }
 521                            })
 522                        })
 523                    });
 524                Ok(Arc::new(ContextServer::stdio(
 525                    id,
 526                    configuration.command().unwrap().clone(),
 527                    root_path,
 528                )))
 529            }
 530        }
 531    }
 532
 533    fn resolve_context_server_settings<'a>(
 534        worktree_store: &'a Entity<WorktreeStore>,
 535        cx: &'a App,
 536    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 537        let location = worktree_store
 538            .read(cx)
 539            .visible_worktrees(cx)
 540            .next()
 541            .map(|worktree| settings::SettingsLocation {
 542                worktree_id: worktree.read(cx).id(),
 543                path: RelPath::empty(),
 544            });
 545        &ProjectSettings::get(location, cx).context_servers
 546    }
 547
 548    fn update_server_state(
 549        &mut self,
 550        id: ContextServerId,
 551        state: ContextServerState,
 552        cx: &mut Context<Self>,
 553    ) {
 554        let status = ContextServerStatus::from_state(&state);
 555        self.servers.insert(id.clone(), state);
 556        cx.emit(Event::ServerStatusChanged {
 557            server_id: id,
 558            status,
 559        });
 560    }
 561
 562    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 563        if self.update_servers_task.is_some() {
 564            self.needs_server_update = true;
 565        } else {
 566            self.needs_server_update = false;
 567            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 568                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 569                    log::error!("Error maintaining context servers: {}", err);
 570                }
 571
 572                this.update(cx, |this, cx| {
 573                    this.update_servers_task.take();
 574                    if this.needs_server_update {
 575                        this.available_context_servers_changed(cx);
 576                    }
 577                })?;
 578
 579                Ok(())
 580            }));
 581        }
 582    }
 583
 584    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 585        let wait_task = this.update(cx, |context_server_store, cx| {
 586            context_server_store.project.update(cx, |project, cx| {
 587                let remote_host = project.remote_connection_options(cx);
 588                wait_for_workspace_trust(remote_host, "context servers", cx)
 589            })
 590        })??;
 591        if let Some(wait_task) = wait_task {
 592            wait_task.await;
 593        }
 594        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
 595            (
 596                this.context_server_settings.clone(),
 597                this.registry.clone(),
 598                this.worktree_store.clone(),
 599            )
 600        })?;
 601
 602        for (id, _) in
 603            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 604        {
 605            configured_servers
 606                .entry(id)
 607                .or_insert(ContextServerSettings::default_extension());
 608        }
 609
 610        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
 611            configured_servers
 612                .into_iter()
 613                .partition(|(_, settings)| settings.enabled());
 614
 615        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
 616            let id = ContextServerId(id);
 617            ContextServerConfiguration::from_settings(
 618                settings,
 619                id.clone(),
 620                registry.clone(),
 621                worktree_store.clone(),
 622                cx,
 623            )
 624            .map(|config| (id, config))
 625        }))
 626        .await
 627        .into_iter()
 628        .filter_map(|(id, config)| config.map(|config| (id, config)))
 629        .collect::<HashMap<_, _>>();
 630
 631        let mut servers_to_start = Vec::new();
 632        let mut servers_to_remove = HashSet::default();
 633        let mut servers_to_stop = HashSet::default();
 634
 635        this.update(cx, |this, cx| {
 636            for server_id in this.servers.keys() {
 637                // All servers that are not in desired_servers should be removed from the store.
 638                // This can happen if the user removed a server from the context server settings.
 639                if !configured_servers.contains_key(server_id) {
 640                    if disabled_servers.contains_key(&server_id.0) {
 641                        servers_to_stop.insert(server_id.clone());
 642                    } else {
 643                        servers_to_remove.insert(server_id.clone());
 644                    }
 645                }
 646            }
 647
 648            for (id, config) in configured_servers {
 649                let state = this.servers.get(&id);
 650                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
 651                let existing_config = state.as_ref().map(|state| state.configuration());
 652                if existing_config.as_deref() != Some(&config) || is_stopped {
 653                    let config = Arc::new(config);
 654                    let server = this.create_context_server(id.clone(), config.clone(), cx)?;
 655                    servers_to_start.push((server, config));
 656                    if this.servers.contains_key(&id) {
 657                        servers_to_stop.insert(id);
 658                    }
 659                }
 660            }
 661
 662            anyhow::Ok(())
 663        })??;
 664
 665        this.update(cx, |this, cx| {
 666            for id in servers_to_stop {
 667                this.stop_server(&id, cx)?;
 668            }
 669            for id in servers_to_remove {
 670                this.remove_server(&id, cx)?;
 671            }
 672            for (server, config) in servers_to_start {
 673                this.run_server(server, config, cx);
 674            }
 675            anyhow::Ok(())
 676        })?
 677    }
 678}
 679
 680#[cfg(test)]
 681mod tests {
 682    use super::*;
 683    use crate::{
 684        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 685        project_settings::ProjectSettings,
 686    };
 687    use context_server::test::create_fake_transport;
 688    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 689    use http_client::{FakeHttpClient, Response};
 690    use serde_json::json;
 691    use std::{cell::RefCell, path::PathBuf, rc::Rc};
 692    use util::path;
 693
 694    #[gpui::test]
 695    async fn test_context_server_status(cx: &mut TestAppContext) {
 696        const SERVER_1_ID: &str = "mcp-1";
 697        const SERVER_2_ID: &str = "mcp-2";
 698
 699        let (_fs, project) = setup_context_server_test(
 700            cx,
 701            json!({"code.rs": ""}),
 702            vec![
 703                (SERVER_1_ID.into(), dummy_server_settings()),
 704                (SERVER_2_ID.into(), dummy_server_settings()),
 705            ],
 706        )
 707        .await;
 708
 709        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 710        let store = cx.new(|cx| {
 711            ContextServerStore::test(
 712                registry.clone(),
 713                project.read(cx).worktree_store(),
 714                project.downgrade(),
 715                cx,
 716            )
 717        });
 718
 719        let server_1_id = ContextServerId(SERVER_1_ID.into());
 720        let server_2_id = ContextServerId(SERVER_2_ID.into());
 721
 722        let server_1 = Arc::new(ContextServer::new(
 723            server_1_id.clone(),
 724            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 725        ));
 726        let server_2 = Arc::new(ContextServer::new(
 727            server_2_id.clone(),
 728            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 729        ));
 730
 731        store.update(cx, |store, cx| store.start_server(server_1, cx));
 732
 733        cx.run_until_parked();
 734
 735        cx.update(|cx| {
 736            assert_eq!(
 737                store.read(cx).status_for_server(&server_1_id),
 738                Some(ContextServerStatus::Running)
 739            );
 740            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 741        });
 742
 743        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 744
 745        cx.run_until_parked();
 746
 747        cx.update(|cx| {
 748            assert_eq!(
 749                store.read(cx).status_for_server(&server_1_id),
 750                Some(ContextServerStatus::Running)
 751            );
 752            assert_eq!(
 753                store.read(cx).status_for_server(&server_2_id),
 754                Some(ContextServerStatus::Running)
 755            );
 756        });
 757
 758        store
 759            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 760            .unwrap();
 761
 762        cx.update(|cx| {
 763            assert_eq!(
 764                store.read(cx).status_for_server(&server_1_id),
 765                Some(ContextServerStatus::Running)
 766            );
 767            assert_eq!(
 768                store.read(cx).status_for_server(&server_2_id),
 769                Some(ContextServerStatus::Stopped)
 770            );
 771        });
 772    }
 773
 774    #[gpui::test]
 775    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 776        const SERVER_1_ID: &str = "mcp-1";
 777        const SERVER_2_ID: &str = "mcp-2";
 778
 779        let (_fs, project) = setup_context_server_test(
 780            cx,
 781            json!({"code.rs": ""}),
 782            vec![
 783                (SERVER_1_ID.into(), dummy_server_settings()),
 784                (SERVER_2_ID.into(), dummy_server_settings()),
 785            ],
 786        )
 787        .await;
 788
 789        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 790        let store = cx.new(|cx| {
 791            ContextServerStore::test(
 792                registry.clone(),
 793                project.read(cx).worktree_store(),
 794                project.downgrade(),
 795                cx,
 796            )
 797        });
 798
 799        let server_1_id = ContextServerId(SERVER_1_ID.into());
 800        let server_2_id = ContextServerId(SERVER_2_ID.into());
 801
 802        let server_1 = Arc::new(ContextServer::new(
 803            server_1_id.clone(),
 804            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 805        ));
 806        let server_2 = Arc::new(ContextServer::new(
 807            server_2_id.clone(),
 808            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 809        ));
 810
 811        let _server_events = assert_server_events(
 812            &store,
 813            vec![
 814                (server_1_id.clone(), ContextServerStatus::Starting),
 815                (server_1_id, ContextServerStatus::Running),
 816                (server_2_id.clone(), ContextServerStatus::Starting),
 817                (server_2_id.clone(), ContextServerStatus::Running),
 818                (server_2_id.clone(), ContextServerStatus::Stopped),
 819            ],
 820            cx,
 821        );
 822
 823        store.update(cx, |store, cx| store.start_server(server_1, cx));
 824
 825        cx.run_until_parked();
 826
 827        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 828
 829        cx.run_until_parked();
 830
 831        store
 832            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 833            .unwrap();
 834    }
 835
 836    #[gpui::test(iterations = 25)]
 837    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 838        const SERVER_1_ID: &str = "mcp-1";
 839
 840        let (_fs, project) = setup_context_server_test(
 841            cx,
 842            json!({"code.rs": ""}),
 843            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 844        )
 845        .await;
 846
 847        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 848        let store = cx.new(|cx| {
 849            ContextServerStore::test(
 850                registry.clone(),
 851                project.read(cx).worktree_store(),
 852                project.downgrade(),
 853                cx,
 854            )
 855        });
 856
 857        let server_id = ContextServerId(SERVER_1_ID.into());
 858
 859        let server_with_same_id_1 = Arc::new(ContextServer::new(
 860            server_id.clone(),
 861            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 862        ));
 863        let server_with_same_id_2 = Arc::new(ContextServer::new(
 864            server_id.clone(),
 865            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 866        ));
 867
 868        // If we start another server with the same id, we should report that we stopped the previous one
 869        let _server_events = assert_server_events(
 870            &store,
 871            vec![
 872                (server_id.clone(), ContextServerStatus::Starting),
 873                (server_id.clone(), ContextServerStatus::Stopped),
 874                (server_id.clone(), ContextServerStatus::Starting),
 875                (server_id.clone(), ContextServerStatus::Running),
 876            ],
 877            cx,
 878        );
 879
 880        store.update(cx, |store, cx| {
 881            store.start_server(server_with_same_id_1.clone(), cx)
 882        });
 883        store.update(cx, |store, cx| {
 884            store.start_server(server_with_same_id_2.clone(), cx)
 885        });
 886
 887        cx.run_until_parked();
 888
 889        cx.update(|cx| {
 890            assert_eq!(
 891                store.read(cx).status_for_server(&server_id),
 892                Some(ContextServerStatus::Running)
 893            );
 894        });
 895    }
 896
 897    #[gpui::test]
 898    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 899        const SERVER_1_ID: &str = "mcp-1";
 900        const SERVER_2_ID: &str = "mcp-2";
 901
 902        let server_1_id = ContextServerId(SERVER_1_ID.into());
 903        let server_2_id = ContextServerId(SERVER_2_ID.into());
 904
 905        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 906
 907        let (_fs, project) = setup_context_server_test(
 908            cx,
 909            json!({"code.rs": ""}),
 910            vec![(
 911                SERVER_1_ID.into(),
 912                ContextServerSettings::Extension {
 913                    enabled: true,
 914                    settings: json!({
 915                        "somevalue": true
 916                    }),
 917                },
 918            )],
 919        )
 920        .await;
 921
 922        let executor = cx.executor();
 923        let registry = cx.new(|cx| {
 924            let mut registry = ContextServerDescriptorRegistry::new();
 925            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
 926            registry
 927        });
 928        let store = cx.new(|cx| {
 929            ContextServerStore::test_maintain_server_loop(
 930                Some(Box::new(move |id, _| {
 931                    Arc::new(ContextServer::new(
 932                        id.clone(),
 933                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 934                    ))
 935                })),
 936                registry.clone(),
 937                project.read(cx).worktree_store(),
 938                project.downgrade(),
 939                cx,
 940            )
 941        });
 942
 943        // Ensure that mcp-1 starts up
 944        {
 945            let _server_events = assert_server_events(
 946                &store,
 947                vec![
 948                    (server_1_id.clone(), ContextServerStatus::Starting),
 949                    (server_1_id.clone(), ContextServerStatus::Running),
 950                ],
 951                cx,
 952            );
 953            cx.run_until_parked();
 954        }
 955
 956        // Ensure that mcp-1 is restarted when the configuration was changed
 957        {
 958            let _server_events = assert_server_events(
 959                &store,
 960                vec![
 961                    (server_1_id.clone(), ContextServerStatus::Stopped),
 962                    (server_1_id.clone(), ContextServerStatus::Starting),
 963                    (server_1_id.clone(), ContextServerStatus::Running),
 964                ],
 965                cx,
 966            );
 967            set_context_server_configuration(
 968                vec![(
 969                    server_1_id.0.clone(),
 970                    settings::ContextServerSettingsContent::Extension {
 971                        enabled: true,
 972                        settings: json!({
 973                            "somevalue": false
 974                        }),
 975                    },
 976                )],
 977                cx,
 978            );
 979
 980            cx.run_until_parked();
 981        }
 982
 983        // Ensure that mcp-1 is not restarted when the configuration was not changed
 984        {
 985            let _server_events = assert_server_events(&store, vec![], cx);
 986            set_context_server_configuration(
 987                vec![(
 988                    server_1_id.0.clone(),
 989                    settings::ContextServerSettingsContent::Extension {
 990                        enabled: true,
 991                        settings: json!({
 992                            "somevalue": false
 993                        }),
 994                    },
 995                )],
 996                cx,
 997            );
 998
 999            cx.run_until_parked();
1000        }
1001
1002        // Ensure that mcp-2 is started once it is added to the settings
1003        {
1004            let _server_events = assert_server_events(
1005                &store,
1006                vec![
1007                    (server_2_id.clone(), ContextServerStatus::Starting),
1008                    (server_2_id.clone(), ContextServerStatus::Running),
1009                ],
1010                cx,
1011            );
1012            set_context_server_configuration(
1013                vec![
1014                    (
1015                        server_1_id.0.clone(),
1016                        settings::ContextServerSettingsContent::Extension {
1017                            enabled: true,
1018                            settings: json!({
1019                                "somevalue": false
1020                            }),
1021                        },
1022                    ),
1023                    (
1024                        server_2_id.0.clone(),
1025                        settings::ContextServerSettingsContent::Stdio {
1026                            enabled: true,
1027                            command: ContextServerCommand {
1028                                path: "somebinary".into(),
1029                                args: vec!["arg".to_string()],
1030                                env: None,
1031                                timeout: None,
1032                            },
1033                        },
1034                    ),
1035                ],
1036                cx,
1037            );
1038
1039            cx.run_until_parked();
1040        }
1041
1042        // Ensure that mcp-2 is restarted once the args have changed
1043        {
1044            let _server_events = assert_server_events(
1045                &store,
1046                vec![
1047                    (server_2_id.clone(), ContextServerStatus::Stopped),
1048                    (server_2_id.clone(), ContextServerStatus::Starting),
1049                    (server_2_id.clone(), ContextServerStatus::Running),
1050                ],
1051                cx,
1052            );
1053            set_context_server_configuration(
1054                vec![
1055                    (
1056                        server_1_id.0.clone(),
1057                        settings::ContextServerSettingsContent::Extension {
1058                            enabled: true,
1059                            settings: json!({
1060                                "somevalue": false
1061                            }),
1062                        },
1063                    ),
1064                    (
1065                        server_2_id.0.clone(),
1066                        settings::ContextServerSettingsContent::Stdio {
1067                            enabled: true,
1068                            command: ContextServerCommand {
1069                                path: "somebinary".into(),
1070                                args: vec!["anotherArg".to_string()],
1071                                env: None,
1072                                timeout: None,
1073                            },
1074                        },
1075                    ),
1076                ],
1077                cx,
1078            );
1079
1080            cx.run_until_parked();
1081        }
1082
1083        // Ensure that mcp-2 is removed once it is removed from the settings
1084        {
1085            let _server_events = assert_server_events(
1086                &store,
1087                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1088                cx,
1089            );
1090            set_context_server_configuration(
1091                vec![(
1092                    server_1_id.0.clone(),
1093                    settings::ContextServerSettingsContent::Extension {
1094                        enabled: true,
1095                        settings: json!({
1096                            "somevalue": false
1097                        }),
1098                    },
1099                )],
1100                cx,
1101            );
1102
1103            cx.run_until_parked();
1104
1105            cx.update(|cx| {
1106                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1107            });
1108        }
1109
1110        // Ensure that nothing happens if the settings do not change
1111        {
1112            let _server_events = assert_server_events(&store, vec![], cx);
1113            set_context_server_configuration(
1114                vec![(
1115                    server_1_id.0.clone(),
1116                    settings::ContextServerSettingsContent::Extension {
1117                        enabled: true,
1118                        settings: json!({
1119                            "somevalue": false
1120                        }),
1121                    },
1122                )],
1123                cx,
1124            );
1125
1126            cx.run_until_parked();
1127
1128            cx.update(|cx| {
1129                assert_eq!(
1130                    store.read(cx).status_for_server(&server_1_id),
1131                    Some(ContextServerStatus::Running)
1132                );
1133                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1134            });
1135        }
1136    }
1137
1138    #[gpui::test]
1139    async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1140        const SERVER_1_ID: &str = "mcp-1";
1141
1142        let server_1_id = ContextServerId(SERVER_1_ID.into());
1143
1144        let (_fs, project) = setup_context_server_test(
1145            cx,
1146            json!({"code.rs": ""}),
1147            vec![(
1148                SERVER_1_ID.into(),
1149                ContextServerSettings::Stdio {
1150                    enabled: true,
1151                    command: ContextServerCommand {
1152                        path: "somebinary".into(),
1153                        args: vec!["arg".to_string()],
1154                        env: None,
1155                        timeout: None,
1156                    },
1157                },
1158            )],
1159        )
1160        .await;
1161
1162        let executor = cx.executor();
1163        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1164        let store = cx.new(|cx| {
1165            ContextServerStore::test_maintain_server_loop(
1166                Some(Box::new(move |id, _| {
1167                    Arc::new(ContextServer::new(
1168                        id.clone(),
1169                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1170                    ))
1171                })),
1172                registry.clone(),
1173                project.read(cx).worktree_store(),
1174                project.downgrade(),
1175                cx,
1176            )
1177        });
1178
1179        // Ensure that mcp-1 starts up
1180        {
1181            let _server_events = assert_server_events(
1182                &store,
1183                vec![
1184                    (server_1_id.clone(), ContextServerStatus::Starting),
1185                    (server_1_id.clone(), ContextServerStatus::Running),
1186                ],
1187                cx,
1188            );
1189            cx.run_until_parked();
1190        }
1191
1192        // Ensure that mcp-1 is stopped once it is disabled.
1193        {
1194            let _server_events = assert_server_events(
1195                &store,
1196                vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1197                cx,
1198            );
1199            set_context_server_configuration(
1200                vec![(
1201                    server_1_id.0.clone(),
1202                    settings::ContextServerSettingsContent::Stdio {
1203                        enabled: false,
1204                        command: ContextServerCommand {
1205                            path: "somebinary".into(),
1206                            args: vec!["arg".to_string()],
1207                            env: None,
1208                            timeout: None,
1209                        },
1210                    },
1211                )],
1212                cx,
1213            );
1214
1215            cx.run_until_parked();
1216        }
1217
1218        // Ensure that mcp-1 is started once it is enabled again.
1219        {
1220            let _server_events = assert_server_events(
1221                &store,
1222                vec![
1223                    (server_1_id.clone(), ContextServerStatus::Starting),
1224                    (server_1_id.clone(), ContextServerStatus::Running),
1225                ],
1226                cx,
1227            );
1228            set_context_server_configuration(
1229                vec![(
1230                    server_1_id.0.clone(),
1231                    settings::ContextServerSettingsContent::Stdio {
1232                        enabled: true,
1233                        command: ContextServerCommand {
1234                            path: "somebinary".into(),
1235                            args: vec!["arg".to_string()],
1236                            timeout: None,
1237                            env: None,
1238                        },
1239                    },
1240                )],
1241                cx,
1242            );
1243
1244            cx.run_until_parked();
1245        }
1246    }
1247
1248    fn set_context_server_configuration(
1249        context_servers: Vec<(Arc<str>, settings::ContextServerSettingsContent)>,
1250        cx: &mut TestAppContext,
1251    ) {
1252        cx.update(|cx| {
1253            SettingsStore::update_global(cx, |store, cx| {
1254                store.update_user_settings(cx, |content| {
1255                    content.project.context_servers.clear();
1256                    for (id, config) in context_servers {
1257                        content.project.context_servers.insert(id, config);
1258                    }
1259                });
1260            })
1261        });
1262    }
1263
1264    #[gpui::test]
1265    async fn test_remote_context_server(cx: &mut TestAppContext) {
1266        const SERVER_ID: &str = "remote-server";
1267        let server_id = ContextServerId(SERVER_ID.into());
1268        let server_url = "http://example.com/api";
1269
1270        let (_fs, project) = setup_context_server_test(
1271            cx,
1272            json!({ "code.rs": "" }),
1273            vec![(
1274                SERVER_ID.into(),
1275                ContextServerSettings::Http {
1276                    enabled: true,
1277                    url: server_url.to_string(),
1278                    headers: Default::default(),
1279                },
1280            )],
1281        )
1282        .await;
1283
1284        let client = FakeHttpClient::create(|_| async move {
1285            use http_client::AsyncBody;
1286
1287            let response = Response::builder()
1288                .status(200)
1289                .header("Content-Type", "application/json")
1290                .body(AsyncBody::from(
1291                    serde_json::to_string(&json!({
1292                        "jsonrpc": "2.0",
1293                        "id": 0,
1294                        "result": {
1295                            "protocolVersion": "2024-11-05",
1296                            "capabilities": {},
1297                            "serverInfo": {
1298                                "name": "test-server",
1299                                "version": "1.0.0"
1300                            }
1301                        }
1302                    }))
1303                    .unwrap(),
1304                ))
1305                .unwrap();
1306            Ok(response)
1307        });
1308        cx.update(|cx| cx.set_http_client(client));
1309        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1310        let store = cx.new(|cx| {
1311            ContextServerStore::test_maintain_server_loop(
1312                None,
1313                registry.clone(),
1314                project.read(cx).worktree_store(),
1315                project.downgrade(),
1316                cx,
1317            )
1318        });
1319
1320        let _server_events = assert_server_events(
1321            &store,
1322            vec![
1323                (server_id.clone(), ContextServerStatus::Starting),
1324                (server_id.clone(), ContextServerStatus::Running),
1325            ],
1326            cx,
1327        );
1328        cx.run_until_parked();
1329    }
1330
1331    struct ServerEvents {
1332        received_event_count: Rc<RefCell<usize>>,
1333        expected_event_count: usize,
1334        _subscription: Subscription,
1335    }
1336
1337    impl Drop for ServerEvents {
1338        fn drop(&mut self) {
1339            let actual_event_count = *self.received_event_count.borrow();
1340            assert_eq!(
1341                actual_event_count, self.expected_event_count,
1342                "
1343                Expected to receive {} context server store events, but received {} events",
1344                self.expected_event_count, actual_event_count
1345            );
1346        }
1347    }
1348
1349    fn dummy_server_settings() -> ContextServerSettings {
1350        ContextServerSettings::Stdio {
1351            enabled: true,
1352            command: ContextServerCommand {
1353                path: "somebinary".into(),
1354                args: vec!["arg".to_string()],
1355                env: None,
1356                timeout: None,
1357            },
1358        }
1359    }
1360
1361    fn assert_server_events(
1362        store: &Entity<ContextServerStore>,
1363        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1364        cx: &mut TestAppContext,
1365    ) -> ServerEvents {
1366        cx.update(|cx| {
1367            let mut ix = 0;
1368            let received_event_count = Rc::new(RefCell::new(0));
1369            let expected_event_count = expected_events.len();
1370            let subscription = cx.subscribe(store, {
1371                let received_event_count = received_event_count.clone();
1372                move |_, event, _| match event {
1373                    Event::ServerStatusChanged {
1374                        server_id: actual_server_id,
1375                        status: actual_status,
1376                    } => {
1377                        let (expected_server_id, expected_status) = &expected_events[ix];
1378
1379                        assert_eq!(
1380                            actual_server_id, expected_server_id,
1381                            "Expected different server id at index {}",
1382                            ix
1383                        );
1384                        assert_eq!(
1385                            actual_status, expected_status,
1386                            "Expected different status at index {}",
1387                            ix
1388                        );
1389                        ix += 1;
1390                        *received_event_count.borrow_mut() += 1;
1391                    }
1392                }
1393            });
1394            ServerEvents {
1395                expected_event_count,
1396                received_event_count,
1397                _subscription: subscription,
1398            }
1399        })
1400    }
1401
1402    async fn setup_context_server_test(
1403        cx: &mut TestAppContext,
1404        files: serde_json::Value,
1405        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1406    ) -> (Arc<FakeFs>, Entity<Project>) {
1407        cx.update(|cx| {
1408            let settings_store = SettingsStore::test(cx);
1409            cx.set_global(settings_store);
1410            let mut settings = ProjectSettings::get_global(cx).clone();
1411            for (id, config) in context_server_configurations {
1412                settings.context_servers.insert(id, config);
1413            }
1414            ProjectSettings::override_global(settings, cx);
1415        });
1416
1417        let fs = FakeFs::new(cx.executor());
1418        fs.insert_tree(path!("/test"), files).await;
1419        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1420
1421        (fs, project)
1422    }
1423
1424    struct FakeContextServerDescriptor {
1425        path: PathBuf,
1426    }
1427
1428    impl FakeContextServerDescriptor {
1429        fn new(path: impl Into<PathBuf>) -> Self {
1430            Self { path: path.into() }
1431        }
1432    }
1433
1434    impl ContextServerDescriptor for FakeContextServerDescriptor {
1435        fn command(
1436            &self,
1437            _worktree_store: Entity<WorktreeStore>,
1438            _cx: &AsyncApp,
1439        ) -> Task<Result<ContextServerCommand>> {
1440            Task::ready(Ok(ContextServerCommand {
1441                path: self.path.clone(),
1442                args: vec!["arg1".to_string(), "arg2".to_string()],
1443                env: None,
1444                timeout: None,
1445            }))
1446        }
1447
1448        fn configuration(
1449            &self,
1450            _worktree_store: Entity<WorktreeStore>,
1451            _cx: &AsyncApp,
1452        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1453            Task::ready(Ok(None))
1454        }
1455    }
1456}