context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::sync::Arc;
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use futures::{FutureExt as _, future::join_all};
  10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  11use registry::ContextServerDescriptorRegistry;
  12use settings::{Settings as _, SettingsStore};
  13use util::{ResultExt as _, rel_path::RelPath};
  14
  15use crate::{
  16    Project,
  17    project_settings::{ContextServerSettings, ProjectSettings},
  18    worktree_store::WorktreeStore,
  19};
  20
  21pub fn init(cx: &mut App) {
  22    extension::init(cx);
  23}
  24
  25actions!(
  26    context_server,
  27    [
  28        /// Restarts the context server.
  29        Restart
  30    ]
  31);
  32
  33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  34pub enum ContextServerStatus {
  35    Starting,
  36    Running,
  37    Stopped,
  38    Error(Arc<str>),
  39}
  40
  41impl ContextServerStatus {
  42    fn from_state(state: &ContextServerState) -> Self {
  43        match state {
  44            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  45            ContextServerState::Running { .. } => ContextServerStatus::Running,
  46            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  47            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  48        }
  49    }
  50}
  51
  52enum ContextServerState {
  53    Starting {
  54        server: Arc<ContextServer>,
  55        configuration: Arc<ContextServerConfiguration>,
  56        _task: Task<()>,
  57    },
  58    Running {
  59        server: Arc<ContextServer>,
  60        configuration: Arc<ContextServerConfiguration>,
  61    },
  62    Stopped {
  63        server: Arc<ContextServer>,
  64        configuration: Arc<ContextServerConfiguration>,
  65    },
  66    Error {
  67        server: Arc<ContextServer>,
  68        configuration: Arc<ContextServerConfiguration>,
  69        error: Arc<str>,
  70    },
  71}
  72
  73impl ContextServerState {
  74    pub fn server(&self) -> Arc<ContextServer> {
  75        match self {
  76            ContextServerState::Starting { server, .. } => server.clone(),
  77            ContextServerState::Running { server, .. } => server.clone(),
  78            ContextServerState::Stopped { server, .. } => server.clone(),
  79            ContextServerState::Error { server, .. } => server.clone(),
  80        }
  81    }
  82
  83    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  84        match self {
  85            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  86            ContextServerState::Running { configuration, .. } => configuration.clone(),
  87            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  88            ContextServerState::Error { configuration, .. } => configuration.clone(),
  89        }
  90    }
  91}
  92
  93#[derive(Debug, PartialEq, Eq)]
  94pub enum ContextServerConfiguration {
  95    Custom {
  96        command: ContextServerCommand,
  97    },
  98    Extension {
  99        command: ContextServerCommand,
 100        settings: serde_json::Value,
 101    },
 102    Http {
 103        url: url::Url,
 104        headers: HashMap<String, String>,
 105    },
 106}
 107
 108impl ContextServerConfiguration {
 109    pub fn command(&self) -> Option<&ContextServerCommand> {
 110        match self {
 111            ContextServerConfiguration::Custom { command } => Some(command),
 112            ContextServerConfiguration::Extension { command, .. } => Some(command),
 113            ContextServerConfiguration::Http { .. } => None,
 114        }
 115    }
 116
 117    pub async fn from_settings(
 118        settings: ContextServerSettings,
 119        id: ContextServerId,
 120        registry: Entity<ContextServerDescriptorRegistry>,
 121        worktree_store: Entity<WorktreeStore>,
 122        cx: &AsyncApp,
 123    ) -> Option<Self> {
 124        match settings {
 125            ContextServerSettings::Stdio {
 126                enabled: _,
 127                command,
 128            } => Some(ContextServerConfiguration::Custom { command }),
 129            ContextServerSettings::Extension {
 130                enabled: _,
 131                settings,
 132            } => {
 133                let descriptor = cx
 134                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 135                    .ok()
 136                    .flatten()?;
 137
 138                match descriptor.command(worktree_store, cx).await {
 139                    Ok(command) => {
 140                        Some(ContextServerConfiguration::Extension { command, settings })
 141                    }
 142                    Err(e) => {
 143                        log::error!(
 144                            "Failed to create context server configuration from settings: {e:#}"
 145                        );
 146                        None
 147                    }
 148                }
 149            }
 150            ContextServerSettings::Http {
 151                enabled: _,
 152                url,
 153                headers: auth,
 154            } => {
 155                let url = url::Url::parse(&url).log_err()?;
 156                Some(ContextServerConfiguration::Http { url, headers: auth })
 157            }
 158        }
 159    }
 160}
 161
 162pub type ContextServerFactory =
 163    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 164
 165pub struct ContextServerStore {
 166    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 167    servers: HashMap<ContextServerId, ContextServerState>,
 168    worktree_store: Entity<WorktreeStore>,
 169    project: WeakEntity<Project>,
 170    registry: Entity<ContextServerDescriptorRegistry>,
 171    update_servers_task: Option<Task<Result<()>>>,
 172    context_server_factory: Option<ContextServerFactory>,
 173    needs_server_update: bool,
 174    _subscriptions: Vec<Subscription>,
 175}
 176
 177pub enum Event {
 178    ServerStatusChanged {
 179        server_id: ContextServerId,
 180        status: ContextServerStatus,
 181    },
 182}
 183
 184impl EventEmitter<Event> for ContextServerStore {}
 185
 186impl ContextServerStore {
 187    pub fn new(
 188        worktree_store: Entity<WorktreeStore>,
 189        weak_project: WeakEntity<Project>,
 190        cx: &mut Context<Self>,
 191    ) -> Self {
 192        Self::new_internal(
 193            true,
 194            None,
 195            ContextServerDescriptorRegistry::default_global(cx),
 196            worktree_store,
 197            weak_project,
 198            cx,
 199        )
 200    }
 201
 202    /// Returns all configured context server ids, excluding the ones that are disabled
 203    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 204        self.context_server_settings
 205            .iter()
 206            .filter(|(_, settings)| settings.enabled())
 207            .map(|(id, _)| ContextServerId(id.clone()))
 208            .collect()
 209    }
 210
 211    #[cfg(any(test, feature = "test-support"))]
 212    pub fn test(
 213        registry: Entity<ContextServerDescriptorRegistry>,
 214        worktree_store: Entity<WorktreeStore>,
 215        weak_project: WeakEntity<Project>,
 216        cx: &mut Context<Self>,
 217    ) -> Self {
 218        Self::new_internal(false, None, registry, worktree_store, weak_project, cx)
 219    }
 220
 221    #[cfg(any(test, feature = "test-support"))]
 222    pub fn test_maintain_server_loop(
 223        context_server_factory: Option<ContextServerFactory>,
 224        registry: Entity<ContextServerDescriptorRegistry>,
 225        worktree_store: Entity<WorktreeStore>,
 226        weak_project: WeakEntity<Project>,
 227        cx: &mut Context<Self>,
 228    ) -> Self {
 229        Self::new_internal(
 230            true,
 231            context_server_factory,
 232            registry,
 233            worktree_store,
 234            weak_project,
 235            cx,
 236        )
 237    }
 238
 239    fn new_internal(
 240        maintain_server_loop: bool,
 241        context_server_factory: Option<ContextServerFactory>,
 242        registry: Entity<ContextServerDescriptorRegistry>,
 243        worktree_store: Entity<WorktreeStore>,
 244        weak_project: WeakEntity<Project>,
 245        cx: &mut Context<Self>,
 246    ) -> Self {
 247        let subscriptions = if maintain_server_loop {
 248            vec![
 249                cx.observe(&registry, |this, _registry, cx| {
 250                    this.available_context_servers_changed(cx);
 251                }),
 252                cx.observe_global::<SettingsStore>(|this, cx| {
 253                    let settings = Self::resolve_context_server_settings(&this.worktree_store, cx);
 254                    if &this.context_server_settings == settings {
 255                        return;
 256                    }
 257                    this.context_server_settings = settings.clone();
 258                    this.available_context_servers_changed(cx);
 259                }),
 260            ]
 261        } else {
 262            Vec::new()
 263        };
 264
 265        let mut this = Self {
 266            _subscriptions: subscriptions,
 267            context_server_settings: Self::resolve_context_server_settings(&worktree_store, cx)
 268                .clone(),
 269            worktree_store,
 270            project: weak_project,
 271            registry,
 272            needs_server_update: false,
 273            servers: HashMap::default(),
 274            update_servers_task: None,
 275            context_server_factory,
 276        };
 277        if maintain_server_loop {
 278            this.available_context_servers_changed(cx);
 279        }
 280        this
 281    }
 282
 283    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 284        self.servers.get(id).map(|state| state.server())
 285    }
 286
 287    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 288        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 289            Some(server.clone())
 290        } else {
 291            None
 292        }
 293    }
 294
 295    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 296        self.servers.get(id).map(ContextServerStatus::from_state)
 297    }
 298
 299    pub fn configuration_for_server(
 300        &self,
 301        id: &ContextServerId,
 302    ) -> Option<Arc<ContextServerConfiguration>> {
 303        self.servers.get(id).map(|state| state.configuration())
 304    }
 305
 306    pub fn server_ids(&self, cx: &App) -> HashSet<ContextServerId> {
 307        self.servers
 308            .keys()
 309            .cloned()
 310            .chain(
 311                self.registry
 312                    .read(cx)
 313                    .context_server_descriptors()
 314                    .into_iter()
 315                    .map(|(id, _)| ContextServerId(id)),
 316            )
 317            .collect()
 318    }
 319
 320    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 321        self.servers
 322            .values()
 323            .filter_map(|state| {
 324                if let ContextServerState::Running { server, .. } = state {
 325                    Some(server.clone())
 326                } else {
 327                    None
 328                }
 329            })
 330            .collect()
 331    }
 332
 333    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 334        cx.spawn(async move |this, cx| {
 335            let this = this.upgrade().context("Context server store dropped")?;
 336            let settings = this
 337                .update(cx, |this, _| {
 338                    this.context_server_settings.get(&server.id().0).cloned()
 339                })
 340                .ok()
 341                .flatten()
 342                .context("Failed to get context server settings")?;
 343
 344            if !settings.enabled() {
 345                return Ok(());
 346            }
 347
 348            let (registry, worktree_store) = this.update(cx, |this, _| {
 349                (this.registry.clone(), this.worktree_store.clone())
 350            })?;
 351            let configuration = ContextServerConfiguration::from_settings(
 352                settings,
 353                server.id(),
 354                registry,
 355                worktree_store,
 356                cx,
 357            )
 358            .await
 359            .context("Failed to create context server configuration")?;
 360
 361            this.update(cx, |this, cx| {
 362                this.run_server(server, Arc::new(configuration), cx)
 363            })
 364        })
 365        .detach_and_log_err(cx);
 366    }
 367
 368    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 369        if matches!(
 370            self.servers.get(id),
 371            Some(ContextServerState::Stopped { .. })
 372        ) {
 373            return Ok(());
 374        }
 375
 376        let state = self
 377            .servers
 378            .remove(id)
 379            .context("Context server not found")?;
 380
 381        let server = state.server();
 382        let configuration = state.configuration();
 383        let mut result = Ok(());
 384        if let ContextServerState::Running { server, .. } = &state {
 385            result = server.stop();
 386        }
 387        drop(state);
 388
 389        self.update_server_state(
 390            id.clone(),
 391            ContextServerState::Stopped {
 392                configuration,
 393                server,
 394            },
 395            cx,
 396        );
 397
 398        result
 399    }
 400
 401    fn run_server(
 402        &mut self,
 403        server: Arc<ContextServer>,
 404        configuration: Arc<ContextServerConfiguration>,
 405        cx: &mut Context<Self>,
 406    ) {
 407        let id = server.id();
 408        if matches!(
 409            self.servers.get(&id),
 410            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 411        ) {
 412            self.stop_server(&id, cx).log_err();
 413        }
 414
 415        let task = cx.spawn({
 416            let id = server.id();
 417            let server = server.clone();
 418            let configuration = configuration.clone();
 419            async move |this, cx| {
 420                match server.clone().start(cx).await {
 421                    Ok(_) => {
 422                        debug_assert!(server.client().is_some());
 423
 424                        this.update(cx, |this, cx| {
 425                            this.update_server_state(
 426                                id.clone(),
 427                                ContextServerState::Running {
 428                                    server,
 429                                    configuration,
 430                                },
 431                                cx,
 432                            )
 433                        })
 434                        .log_err()
 435                    }
 436                    Err(err) => {
 437                        log::error!("{} context server failed to start: {}", id, err);
 438                        this.update(cx, |this, cx| {
 439                            this.update_server_state(
 440                                id.clone(),
 441                                ContextServerState::Error {
 442                                    configuration,
 443                                    server,
 444                                    error: err.to_string().into(),
 445                                },
 446                                cx,
 447                            )
 448                        })
 449                        .log_err()
 450                    }
 451                };
 452            }
 453        });
 454
 455        self.update_server_state(
 456            id.clone(),
 457            ContextServerState::Starting {
 458                configuration,
 459                _task: task,
 460                server,
 461            },
 462            cx,
 463        );
 464    }
 465
 466    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 467        let state = self
 468            .servers
 469            .remove(id)
 470            .context("Context server not found")?;
 471        drop(state);
 472        cx.emit(Event::ServerStatusChanged {
 473            server_id: id.clone(),
 474            status: ContextServerStatus::Stopped,
 475        });
 476        Ok(())
 477    }
 478
 479    fn create_context_server(
 480        &self,
 481        id: ContextServerId,
 482        configuration: Arc<ContextServerConfiguration>,
 483        cx: &mut Context<Self>,
 484    ) -> Result<Arc<ContextServer>> {
 485        if let Some(factory) = self.context_server_factory.as_ref() {
 486            return Ok(factory(id, configuration));
 487        }
 488
 489        match configuration.as_ref() {
 490            ContextServerConfiguration::Http { url, headers } => Ok(Arc::new(ContextServer::http(
 491                id,
 492                url,
 493                headers.clone(),
 494                cx.http_client(),
 495                cx.background_executor().clone(),
 496            )?)),
 497            _ => {
 498                let root_path = self
 499                    .project
 500                    .read_with(cx, |project, cx| project.active_project_directory(cx))
 501                    .ok()
 502                    .flatten()
 503                    .or_else(|| {
 504                        self.worktree_store.read_with(cx, |store, cx| {
 505                            store.visible_worktrees(cx).fold(None, |acc, item| {
 506                                if acc.is_none() {
 507                                    item.read(cx).root_dir()
 508                                } else {
 509                                    acc
 510                                }
 511                            })
 512                        })
 513                    });
 514                Ok(Arc::new(ContextServer::stdio(
 515                    id,
 516                    configuration.command().unwrap().clone(),
 517                    root_path,
 518                )))
 519            }
 520        }
 521    }
 522
 523    fn resolve_context_server_settings<'a>(
 524        worktree_store: &'a Entity<WorktreeStore>,
 525        cx: &'a App,
 526    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 527        let location = worktree_store
 528            .read(cx)
 529            .visible_worktrees(cx)
 530            .next()
 531            .map(|worktree| settings::SettingsLocation {
 532                worktree_id: worktree.read(cx).id(),
 533                path: RelPath::empty(),
 534            });
 535        &ProjectSettings::get(location, cx).context_servers
 536    }
 537
 538    fn update_server_state(
 539        &mut self,
 540        id: ContextServerId,
 541        state: ContextServerState,
 542        cx: &mut Context<Self>,
 543    ) {
 544        let status = ContextServerStatus::from_state(&state);
 545        self.servers.insert(id.clone(), state);
 546        cx.emit(Event::ServerStatusChanged {
 547            server_id: id,
 548            status,
 549        });
 550    }
 551
 552    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 553        if self.update_servers_task.is_some() {
 554            self.needs_server_update = true;
 555        } else {
 556            self.needs_server_update = false;
 557            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 558                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 559                    log::error!("Error maintaining context servers: {}", err);
 560                }
 561
 562                this.update(cx, |this, cx| {
 563                    this.update_servers_task.take();
 564                    if this.needs_server_update {
 565                        this.available_context_servers_changed(cx);
 566                    }
 567                })?;
 568
 569                Ok(())
 570            }));
 571        }
 572    }
 573
 574    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 575        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
 576            (
 577                this.context_server_settings.clone(),
 578                this.registry.clone(),
 579                this.worktree_store.clone(),
 580            )
 581        })?;
 582
 583        for (id, _) in
 584            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 585        {
 586            configured_servers
 587                .entry(id)
 588                .or_insert(ContextServerSettings::default_extension());
 589        }
 590
 591        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
 592            configured_servers
 593                .into_iter()
 594                .partition(|(_, settings)| settings.enabled());
 595
 596        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
 597            let id = ContextServerId(id);
 598            ContextServerConfiguration::from_settings(
 599                settings,
 600                id.clone(),
 601                registry.clone(),
 602                worktree_store.clone(),
 603                cx,
 604            )
 605            .map(|config| (id, config))
 606        }))
 607        .await
 608        .into_iter()
 609        .filter_map(|(id, config)| config.map(|config| (id, config)))
 610        .collect::<HashMap<_, _>>();
 611
 612        let mut servers_to_start = Vec::new();
 613        let mut servers_to_remove = HashSet::default();
 614        let mut servers_to_stop = HashSet::default();
 615
 616        this.update(cx, |this, cx| {
 617            for server_id in this.servers.keys() {
 618                // All servers that are not in desired_servers should be removed from the store.
 619                // This can happen if the user removed a server from the context server settings.
 620                if !configured_servers.contains_key(server_id) {
 621                    if disabled_servers.contains_key(&server_id.0) {
 622                        servers_to_stop.insert(server_id.clone());
 623                    } else {
 624                        servers_to_remove.insert(server_id.clone());
 625                    }
 626                }
 627            }
 628
 629            for (id, config) in configured_servers {
 630                let state = this.servers.get(&id);
 631                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
 632                let existing_config = state.as_ref().map(|state| state.configuration());
 633                if existing_config.as_deref() != Some(&config) || is_stopped {
 634                    let config = Arc::new(config);
 635                    let server = this.create_context_server(id.clone(), config.clone(), cx)?;
 636                    servers_to_start.push((server, config));
 637                    if this.servers.contains_key(&id) {
 638                        servers_to_stop.insert(id);
 639                    }
 640                }
 641            }
 642
 643            anyhow::Ok(())
 644        })??;
 645
 646        this.update(cx, |this, cx| {
 647            for id in servers_to_stop {
 648                this.stop_server(&id, cx)?;
 649            }
 650            for id in servers_to_remove {
 651                this.remove_server(&id, cx)?;
 652            }
 653            for (server, config) in servers_to_start {
 654                this.run_server(server, config, cx);
 655            }
 656            anyhow::Ok(())
 657        })?
 658    }
 659}
 660
 661#[cfg(test)]
 662mod tests {
 663    use super::*;
 664    use crate::{
 665        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 666        project_settings::ProjectSettings,
 667    };
 668    use context_server::test::create_fake_transport;
 669    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 670    use http_client::{FakeHttpClient, Response};
 671    use serde_json::json;
 672    use std::{cell::RefCell, path::PathBuf, rc::Rc};
 673    use util::path;
 674
 675    #[gpui::test]
 676    async fn test_context_server_status(cx: &mut TestAppContext) {
 677        const SERVER_1_ID: &str = "mcp-1";
 678        const SERVER_2_ID: &str = "mcp-2";
 679
 680        let (_fs, project) = setup_context_server_test(
 681            cx,
 682            json!({"code.rs": ""}),
 683            vec![
 684                (SERVER_1_ID.into(), dummy_server_settings()),
 685                (SERVER_2_ID.into(), dummy_server_settings()),
 686            ],
 687        )
 688        .await;
 689
 690        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 691        let store = cx.new(|cx| {
 692            ContextServerStore::test(
 693                registry.clone(),
 694                project.read(cx).worktree_store(),
 695                project.downgrade(),
 696                cx,
 697            )
 698        });
 699
 700        let server_1_id = ContextServerId(SERVER_1_ID.into());
 701        let server_2_id = ContextServerId(SERVER_2_ID.into());
 702
 703        let server_1 = Arc::new(ContextServer::new(
 704            server_1_id.clone(),
 705            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 706        ));
 707        let server_2 = Arc::new(ContextServer::new(
 708            server_2_id.clone(),
 709            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 710        ));
 711
 712        store.update(cx, |store, cx| store.start_server(server_1, cx));
 713
 714        cx.run_until_parked();
 715
 716        cx.update(|cx| {
 717            assert_eq!(
 718                store.read(cx).status_for_server(&server_1_id),
 719                Some(ContextServerStatus::Running)
 720            );
 721            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 722        });
 723
 724        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 725
 726        cx.run_until_parked();
 727
 728        cx.update(|cx| {
 729            assert_eq!(
 730                store.read(cx).status_for_server(&server_1_id),
 731                Some(ContextServerStatus::Running)
 732            );
 733            assert_eq!(
 734                store.read(cx).status_for_server(&server_2_id),
 735                Some(ContextServerStatus::Running)
 736            );
 737        });
 738
 739        store
 740            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 741            .unwrap();
 742
 743        cx.update(|cx| {
 744            assert_eq!(
 745                store.read(cx).status_for_server(&server_1_id),
 746                Some(ContextServerStatus::Running)
 747            );
 748            assert_eq!(
 749                store.read(cx).status_for_server(&server_2_id),
 750                Some(ContextServerStatus::Stopped)
 751            );
 752        });
 753    }
 754
 755    #[gpui::test]
 756    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 757        const SERVER_1_ID: &str = "mcp-1";
 758        const SERVER_2_ID: &str = "mcp-2";
 759
 760        let (_fs, project) = setup_context_server_test(
 761            cx,
 762            json!({"code.rs": ""}),
 763            vec![
 764                (SERVER_1_ID.into(), dummy_server_settings()),
 765                (SERVER_2_ID.into(), dummy_server_settings()),
 766            ],
 767        )
 768        .await;
 769
 770        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 771        let store = cx.new(|cx| {
 772            ContextServerStore::test(
 773                registry.clone(),
 774                project.read(cx).worktree_store(),
 775                project.downgrade(),
 776                cx,
 777            )
 778        });
 779
 780        let server_1_id = ContextServerId(SERVER_1_ID.into());
 781        let server_2_id = ContextServerId(SERVER_2_ID.into());
 782
 783        let server_1 = Arc::new(ContextServer::new(
 784            server_1_id.clone(),
 785            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 786        ));
 787        let server_2 = Arc::new(ContextServer::new(
 788            server_2_id.clone(),
 789            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 790        ));
 791
 792        let _server_events = assert_server_events(
 793            &store,
 794            vec![
 795                (server_1_id.clone(), ContextServerStatus::Starting),
 796                (server_1_id, ContextServerStatus::Running),
 797                (server_2_id.clone(), ContextServerStatus::Starting),
 798                (server_2_id.clone(), ContextServerStatus::Running),
 799                (server_2_id.clone(), ContextServerStatus::Stopped),
 800            ],
 801            cx,
 802        );
 803
 804        store.update(cx, |store, cx| store.start_server(server_1, cx));
 805
 806        cx.run_until_parked();
 807
 808        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 809
 810        cx.run_until_parked();
 811
 812        store
 813            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 814            .unwrap();
 815    }
 816
 817    #[gpui::test(iterations = 25)]
 818    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 819        const SERVER_1_ID: &str = "mcp-1";
 820
 821        let (_fs, project) = setup_context_server_test(
 822            cx,
 823            json!({"code.rs": ""}),
 824            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 825        )
 826        .await;
 827
 828        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 829        let store = cx.new(|cx| {
 830            ContextServerStore::test(
 831                registry.clone(),
 832                project.read(cx).worktree_store(),
 833                project.downgrade(),
 834                cx,
 835            )
 836        });
 837
 838        let server_id = ContextServerId(SERVER_1_ID.into());
 839
 840        let server_with_same_id_1 = Arc::new(ContextServer::new(
 841            server_id.clone(),
 842            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 843        ));
 844        let server_with_same_id_2 = Arc::new(ContextServer::new(
 845            server_id.clone(),
 846            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 847        ));
 848
 849        // If we start another server with the same id, we should report that we stopped the previous one
 850        let _server_events = assert_server_events(
 851            &store,
 852            vec![
 853                (server_id.clone(), ContextServerStatus::Starting),
 854                (server_id.clone(), ContextServerStatus::Stopped),
 855                (server_id.clone(), ContextServerStatus::Starting),
 856                (server_id.clone(), ContextServerStatus::Running),
 857            ],
 858            cx,
 859        );
 860
 861        store.update(cx, |store, cx| {
 862            store.start_server(server_with_same_id_1.clone(), cx)
 863        });
 864        store.update(cx, |store, cx| {
 865            store.start_server(server_with_same_id_2.clone(), cx)
 866        });
 867
 868        cx.run_until_parked();
 869
 870        cx.update(|cx| {
 871            assert_eq!(
 872                store.read(cx).status_for_server(&server_id),
 873                Some(ContextServerStatus::Running)
 874            );
 875        });
 876    }
 877
 878    #[gpui::test]
 879    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 880        const SERVER_1_ID: &str = "mcp-1";
 881        const SERVER_2_ID: &str = "mcp-2";
 882
 883        let server_1_id = ContextServerId(SERVER_1_ID.into());
 884        let server_2_id = ContextServerId(SERVER_2_ID.into());
 885
 886        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 887
 888        let (_fs, project) = setup_context_server_test(
 889            cx,
 890            json!({"code.rs": ""}),
 891            vec![(
 892                SERVER_1_ID.into(),
 893                ContextServerSettings::Extension {
 894                    enabled: true,
 895                    settings: json!({
 896                        "somevalue": true
 897                    }),
 898                },
 899            )],
 900        )
 901        .await;
 902
 903        let executor = cx.executor();
 904        let registry = cx.new(|cx| {
 905            let mut registry = ContextServerDescriptorRegistry::new();
 906            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1, cx);
 907            registry
 908        });
 909        let store = cx.new(|cx| {
 910            ContextServerStore::test_maintain_server_loop(
 911                Some(Box::new(move |id, _| {
 912                    Arc::new(ContextServer::new(
 913                        id.clone(),
 914                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 915                    ))
 916                })),
 917                registry.clone(),
 918                project.read(cx).worktree_store(),
 919                project.downgrade(),
 920                cx,
 921            )
 922        });
 923
 924        // Ensure that mcp-1 starts up
 925        {
 926            let _server_events = assert_server_events(
 927                &store,
 928                vec![
 929                    (server_1_id.clone(), ContextServerStatus::Starting),
 930                    (server_1_id.clone(), ContextServerStatus::Running),
 931                ],
 932                cx,
 933            );
 934            cx.run_until_parked();
 935        }
 936
 937        // Ensure that mcp-1 is restarted when the configuration was changed
 938        {
 939            let _server_events = assert_server_events(
 940                &store,
 941                vec![
 942                    (server_1_id.clone(), ContextServerStatus::Stopped),
 943                    (server_1_id.clone(), ContextServerStatus::Starting),
 944                    (server_1_id.clone(), ContextServerStatus::Running),
 945                ],
 946                cx,
 947            );
 948            set_context_server_configuration(
 949                vec![(
 950                    server_1_id.0.clone(),
 951                    settings::ContextServerSettingsContent::Extension {
 952                        enabled: true,
 953                        settings: json!({
 954                            "somevalue": false
 955                        }),
 956                    },
 957                )],
 958                cx,
 959            );
 960
 961            cx.run_until_parked();
 962        }
 963
 964        // Ensure that mcp-1 is not restarted when the configuration was not changed
 965        {
 966            let _server_events = assert_server_events(&store, vec![], cx);
 967            set_context_server_configuration(
 968                vec![(
 969                    server_1_id.0.clone(),
 970                    settings::ContextServerSettingsContent::Extension {
 971                        enabled: true,
 972                        settings: json!({
 973                            "somevalue": false
 974                        }),
 975                    },
 976                )],
 977                cx,
 978            );
 979
 980            cx.run_until_parked();
 981        }
 982
 983        // Ensure that mcp-2 is started once it is added to the settings
 984        {
 985            let _server_events = assert_server_events(
 986                &store,
 987                vec![
 988                    (server_2_id.clone(), ContextServerStatus::Starting),
 989                    (server_2_id.clone(), ContextServerStatus::Running),
 990                ],
 991                cx,
 992            );
 993            set_context_server_configuration(
 994                vec![
 995                    (
 996                        server_1_id.0.clone(),
 997                        settings::ContextServerSettingsContent::Extension {
 998                            enabled: true,
 999                            settings: json!({
1000                                "somevalue": false
1001                            }),
1002                        },
1003                    ),
1004                    (
1005                        server_2_id.0.clone(),
1006                        settings::ContextServerSettingsContent::Stdio {
1007                            enabled: true,
1008                            command: ContextServerCommand {
1009                                path: "somebinary".into(),
1010                                args: vec!["arg".to_string()],
1011                                env: None,
1012                                timeout: None,
1013                            },
1014                        },
1015                    ),
1016                ],
1017                cx,
1018            );
1019
1020            cx.run_until_parked();
1021        }
1022
1023        // Ensure that mcp-2 is restarted once the args have changed
1024        {
1025            let _server_events = assert_server_events(
1026                &store,
1027                vec![
1028                    (server_2_id.clone(), ContextServerStatus::Stopped),
1029                    (server_2_id.clone(), ContextServerStatus::Starting),
1030                    (server_2_id.clone(), ContextServerStatus::Running),
1031                ],
1032                cx,
1033            );
1034            set_context_server_configuration(
1035                vec![
1036                    (
1037                        server_1_id.0.clone(),
1038                        settings::ContextServerSettingsContent::Extension {
1039                            enabled: true,
1040                            settings: json!({
1041                                "somevalue": false
1042                            }),
1043                        },
1044                    ),
1045                    (
1046                        server_2_id.0.clone(),
1047                        settings::ContextServerSettingsContent::Stdio {
1048                            enabled: true,
1049                            command: ContextServerCommand {
1050                                path: "somebinary".into(),
1051                                args: vec!["anotherArg".to_string()],
1052                                env: None,
1053                                timeout: None,
1054                            },
1055                        },
1056                    ),
1057                ],
1058                cx,
1059            );
1060
1061            cx.run_until_parked();
1062        }
1063
1064        // Ensure that mcp-2 is removed once it is removed from the settings
1065        {
1066            let _server_events = assert_server_events(
1067                &store,
1068                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
1069                cx,
1070            );
1071            set_context_server_configuration(
1072                vec![(
1073                    server_1_id.0.clone(),
1074                    settings::ContextServerSettingsContent::Extension {
1075                        enabled: true,
1076                        settings: json!({
1077                            "somevalue": false
1078                        }),
1079                    },
1080                )],
1081                cx,
1082            );
1083
1084            cx.run_until_parked();
1085
1086            cx.update(|cx| {
1087                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1088            });
1089        }
1090
1091        // Ensure that nothing happens if the settings do not change
1092        {
1093            let _server_events = assert_server_events(&store, vec![], cx);
1094            set_context_server_configuration(
1095                vec![(
1096                    server_1_id.0.clone(),
1097                    settings::ContextServerSettingsContent::Extension {
1098                        enabled: true,
1099                        settings: json!({
1100                            "somevalue": false
1101                        }),
1102                    },
1103                )],
1104                cx,
1105            );
1106
1107            cx.run_until_parked();
1108
1109            cx.update(|cx| {
1110                assert_eq!(
1111                    store.read(cx).status_for_server(&server_1_id),
1112                    Some(ContextServerStatus::Running)
1113                );
1114                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
1115            });
1116        }
1117    }
1118
1119    #[gpui::test]
1120    async fn test_context_server_enabled_disabled(cx: &mut TestAppContext) {
1121        const SERVER_1_ID: &str = "mcp-1";
1122
1123        let server_1_id = ContextServerId(SERVER_1_ID.into());
1124
1125        let (_fs, project) = setup_context_server_test(
1126            cx,
1127            json!({"code.rs": ""}),
1128            vec![(
1129                SERVER_1_ID.into(),
1130                ContextServerSettings::Stdio {
1131                    enabled: true,
1132                    command: ContextServerCommand {
1133                        path: "somebinary".into(),
1134                        args: vec!["arg".to_string()],
1135                        env: None,
1136                        timeout: None,
1137                    },
1138                },
1139            )],
1140        )
1141        .await;
1142
1143        let executor = cx.executor();
1144        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1145        let store = cx.new(|cx| {
1146            ContextServerStore::test_maintain_server_loop(
1147                Some(Box::new(move |id, _| {
1148                    Arc::new(ContextServer::new(
1149                        id.clone(),
1150                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
1151                    ))
1152                })),
1153                registry.clone(),
1154                project.read(cx).worktree_store(),
1155                project.downgrade(),
1156                cx,
1157            )
1158        });
1159
1160        // Ensure that mcp-1 starts up
1161        {
1162            let _server_events = assert_server_events(
1163                &store,
1164                vec![
1165                    (server_1_id.clone(), ContextServerStatus::Starting),
1166                    (server_1_id.clone(), ContextServerStatus::Running),
1167                ],
1168                cx,
1169            );
1170            cx.run_until_parked();
1171        }
1172
1173        // Ensure that mcp-1 is stopped once it is disabled.
1174        {
1175            let _server_events = assert_server_events(
1176                &store,
1177                vec![(server_1_id.clone(), ContextServerStatus::Stopped)],
1178                cx,
1179            );
1180            set_context_server_configuration(
1181                vec![(
1182                    server_1_id.0.clone(),
1183                    settings::ContextServerSettingsContent::Stdio {
1184                        enabled: false,
1185                        command: ContextServerCommand {
1186                            path: "somebinary".into(),
1187                            args: vec!["arg".to_string()],
1188                            env: None,
1189                            timeout: None,
1190                        },
1191                    },
1192                )],
1193                cx,
1194            );
1195
1196            cx.run_until_parked();
1197        }
1198
1199        // Ensure that mcp-1 is started once it is enabled again.
1200        {
1201            let _server_events = assert_server_events(
1202                &store,
1203                vec![
1204                    (server_1_id.clone(), ContextServerStatus::Starting),
1205                    (server_1_id.clone(), ContextServerStatus::Running),
1206                ],
1207                cx,
1208            );
1209            set_context_server_configuration(
1210                vec![(
1211                    server_1_id.0.clone(),
1212                    settings::ContextServerSettingsContent::Stdio {
1213                        enabled: true,
1214                        command: ContextServerCommand {
1215                            path: "somebinary".into(),
1216                            args: vec!["arg".to_string()],
1217                            timeout: None,
1218                            env: None,
1219                        },
1220                    },
1221                )],
1222                cx,
1223            );
1224
1225            cx.run_until_parked();
1226        }
1227    }
1228
1229    fn set_context_server_configuration(
1230        context_servers: Vec<(Arc<str>, settings::ContextServerSettingsContent)>,
1231        cx: &mut TestAppContext,
1232    ) {
1233        cx.update(|cx| {
1234            SettingsStore::update_global(cx, |store, cx| {
1235                store.update_user_settings(cx, |content| {
1236                    content.project.context_servers.clear();
1237                    for (id, config) in context_servers {
1238                        content.project.context_servers.insert(id, config);
1239                    }
1240                });
1241            })
1242        });
1243    }
1244
1245    #[gpui::test]
1246    async fn test_remote_context_server(cx: &mut TestAppContext) {
1247        const SERVER_ID: &str = "remote-server";
1248        let server_id = ContextServerId(SERVER_ID.into());
1249        let server_url = "http://example.com/api";
1250
1251        let (_fs, project) = setup_context_server_test(
1252            cx,
1253            json!({ "code.rs": "" }),
1254            vec![(
1255                SERVER_ID.into(),
1256                ContextServerSettings::Http {
1257                    enabled: true,
1258                    url: server_url.to_string(),
1259                    headers: Default::default(),
1260                },
1261            )],
1262        )
1263        .await;
1264
1265        let client = FakeHttpClient::create(|_| async move {
1266            use http_client::AsyncBody;
1267
1268            let response = Response::builder()
1269                .status(200)
1270                .header("Content-Type", "application/json")
1271                .body(AsyncBody::from(
1272                    serde_json::to_string(&json!({
1273                        "jsonrpc": "2.0",
1274                        "id": 0,
1275                        "result": {
1276                            "protocolVersion": "2024-11-05",
1277                            "capabilities": {},
1278                            "serverInfo": {
1279                                "name": "test-server",
1280                                "version": "1.0.0"
1281                            }
1282                        }
1283                    }))
1284                    .unwrap(),
1285                ))
1286                .unwrap();
1287            Ok(response)
1288        });
1289        cx.update(|cx| cx.set_http_client(client));
1290        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
1291        let store = cx.new(|cx| {
1292            ContextServerStore::test_maintain_server_loop(
1293                None,
1294                registry.clone(),
1295                project.read(cx).worktree_store(),
1296                project.downgrade(),
1297                cx,
1298            )
1299        });
1300
1301        let _server_events = assert_server_events(
1302            &store,
1303            vec![
1304                (server_id.clone(), ContextServerStatus::Starting),
1305                (server_id.clone(), ContextServerStatus::Running),
1306            ],
1307            cx,
1308        );
1309        cx.run_until_parked();
1310    }
1311
1312    struct ServerEvents {
1313        received_event_count: Rc<RefCell<usize>>,
1314        expected_event_count: usize,
1315        _subscription: Subscription,
1316    }
1317
1318    impl Drop for ServerEvents {
1319        fn drop(&mut self) {
1320            let actual_event_count = *self.received_event_count.borrow();
1321            assert_eq!(
1322                actual_event_count, self.expected_event_count,
1323                "
1324                Expected to receive {} context server store events, but received {} events",
1325                self.expected_event_count, actual_event_count
1326            );
1327        }
1328    }
1329
1330    fn dummy_server_settings() -> ContextServerSettings {
1331        ContextServerSettings::Stdio {
1332            enabled: true,
1333            command: ContextServerCommand {
1334                path: "somebinary".into(),
1335                args: vec!["arg".to_string()],
1336                env: None,
1337                timeout: None,
1338            },
1339        }
1340    }
1341
1342    fn assert_server_events(
1343        store: &Entity<ContextServerStore>,
1344        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1345        cx: &mut TestAppContext,
1346    ) -> ServerEvents {
1347        cx.update(|cx| {
1348            let mut ix = 0;
1349            let received_event_count = Rc::new(RefCell::new(0));
1350            let expected_event_count = expected_events.len();
1351            let subscription = cx.subscribe(store, {
1352                let received_event_count = received_event_count.clone();
1353                move |_, event, _| match event {
1354                    Event::ServerStatusChanged {
1355                        server_id: actual_server_id,
1356                        status: actual_status,
1357                    } => {
1358                        let (expected_server_id, expected_status) = &expected_events[ix];
1359
1360                        assert_eq!(
1361                            actual_server_id, expected_server_id,
1362                            "Expected different server id at index {}",
1363                            ix
1364                        );
1365                        assert_eq!(
1366                            actual_status, expected_status,
1367                            "Expected different status at index {}",
1368                            ix
1369                        );
1370                        ix += 1;
1371                        *received_event_count.borrow_mut() += 1;
1372                    }
1373                }
1374            });
1375            ServerEvents {
1376                expected_event_count,
1377                received_event_count,
1378                _subscription: subscription,
1379            }
1380        })
1381    }
1382
1383    async fn setup_context_server_test(
1384        cx: &mut TestAppContext,
1385        files: serde_json::Value,
1386        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1387    ) -> (Arc<FakeFs>, Entity<Project>) {
1388        cx.update(|cx| {
1389            let settings_store = SettingsStore::test(cx);
1390            cx.set_global(settings_store);
1391            let mut settings = ProjectSettings::get_global(cx).clone();
1392            for (id, config) in context_server_configurations {
1393                settings.context_servers.insert(id, config);
1394            }
1395            ProjectSettings::override_global(settings, cx);
1396        });
1397
1398        let fs = FakeFs::new(cx.executor());
1399        fs.insert_tree(path!("/test"), files).await;
1400        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1401
1402        (fs, project)
1403    }
1404
1405    struct FakeContextServerDescriptor {
1406        path: PathBuf,
1407    }
1408
1409    impl FakeContextServerDescriptor {
1410        fn new(path: impl Into<PathBuf>) -> Self {
1411            Self { path: path.into() }
1412        }
1413    }
1414
1415    impl ContextServerDescriptor for FakeContextServerDescriptor {
1416        fn command(
1417            &self,
1418            _worktree_store: Entity<WorktreeStore>,
1419            _cx: &AsyncApp,
1420        ) -> Task<Result<ContextServerCommand>> {
1421            Task::ready(Ok(ContextServerCommand {
1422                path: self.path.clone(),
1423                args: vec!["arg1".to_string(), "arg2".to_string()],
1424                env: None,
1425                timeout: None,
1426            }))
1427        }
1428
1429        fn configuration(
1430            &self,
1431            _worktree_store: Entity<WorktreeStore>,
1432            _cx: &AsyncApp,
1433        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1434            Task::ready(Ok(None))
1435        }
1436    }
1437}