context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::sync::Arc;
   5use std::time::Duration;
   6
   7use anyhow::{Context as _, Result};
   8use collections::{HashMap, HashSet};
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use futures::{FutureExt as _, future::join_all};
  11use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  12use registry::ContextServerDescriptorRegistry;
  13use settings::{Settings as _, SettingsStore};
  14use util::{ResultExt as _, rel_path::RelPath};
  15
  16use crate::{
  17    Project,
  18    project_settings::{ContextServerSettings, ProjectSettings},
  19    worktree_store::WorktreeStore,
  20};
  21
  22pub fn init(cx: &mut App) {
  23    extension::init(cx);
  24}
  25
  26actions!(
  27    context_server,
  28    [
  29        /// Restarts the context server.
  30        Restart
  31    ]
  32);
  33
  34#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  35pub enum ContextServerStatus {
  36    Starting,
  37    Running,
  38    Stopped,
  39    Error(Arc<str>),
  40}
  41
  42impl ContextServerStatus {
  43    fn from_state(state: &ContextServerState) -> Self {
  44        match state {
  45            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  46            ContextServerState::Running { .. } => ContextServerStatus::Running,
  47            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  48            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  49        }
  50    }
  51}
  52
  53enum ContextServerState {
  54    Starting {
  55        server: Arc<ContextServer>,
  56        configuration: Arc<ContextServerConfiguration>,
  57        _task: Task<()>,
  58    },
  59    Running {
  60        server: Arc<ContextServer>,
  61        configuration: Arc<ContextServerConfiguration>,
  62    },
  63    Stopped {
  64        server: Arc<ContextServer>,
  65        configuration: Arc<ContextServerConfiguration>,
  66    },
  67    Error {
  68        server: Arc<ContextServer>,
  69        configuration: Arc<ContextServerConfiguration>,
  70        error: Arc<str>,
  71    },
  72}
  73
  74impl ContextServerState {
  75    pub fn server(&self) -> Arc<ContextServer> {
  76        match self {
  77            ContextServerState::Starting { server, .. } => server.clone(),
  78            ContextServerState::Running { server, .. } => server.clone(),
  79            ContextServerState::Stopped { server, .. } => server.clone(),
  80            ContextServerState::Error { server, .. } => server.clone(),
  81        }
  82    }
  83
  84    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  85        match self {
  86            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  87            ContextServerState::Running { configuration, .. } => configuration.clone(),
  88            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  89            ContextServerState::Error { configuration, .. } => configuration.clone(),
  90        }
  91    }
  92}
  93
  94#[derive(Debug, PartialEq, Eq)]
  95pub enum ContextServerConfiguration {
  96    Custom {
  97        command: ContextServerCommand,
  98    },
  99    Extension {
 100        command: ContextServerCommand,
 101        settings: serde_json::Value,
 102    },
 103    Http {
 104        url: url::Url,
 105        headers: HashMap<String, String>,
 106        timeout: Option<u64>,
 107    },
 108}
 109
 110impl ContextServerConfiguration {
 111    pub fn command(&self) -> Option<&ContextServerCommand> {
 112        match self {
 113            ContextServerConfiguration::Custom { command } => Some(command),
 114            ContextServerConfiguration::Extension { command, .. } => Some(command),
 115            ContextServerConfiguration::Http { .. } => None,
 116        }
 117    }
 118
 119    pub async fn from_settings(
 120        settings: ContextServerSettings,
 121        id: ContextServerId,
 122        registry: Entity<ContextServerDescriptorRegistry>,
 123        worktree_store: Entity<WorktreeStore>,
 124        cx: &AsyncApp,
 125    ) -> Option<Self> {
 126        match settings {
 127            ContextServerSettings::Stdio {
 128                enabled: _,
 129                command,
 130            } => Some(ContextServerConfiguration::Custom { command }),
 131            ContextServerSettings::Extension {
 132                enabled: _,
 133                settings,
 134            } => {
 135                let descriptor = cx
 136                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 137                    .ok()
 138                    .flatten()?;
 139
 140                match descriptor.command(worktree_store, cx).await {
 141                    Ok(command) => {
 142                        Some(ContextServerConfiguration::Extension { command, settings })
 143                    }
 144                    Err(e) => {
 145                        log::error!(
 146                            "Failed to create context server configuration from settings: {e:#}"
 147                        );
 148                        None
 149                    }
 150                }
 151            }
 152            ContextServerSettings::Http {
 153                enabled: _,
 154                url,
 155                headers: auth,
 156                timeout,
 157            } => {
 158                let url = url::Url::parse(&url).log_err()?;
 159                Some(ContextServerConfiguration::Http {
 160                    url,
 161                    headers: auth,
 162                    timeout,
 163                })
 164            }
 165        }
 166    }
 167}
 168
 169pub type ContextServerFactory =
 170    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 171
 172pub struct ContextServerStore {
 173    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 174    servers: HashMap<ContextServerId, ContextServerState>,
 175    worktree_store: Entity<WorktreeStore>,
 176    project: WeakEntity<Project>,
 177    registry: Entity<ContextServerDescriptorRegistry>,
 178    update_servers_task: Option<Task<Result<()>>>,
 179    context_server_factory: Option<ContextServerFactory>,
 180    needs_server_update: bool,
 181    _subscriptions: Vec<Subscription>,
 182}
 183
 184pub enum Event {
 185    ServerStatusChanged {
 186        server_id: ContextServerId,
 187        status: ContextServerStatus,
 188    },
 189}
 190
 191impl EventEmitter<Event> for ContextServerStore {}
 192
 193impl ContextServerStore {
 194    pub fn new(
 195        worktree_store: Entity<WorktreeStore>,
 196        weak_project: WeakEntity<Project>,
 197        cx: &mut Context<Self>,
 198    ) -> Self {
 199        Self::new_internal(
 200            true,
 201            None,
 202            ContextServerDescriptorRegistry::default_global(cx),
 203            worktree_store,
 204            weak_project,
 205            cx,
 206        )
 207    }
 208
 209    /// Returns all configured context server ids, excluding the ones that are disabled
 210    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 211        self.context_server_settings
 212            .iter()
 213            .filter(|(_, settings)| settings.enabled())
 214            .map(|(id, _)| ContextServerId(id.clone()))
 215            .collect()
 216    }
 217
 218    #[cfg(any(test, feature = "test-support"))]
 219    pub fn test(
 220        registry: Entity<ContextServerDescriptorRegistry>,
 221        worktree_store: Entity<WorktreeStore>,
 222        weak_project: WeakEntity<Project>,
 223        cx: &mut Context<Self>,
 224    ) -> Self {
 225        Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
 226    }
 227
 228    #[cfg(any(test, feature = "test-support"))]
 229    pub fn test_maintain_server_loop(
 230        context_server_factory: Option<ContextServerFactory>,
 231        registry: Entity<ContextServerDescriptorRegistry>,
 232        worktree_store: Entity<WorktreeStore>,
 233        weak_project: WeakEntity<Project>,
 234        cx: &mut Context<Self>,
 235    ) -> Self {
 236        Self::new_internal(
 237            true,
 238            context_server_factory,
 239            registry,
 240            worktree_store,
 241            weak_project,
 242            cx,
 243        )
 244    }
 245
 246    fn new_internal(
 247        maintain_server_loop: bool,
 248        context_server_factory: Option<ContextServerFactory>,
 249        registry: Entity<ContextServerDescriptorRegistry>,
 250        worktree_store: Entity<WorktreeStore>,
 251        weak_project: WeakEntity<Project>,
 252        cx: &mut Context<Self>,
 253    ) -> Self {
 254        let subscriptions = if maintain_server_loop {
 255            vec![
 256                cx.observe(&registry, |this, _registry, cx| {
 257                    this.available_context_servers_changed(cx);
 258                }),
 259                cx.observe_global::<SettingsStore>(|this, cx| {
 260                    let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
 261                    if &this.context_server_settings == settings {
 262                        return;
 263                    }
 264                    this.context_server_settings = settings.clone();
 265                    this.available_context_servers_changed(cx);
 266                }),
 267            ]
 268        } else {
 269            Vec::new()
 270        };
 271
 272        let mut this = Self {
 273            _subscriptions: subscriptions,
 274            context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
 275                .clone(),
 276            worktree_store,
 277            project: weak_project,
 278            registry,
 279            needs_server_update: false,
 280            servers: HashMap::default(),
 281            update_servers_task: None,
 282            context_server_factory,
 283        };
 284        if maintain_server_loop {
 285            this.available_context_servers_changed(cx);
 286        }
 287        this
 288    }
 289
 290    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 291        self.servers.get(id).map(|state| state.server())
 292    }
 293
 294    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 295        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 296            Some(server.clone())
 297        } else {
 298            None
 299        }
 300    }
 301
 302    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 303        self.servers.get(id).map(ContextServerStatus::from_state)
 304    }
 305
 306    pub fn configuration_for_server(
 307        &self,
 308        id: &ContextServerId,
 309    ) -> Option<Arc<ContextServerConfiguration>> {
 310        self.servers.get(id).map(|state| state.configuration())
 311    }
 312
 313    pub fn server_ids(&self, cx: &App) -> HashSet<ContextServerId> {
 314        self.servers
 315            .keys()
 316            .cloned()
 317            .chain(
 318                self.registry
 319                    .read(cx)
 320                    .context_server_descriptors()
 321                    .into_iter()
 322                    .map(|(id, _)| ContextServerId(id)),
 323            )
 324            .collect()
 325    }
 326
 327    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 328        self.servers
 329            .values()
 330            .filter_map(|state| {
 331                if let ContextServerState::Running { server, .. } = state {
 332                    Some(server.clone())
 333                } else {
 334                    None
 335                }
 336            })
 337            .collect()
 338    }
 339
 340    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 341        cx.spawn(async move |this, cx| {
 342            let this = this.upgrade().context("Context server store dropped")?;
 343            let settings = this
 344                .update(cx, |this, _| {
 345                    this.context_server_settings.get(&server.id().0).cloned()
 346                })
 347                .ok()
 348                .flatten()
 349                .context("Failed to get context server settings")?;
 350
 351            if !settings.enabled() {
 352                return Ok(());
 353            }
 354
 355            let (registry, worktree_store) = this.update(cx, |this, _| {
 356                (this.registry.clone(), this.worktree_store.clone())
 357            })?;
 358            let configuration = ContextServerConfiguration::from_settings(
 359                settings,
 360                server.id(),
 361                registry,
 362                worktree_store,
 363                cx,
 364            )
 365            .await
 366            .context("Failed to create context server configuration")?;
 367
 368            this.update(cx, |this, cx| {
 369                this.run_server(server, Arc::new(configuration), cx)
 370            })
 371        })
 372        .detach_and_log_err(cx);
 373    }
 374
 375    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 376        if matches!(
 377            self.servers.get(id),
 378            Some(ContextServerState::Stopped { .. })
 379        ) {
 380            return Ok(());
 381        }
 382
 383        let state = self
 384            .servers
 385            .remove(id)
 386            .context("Context server not found")?;
 387
 388        let server = state.server();
 389        let configuration = state.configuration();
 390        let mut result = Ok(());
 391        if let ContextServerState::Running { server, .. } = &state {
 392            result = server.stop();
 393        }
 394        drop(state);
 395
 396        self.update_server_state(
 397            id.clone(),
 398            ContextServerState::Stopped {
 399                configuration,
 400                server,
 401            },
 402            cx,
 403        );
 404
 405        result
 406    }
 407
 408    fn run_server(
 409        &mut self,
 410        server: Arc<ContextServer>,
 411        configuration: Arc<ContextServerConfiguration>,
 412        cx: &mut Context<Self>,
 413    ) {
 414        let id = server.id();
 415        if matches!(
 416            self.servers.get(&id),
 417            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 418        ) {
 419            self.stop_server(&id, cx).log_err();
 420        }
 421        let task = cx.spawn({
 422            let id = server.id();
 423            let server = server.clone();
 424            let configuration = configuration.clone();
 425
 426            async move |this, cx| {
 427                match server.clone().start(cx).await {
 428                    Ok(_) => {
 429                        debug_assert!(server.client().is_some());
 430
 431                        this.update(cx, |this, cx| {
 432                            this.update_server_state(
 433                                id.clone(),
 434                                ContextServerState::Running {
 435                                    server,
 436                                    configuration,
 437                                },
 438                                cx,
 439                            )
 440                        })
 441                        .log_err()
 442                    }
 443                    Err(err) => {
 444                        log::error!("{} context server failed to start: {}", id, err);
 445                        this.update(cx, |this, cx| {
 446                            this.update_server_state(
 447                                id.clone(),
 448                                ContextServerState::Error {
 449                                    configuration,
 450                                    server,
 451                                    error: err.to_string().into(),
 452                                },
 453                                cx,
 454                            )
 455                        })
 456                        .log_err()
 457                    }
 458                };
 459            }
 460        });
 461
 462        self.update_server_state(
 463            id.clone(),
 464            ContextServerState::Starting {
 465                configuration,
 466                _task: task,
 467                server,
 468            },
 469            cx,
 470        );
 471    }
 472
 473    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 474        let state = self
 475            .servers
 476            .remove(id)
 477            .context("Context server not found")?;
 478        drop(state);
 479        cx.emit(Event::ServerStatusChanged {
 480            server_id: id.clone(),
 481            status: ContextServerStatus::Stopped,
 482        });
 483        Ok(())
 484    }
 485
 486    fn create_context_server(
 487        &self,
 488        id: ContextServerId,
 489        configuration: Arc<ContextServerConfiguration>,
 490        cx: &mut Context<Self>,
 491    ) -> Result<Arc<ContextServer>> {
 492        // Get global timeout from settings
 493        let global_timeout = ProjectSettings::get_global(cx).context_server_timeout;
 494
 495        if let Some(factory) = self.context_server_factory.as_ref() {
 496            return Ok(factory(id, configuration));
 497        }
 498
 499        match configuration.as_ref() {
 500            ContextServerConfiguration::Http {
 501                url,
 502                headers,
 503                timeout,
 504            } => {
 505                // Apply timeout precedence for HTTP servers: per-server > global
 506                let resolved_timeout = timeout.unwrap_or(global_timeout);
 507
 508                Ok(Arc::new(ContextServer::http(
 509                    id,
 510                    url,
 511                    headers.clone(),
 512                    cx.http_client(),
 513                    cx.background_executor().clone(),
 514                    Some(Duration::from_millis(resolved_timeout)),
 515                )?))
 516            }
 517            _ => {
 518                let root_path = self
 519                    .project
 520                    .read_with(cx, |project, cx| project.active_project_directory(cx))
 521                    .ok()
 522                    .flatten()
 523                    .or_else(|| {
 524                        self.worktree_store.read_with(cx, |store, cx| {
 525                            store.visible_worktrees(cx).fold(None, |acc, item| {
 526                                if acc.is_none() {
 527                                    item.read(cx).root_dir()
 528                                } else {
 529                                    acc
 530                                }
 531                            })
 532                        })
 533                    });
 534
 535                // Apply timeout precedence for stdio servers: per-server > global
 536                let mut command_with_timeout = configuration.command().unwrap().clone();
 537                if command_with_timeout.timeout.is_none() {
 538                    command_with_timeout.timeout = Some(global_timeout);
 539                }
 540
 541                Ok(Arc::new(ContextServer::stdio(
 542                    id,
 543                    command_with_timeout,
 544                    root_path,
 545                )))
 546            }
 547        }
 548    }
 549
 550    fn resolve_context_server_settings<'a>(
 551        worktree_store: &'a Entity<WorktreeStore>,
 552        cx: &'a App,
 553    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 554        let location = worktree_store
 555            .read(cx)
 556            .visible_worktrees(cx)
 557            .next()
 558            .map(|worktree| settings::SettingsLocation {
 559                worktree_id: worktree.read(cx).id(),
 560                path: RelPath::empty(),
 561            });
 562        &ProjectSettings::get(location, cx).context_servers
 563    }
 564
 565    fn update_server_state(
 566        &mut self,
 567        id: ContextServerId,
 568        state: ContextServerState,
 569        cx: &mut Context<Self>,
 570    ) {
 571        let status = ContextServerStatus::from_state(&state);
 572        self.servers.insert(id.clone(), state);
 573        cx.emit(Event::ServerStatusChanged {
 574            server_id: id,
 575            status,
 576        });
 577    }
 578
 579    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 580        if self.update_servers_task.is_some() {
 581            self.needs_server_update = true;
 582        } else {
 583            self.needs_server_update = false;
 584            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 585                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 586                    log::error!("Error maintaining context servers: {}", err);
 587                }
 588
 589                this.update(cx, |this, cx| {
 590                    this.update_servers_task.take();
 591                    if this.needs_server_update {
 592                        this.available_context_servers_changed(cx);
 593                    }
 594                })?;
 595
 596                Ok(())
 597            }));
 598        }
 599    }
 600
 601    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 602        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
 603            (
 604                this.context_server_settings.clone(),
 605                this.registry.clone(),
 606                this.worktree_store.clone(),
 607            )
 608        })?;
 609
 610        for (id, _) in
 611            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 612        {
 613            configured_servers
 614                .entry(id)
 615                .or_insert(ContextServerSettings::default_extension());
 616        }
 617
 618        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
 619            configured_servers
 620                .into_iter()
 621                .partition(|(_, settings)| settings.enabled());
 622
 623        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
 624            let id = ContextServerId(id);
 625            ContextServerConfiguration::from_settings(
 626                settings,
 627                id.clone(),
 628                registry.clone(),
 629                worktree_store.clone(),
 630                cx,
 631            )
 632            .map(|config| (id, config))
 633        }))
 634        .await
 635        .into_iter()
 636        .filter_map(|(id, config)| config.map(|config| (id, config)))
 637        .collect::<HashMap<_, _>>();
 638
 639        let mut servers_to_start = Vec::new();
 640        let mut servers_to_remove = HashSet::default();
 641        let mut servers_to_stop = HashSet::default();
 642
 643        this.update(cx, |this, cx| {
 644            for server_id in this.servers.keys() {
 645                // All servers that are not in desired_servers should be removed from the store.
 646                // This can happen if the user removed a server from the context server settings.
 647                if !configured_servers.contains_key(server_id) {
 648                    if disabled_servers.contains_key(&server_id.0) {
 649                        servers_to_stop.insert(server_id.clone());
 650                    } else {
 651                        servers_to_remove.insert(server_id.clone());
 652                    }
 653                }
 654            }
 655
 656            for (id, config) in configured_servers {
 657                let state = this.servers.get(&id);
 658                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
 659                let existing_config = state.as_ref().map(|state| state.configuration());
 660                if existing_config.as_deref() != Some(&config) || is_stopped {
 661                    let config = Arc::new(config);
 662                    let server = this.create_context_server(id.clone(), config.clone(), cx)?;
 663                    servers_to_start.push((server, config));
 664                    if this.servers.contains_key(&id) {
 665                        servers_to_stop.insert(id);
 666                    }
 667                }
 668            }
 669
 670            anyhow::Ok(())
 671        })??;
 672
 673        this.update(cx, |this, cx| {
 674            for id in servers_to_stop {
 675                this.stop_server(&id, cx)?;
 676            }
 677            for id in servers_to_remove {
 678                this.remove_server(&id, cx)?;
 679            }
 680            for (server, config) in servers_to_start {
 681                this.run_server(server, config, cx);
 682            }
 683            anyhow::Ok(())
 684        })?
 685    }
 686}
 687
 688#[cfg(test)]
 689mod tests {
 690    use super::*;
 691    use crate::{
 692        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 693        project_settings::ProjectSettings,
 694    };
 695    use context_server::test::create_fake_transport;
 696    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 697    use http_client::{FakeHttpClient, Response};
 698    use serde_json::json;
 699    use std::{cell::RefCell, path::PathBuf, rc::Rc};
 700    use util::path;
 701
 702    #[gpui::test]
 703    async fn test_context_server_status(cx: &mut TestAppContext) {
 704        const SERVER_1_ID: &str = "mcp-1";
 705        const SERVER_2_ID: &str = "mcp-2";
 706
 707        let (_fs, project) = setup_context_server_test(
 708            cx,
 709            json!({"code.rs": ""}),
 710            vec![
 711                (SERVER_1_ID.into(), dummy_server_settings()),
 712                (SERVER_2_ID.into(), dummy_server_settings()),
 713            ],
 714        )
 715        .await;
 716
 717        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 718        let store = cx.new(|cx| {
 719            ContextServerStore::test(
 720                registry.clone(),
 721                project.read(cx).worktree_store(),
 722                project.downgrade(),
 723                cx,
 724            )
 725        });
 726
 727        let server_1_id = ContextServerId(SERVER_1_ID.into());
 728        let server_2_id = ContextServerId(SERVER_2_ID.into());
 729
 730        let server_1 = Arc::new(ContextServer::new(
 731            server_1_id.clone(),
 732            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 733        ));
 734        let server_2 = Arc::new(ContextServer::new(
 735            server_2_id.clone(),
 736            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 737        ));
 738
 739        store.update(cx, |store, cx| store.start_server(server_1, cx));
 740
 741        cx.run_until_parked();
 742
 743        cx.update(|cx| {
 744            assert_eq!(
 745                store.read(cx).status_for_server(&server_1_id),
 746                Some(ContextServerStatus::Running)
 747            );
 748            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 749        });
 750
 751        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 752
 753        cx.run_until_parked();
 754
 755        cx.update(|cx| {
 756            assert_eq!(
 757                store.read(cx).status_for_server(&server_1_id),
 758                Some(ContextServerStatus::Running)
 759            );
 760            assert_eq!(
 761                store.read(cx).status_for_server(&server_2_id),
 762                Some(ContextServerStatus::Running)
 763            );
 764        });
 765
 766        store
 767            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 768            .unwrap();
 769
 770        cx.update(|cx| {
 771            assert_eq!(
 772                store.read(cx).status_for_server(&server_1_id),
 773                Some(ContextServerStatus::Running)
 774            );
 775            assert_eq!(
 776                store.read(cx).status_for_server(&server_2_id),
 777                Some(ContextServerStatus::Stopped)
 778            );
 779        });
 780    }
 781
 782    #[gpui::test]
 783    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 784        const SERVER_1_ID: &str = "mcp-1";
 785        const SERVER_2_ID: &str = "mcp-2";
 786
 787        let (_fs, project) = setup_context_server_test(
 788            cx,
 789            json!({"code.rs": ""}),
 790            vec![
 791                (SERVER_1_ID.into(), dummy_server_settings()),
 792                (SERVER_2_ID.into(), dummy_server_settings()),
 793            ],
 794        )
 795        .await;
 796
 797        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 798        let store = cx.new(|cx| {
 799            ContextServerStore::test(
 800                registry.clone(),
 801                project.read(cx).worktree_store(),
 802                project.downgrade(),
 803                cx,
 804            )
 805        });
 806
 807        let server_1_id = ContextServerId(SERVER_1_ID.into());
 808        let server_2_id = ContextServerId(SERVER_2_ID.into());
 809
 810        let server_1 = Arc::new(ContextServer::new(
 811            server_1_id.clone(),
 812            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 813        ));
 814        let server_2 = Arc::new(ContextServer::new(
 815            server_2_id.clone(),
 816            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 817        ));
 818
 819        let _server_events = assert_server_events(
 820            &store,
 821            vec![
 822                (server_1_id.clone(), ContextServerStatus::Starting),
 823                (server_1_id, ContextServerStatus::Running),
 824                (server_2_id.clone(), ContextServerStatus::Starting),
 825                (server_2_id.clone(), ContextServerStatus::Running),
 826                (server_2_id.clone(), ContextServerStatus::Stopped),
 827            ],
 828            cx,
 829        );
 830
 831        store.update(cx, |store, cx| store.start_server(server_1, cx));
 832
 833        cx.run_until_parked();
 834
 835        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 836
 837        cx.run_until_parked();
 838
 839        store
 840            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 841            .unwrap();
 842    }
 843
 844    #[gpui::test(iterations = 25)]
 845    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 846        const SERVER_1_ID: &str = "mcp-1";
 847
 848        let (_fs, project) = setup_context_server_test(
 849            cx,
 850            json!({"code.rs": ""}),
 851            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 852        )
 853        .await;
 854
 855        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 856        let store = cx.new(|cx| {
 857            ContextServerStore::test(
 858                registry.clone(),
 859                project.read(cx).worktree_store(),
 860                project.downgrade(),
 861                cx,
 862            )
 863        });
 864
 865        let server_id = ContextServerId(SERVER_1_ID.into());
 866
 867        let server_with_same_id_1 = Arc::new(ContextServer::new(
 868            server_id.clone(),
 869            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 870        ));
 871        let server_with_same_id_2 = Arc::new(ContextServer::new(
 872            server_id.clone(),
 873            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 874        ));
 875
 876        // If we start another server with the same id, we should report that we stopped the previous one
 877        let _server_events = assert_server_events(
 878            &store,
 879            vec![
 880                (server_id.clone(), ContextServerStatus::Starting),
 881                (server_id.clone(), ContextServerStatus::Stopped),
 882                (server_id.clone(), ContextServerStatus::Starting),
 883                (server_id.clone(), ContextServerStatus::Running),
 884            ],
 885            cx,
 886        );
 887
 888        store.update(cx, |store, cx| {
 889            store.start_server(server_with_same_id_1.clone(), cx)
 890        });
 891        store.update(cx, |store, cx| {
 892            store.start_server(server_with_same_id_2.clone(), cx)
 893        });
 894
 895        cx.run_until_parked();
 896
 897        cx.update(|cx| {
 898            assert_eq!(
 899                store.read(cx).status_for_server(&server_id),
 900                Some(ContextServerStatus::Running)
 901            );
 902        });
 903    }
 904
 905    #[gpui::test]
 906    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 907        const SERVER_1_ID: &str = "mcp-1";
 908        const SERVER_2_ID: &str = "mcp-2";
 909
 910        let server_1_id = ContextServerId(SERVER_1_ID.into());
 911        let server_2_id = ContextServerId(SERVER_2_ID.into());
 912
 913        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 914
 915        let (_fs, project) = setup_context_server_test(
 916            cx,
 917            json!({"code.rs": ""}),
 918            vec![(
 919                SERVER_1_ID.into(),
 920                ContextServerSettings::Extension {
 921                    enabled: true,
 922                    settings: json!({
 923                        "somevalue": true
 924                    }),
 925                },
 926            )],
 927        )
 928        .await;
 929
 930        let executor = cx.executor();
 931        let registry = cx.new(|cx| {
 932            let mut registry = ContextServerDescriptorRegistry::new();
 933            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
 934            registry
 935        });
 936        let store = cx.new(|cx| {
 937            ContextServerStore::test_maintain_server_loop(
 938                Some(Box::new(move |id, _| {
 939                    Arc::new(ContextServer::new(
 940                        id.clone(),
 941                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 942                    ))
 943                })),
 944                registry.clone(),
 945                project.read(cx).worktree_store(),
 946                project.downgrade(),
 947                cx,
 948            )
 949        });
 950
 951        // Ensure that mcp-1 starts up
 952        {
 953            let _server_events = assert_server_events(
 954                &store,
 955                vec![
 956                    (server_1_id.clone(), ContextServerStatus::Starting),
 957                    (server_1_id.clone(), ContextServerStatus::Running),
 958                ],
 959                cx,
 960            );
 961            cx.run_until_parked();
 962        }
 963
 964        // Ensure that mcp-1 is restarted when the configuration was changed
 965        {
 966            let _server_events = assert_server_events(
 967                &store,
 968                vec![
 969                    (server_1_id.clone(), ContextServerStatus::Stopped),
 970                    (server_1_id.clone(), ContextServerStatus::Starting),
 971                    (server_1_id.clone(), ContextServerStatus::Running),
 972                ],
 973                cx,
 974            );
 975            set_context_server_configuration(
 976                vec![(
 977                    server_1_id.0.clone(),
 978                    settings::ContextServerSettingsContent::Extension {
 979                        enabled: true,
 980                        settings: json!({
 981                            "somevalue": false
 982                        }),
 983                    },
 984                )],
 985                cx,
 986            );
 987
 988            cx.run_until_parked();
 989        }
 990
 991        // Ensure that mcp-1 is not restarted when the configuration was not changed
 992        {
 993            let _server_events = assert_server_events(&store, vec![], cx);
 994            set_context_server_configuration(
 995                vec![(
 996                    server_1_id.0.clone(),
 997                    settings::ContextServerSettingsContent::Extension {
 998                        enabled: true,
 999                        settings: json!({
1000                            "somevalue": false
1001                        }),
1002                    },
1003                )],
1004                cx,
1005            );
1006
1007            cx.run_until_parked();
1008        }
1009
1010        // Ensure that mcp-2 is started once it is added to the settings
1011        {
1012            let _server_events = assert_server_events(
1013                &store,
1014                vec![
1015                    (server_2_id.clone(), ContextServerStatus::Starting),
1016                    (server_2_id.clone(), ContextServerStatus::Running),
1017                ],
1018                cx,
1019            );
1020            set_context_server_configuration(
1021                vec![
1022                    (
1023                        server_1_id.0.clone(),
1024                        settings::ContextServerSettingsContent::Extension {
1025                            enabled: true,
1026                            settings: json!({
1027                                "somevalue": false
1028                            }),
1029                        },
1030                    ),
1031                    (
1032                        server_2_id.0.clone(),
1033                        settings::ContextServerSettingsContent::Stdio {
1034                            enabled: true,
1035                            command: ContextServerCommand {
1036                                path: "somebinary".into(),
1037                                args: vec!["arg".to_string()],
1038                                env: None,
1039                                timeout: None,
1040                            },
1041                        },
1042                    ),
1043                ],
1044                cx,
1045            );
1046
1047            cx.run_until_parked();
1048        }
1049
1050        // Ensure that mcp-2 is restarted once the args have changed
1051        {
1052            let _server_events = assert_server_events(
1053                &store,
1054                vec![
1055                    (server_2_id.clone(), ContextServerStatus::Stopped),
1056                    (server_2_id.clone(), ContextServerStatus::Starting),
1057                    (server_2_id.clone(), ContextServerStatus::Running),
1058                ],
1059                cx,
1060            );
1061            set_context_server_configuration(
1062                vec![
1063                    (
1064                        server_1_id.0.clone(),
1065                        settings::ContextServerSettingsContent::Extension {
1066                            enabled: true,
1067                            settings: json!({
1068                                "somevalue": false
1069                            }),
1070                        },
1071                    ),
1072                    (
1073                        server_2_id.0.clone(),
1074                        settings::ContextServerSettingsContent::Stdio {
1075                            enabled: true,
1076                            command: ContextServerCommand {
1077                                path: "somebinary".into(),
1078                                args: vec!["anotherArg".to_string()],
1079                                env: None,
1080                                timeout: None,
1081                            },
1082                        },
1083                    ),
1084                ],
1085                cx,
1086            );
1087
1088            cx.run_until_parked();
1089        }
1090
1091        // Ensure that mcp-2 is removed once it is removed from the settings
1092        {
1093            let _server_events = assert_server_events(
1094                &store,
1095                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1096                cx,
1097            );
1098            set_context_server_configuration(
1099                vec![(
1100                    server_1_id.0.clone(),
1101                    settings::ContextServerSettingsContent::Extension {
1102                        enabled: true,
1103                        settings: json!({
1104                            "somevalue": false
1105                        }),
1106                    },
1107                )],
1108                cx,
1109            );
1110
1111            cx.run_until_parked();
1112
1113            cx.update(|cx| {
1114                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1115            });
1116        }
1117
1118        // Ensure that nothing happens if the settings do not change
1119        {
1120            let _server_events = assert_server_events(&store, vec![], cx);
1121            set_context_server_configuration(
1122                vec![(
1123                    server_1_id.0.clone(),
1124                    settings::ContextServerSettingsContent::Extension {
1125                        enabled: true,
1126                        settings: json!({
1127                            "somevalue": false
1128                        }),
1129                    },
1130                )],
1131                cx,
1132            );
1133
1134            cx.run_until_parked();
1135
1136            cx.update(|cx| {
1137                assert_eq!(
1138                    store.read(cx).status_for_server(&server_1_id),
1139                    Some(ContextServerStatus::Running)
1140                );
1141                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1142            });
1143        }
1144    }
1145
1146    #[gpui::test]
1147    async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1148        const SERVER_1_ID: &str = "mcp-1";
1149
1150        let server_1_id = ContextServerId(SERVER_1_ID.into());
1151
1152        let (_fs, project) = setup_context_server_test(
1153            cx,
1154            json!({"code.rs": ""}),
1155            vec![(
1156                SERVER_1_ID.into(),
1157                ContextServerSettings::Stdio {
1158                    enabled: true,
1159                    command: ContextServerCommand {
1160                        path: "somebinary".into(),
1161                        args: vec!["arg".to_string()],
1162                        env: None,
1163                        timeout: None,
1164                    },
1165                },
1166            )],
1167        )
1168        .await;
1169
1170        let executor = cx.executor();
1171        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1172        let store = cx.new(|cx| {
1173            ContextServerStore::test_maintain_server_loop(
1174                Some(Box::new(move |id, _| {
1175                    Arc::new(ContextServer::new(
1176                        id.clone(),
1177                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1178                    ))
1179                })),
1180                registry.clone(),
1181                project.read(cx).worktree_store(),
1182                project.downgrade(),
1183                cx,
1184            )
1185        });
1186
1187        // Ensure that mcp-1 starts up
1188        {
1189            let _server_events = assert_server_events(
1190                &store,
1191                vec![
1192                    (server_1_id.clone(), ContextServerStatus::Starting),
1193                    (server_1_id.clone(), ContextServerStatus::Running),
1194                ],
1195                cx,
1196            );
1197            cx.run_until_parked();
1198        }
1199
1200        // Ensure that mcp-1 is stopped once it is disabled.
1201        {
1202            let _server_events = assert_server_events(
1203                &store,
1204                vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1205                cx,
1206            );
1207            set_context_server_configuration(
1208                vec![(
1209                    server_1_id.0.clone(),
1210                    settings::ContextServerSettingsContent::Stdio {
1211                        enabled: false,
1212                        command: ContextServerCommand {
1213                            path: "somebinary".into(),
1214                            args: vec!["arg".to_string()],
1215                            env: None,
1216                            timeout: None,
1217                        },
1218                    },
1219                )],
1220                cx,
1221            );
1222
1223            cx.run_until_parked();
1224        }
1225
1226        // Ensure that mcp-1 is started once it is enabled again.
1227        {
1228            let _server_events = assert_server_events(
1229                &store,
1230                vec![
1231                    (server_1_id.clone(), ContextServerStatus::Starting),
1232                    (server_1_id.clone(), ContextServerStatus::Running),
1233                ],
1234                cx,
1235            );
1236            set_context_server_configuration(
1237                vec![(
1238                    server_1_id.0.clone(),
1239                    settings::ContextServerSettingsContent::Stdio {
1240                        enabled: true,
1241                        command: ContextServerCommand {
1242                            path: "somebinary".into(),
1243                            args: vec!["arg".to_string()],
1244                            timeout: None,
1245                            env: None,
1246                        },
1247                    },
1248                )],
1249                cx,
1250            );
1251
1252            cx.run_until_parked();
1253        }
1254    }
1255
1256    fn set_context_server_configuration(
1257        context_servers: Vec<(Arc<str>, settings::ContextServerSettingsContent)>,
1258        cx: &mut TestAppContext,
1259    ) {
1260        cx.update(|cx| {
1261            SettingsStore::update_global(cx, |store, cx| {
1262                store.update_user_settings(cx, |content| {
1263                    content.project.context_servers.clear();
1264                    for (id, config) in context_servers {
1265                        content.project.context_servers.insert(id, config);
1266                    }
1267                });
1268            })
1269        });
1270    }
1271
1272    #[gpui::test]
1273    async fn test_remote_context_server(cx: &mut TestAppContext) {
1274        const SERVER_ID: &str = "remote-server";
1275        let server_id = ContextServerId(SERVER_ID.into());
1276        let server_url = "http://example.com/api";
1277
1278        let (_fs, project) = setup_context_server_test(
1279            cx,
1280            json!({ "code.rs": "" }),
1281            vec![(
1282                SERVER_ID.into(),
1283                ContextServerSettings::Http {
1284                    enabled: true,
1285                    url: server_url.to_string(),
1286                    headers: Default::default(),
1287                    timeout: None,
1288                },
1289            )],
1290        )
1291        .await;
1292
1293        let client = FakeHttpClient::create(|_| async move {
1294            use http_client::AsyncBody;
1295
1296            let response = Response::builder()
1297                .status(200)
1298                .header("Content-Type", "application/json")
1299                .body(AsyncBody::from(
1300                    serde_json::to_string(&json!({
1301                        "jsonrpc": "2.0",
1302                        "id": 0,
1303                        "result": {
1304                            "protocolVersion": "2024-11-05",
1305                            "capabilities": {},
1306                            "serverInfo": {
1307                                "name": "test-server",
1308                                "version": "1.0.0"
1309                            }
1310                        }
1311                    }))
1312                    .unwrap(),
1313                ))
1314                .unwrap();
1315            Ok(response)
1316        });
1317        cx.update(|cx| cx.set_http_client(client));
1318        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1319        let store = cx.new(|cx| {
1320            ContextServerStore::test_maintain_server_loop(
1321                None,
1322                registry.clone(),
1323                project.read(cx).worktree_store(),
1324                project.downgrade(),
1325                cx,
1326            )
1327        });
1328
1329        let _server_events = assert_server_events(
1330            &store,
1331            vec![
1332                (server_id.clone(), ContextServerStatus::Starting),
1333                (server_id.clone(), ContextServerStatus::Running),
1334            ],
1335            cx,
1336        );
1337        cx.run_until_parked();
1338    }
1339
1340    struct ServerEvents {
1341        received_event_count: Rc<RefCell<usize>>,
1342        expected_event_count: usize,
1343        _subscription: Subscription,
1344    }
1345
1346    impl Drop for ServerEvents {
1347        fn drop(&mut self) {
1348            let actual_event_count = *self.received_event_count.borrow();
1349            assert_eq!(
1350                actual_event_count, self.expected_event_count,
1351                "
1352                Expected to receive {} context server store events, but received {} events",
1353                self.expected_event_count, actual_event_count
1354            );
1355        }
1356    }
1357
1358    fn dummy_server_settings() -> ContextServerSettings {
1359        ContextServerSettings::Stdio {
1360            enabled: true,
1361            command: ContextServerCommand {
1362                path: "somebinary".into(),
1363                args: vec!["arg".to_string()],
1364                env: None,
1365                timeout: None,
1366            },
1367        }
1368    }
1369
1370    fn assert_server_events(
1371        store: &Entity<ContextServerStore>,
1372        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1373        cx: &mut TestAppContext,
1374    ) -> ServerEvents {
1375        cx.update(|cx| {
1376            let mut ix = 0;
1377            let received_event_count = Rc::new(RefCell::new(0));
1378            let expected_event_count = expected_events.len();
1379            let subscription = cx.subscribe(store, {
1380                let received_event_count = received_event_count.clone();
1381                move |_, event, _| match event {
1382                    Event::ServerStatusChanged {
1383                        server_id: actual_server_id,
1384                        status: actual_status,
1385                    } => {
1386                        let (expected_server_id, expected_status) = &expected_events[ix];
1387
1388                        assert_eq!(
1389                            actual_server_id, expected_server_id,
1390                            "Expected different server id at index {}",
1391                            ix
1392                        );
1393                        assert_eq!(
1394                            actual_status, expected_status,
1395                            "Expected different status at index {}",
1396                            ix
1397                        );
1398                        ix += 1;
1399                        *received_event_count.borrow_mut() += 1;
1400                    }
1401                }
1402            });
1403            ServerEvents {
1404                expected_event_count,
1405                received_event_count,
1406                _subscription: subscription,
1407            }
1408        })
1409    }
1410
1411    async fn setup_context_server_test(
1412        cx: &mut TestAppContext,
1413        files: serde_json::Value,
1414        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1415    ) -> (Arc<FakeFs>, Entity<Project>) {
1416        cx.update(|cx| {
1417            let settings_store = SettingsStore::test(cx);
1418            cx.set_global(settings_store);
1419            let mut settings = ProjectSettings::get_global(cx).clone();
1420            for (id, config) in context_server_configurations {
1421                settings.context_servers.insert(id, config);
1422            }
1423            ProjectSettings::override_global(settings, cx);
1424        });
1425
1426        let fs = FakeFs::new(cx.executor());
1427        fs.insert_tree(path!("/test"), files).await;
1428        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1429
1430        (fs, project)
1431    }
1432
1433    struct FakeContextServerDescriptor {
1434        path: PathBuf,
1435    }
1436
1437    impl FakeContextServerDescriptor {
1438        fn new(path: impl Into<PathBuf>) -> Self {
1439            Self { path: path.into() }
1440        }
1441    }
1442
1443    impl ContextServerDescriptor for FakeContextServerDescriptor {
1444        fn command(
1445            &self,
1446            _worktree_store: Entity<WorktreeStore>,
1447            _cx: &AsyncApp,
1448        ) -> Task<Result<ContextServerCommand>> {
1449            Task::ready(Ok(ContextServerCommand {
1450                path: self.path.clone(),
1451                args: vec!["arg1".to_string(), "arg2".to_string()],
1452                env: None,
1453                timeout: None,
1454            }))
1455        }
1456
1457        fn configuration(
1458            &self,
1459            _worktree_store: Entity<WorktreeStore>,
1460            _cx: &AsyncApp,
1461        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1462            Task::ready(Ok(None))
1463        }
1464    }
1465}