context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::{path::Path, sync::Arc};
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use futures::{FutureExt as _, future::join_all};
  10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  11use registry::ContextServerDescriptorRegistry;
  12use settings::{Settings as _, SettingsStore};
  13use util::ResultExt as _;
  14
  15use crate::{
  16    project_settings::{ContextServerSettings, ProjectSettings},
  17    worktree_store::WorktreeStore,
  18};
  19
  20pub fn init(cx: &mut App) {
  21    extension::init(cx);
  22}
  23
  24actions!(context_server, [Restart]);
  25
  26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  27pub enum ContextServerStatus {
  28    Starting,
  29    Running,
  30    Stopped,
  31    Error(Arc<str>),
  32}
  33
  34impl ContextServerStatus {
  35    fn from_state(state: &ContextServerState) -> Self {
  36        match state {
  37            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  38            ContextServerState::Running { .. } => ContextServerStatus::Running,
  39            ContextServerState::Stopped { error, .. } => {
  40                if let Some(error) = error {
  41                    ContextServerStatus::Error(error.clone())
  42                } else {
  43                    ContextServerStatus::Stopped
  44                }
  45            }
  46        }
  47    }
  48}
  49
  50enum ContextServerState {
  51    Starting {
  52        server: Arc<ContextServer>,
  53        configuration: Arc<ContextServerConfiguration>,
  54        _task: Task<()>,
  55    },
  56    Running {
  57        server: Arc<ContextServer>,
  58        configuration: Arc<ContextServerConfiguration>,
  59    },
  60    Stopped {
  61        server: Arc<ContextServer>,
  62        configuration: Arc<ContextServerConfiguration>,
  63        error: Option<Arc<str>>,
  64    },
  65}
  66
  67impl ContextServerState {
  68    pub fn server(&self) -> Arc<ContextServer> {
  69        match self {
  70            ContextServerState::Starting { server, .. } => server.clone(),
  71            ContextServerState::Running { server, .. } => server.clone(),
  72            ContextServerState::Stopped { server, .. } => server.clone(),
  73        }
  74    }
  75
  76    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  77        match self {
  78            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  79            ContextServerState::Running { configuration, .. } => configuration.clone(),
  80            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  81        }
  82    }
  83}
  84
  85#[derive(PartialEq, Eq)]
  86pub enum ContextServerConfiguration {
  87    Custom {
  88        command: ContextServerCommand,
  89    },
  90    Extension {
  91        command: ContextServerCommand,
  92        settings: serde_json::Value,
  93    },
  94}
  95
  96impl ContextServerConfiguration {
  97    pub fn command(&self) -> &ContextServerCommand {
  98        match self {
  99            ContextServerConfiguration::Custom { command } => command,
 100            ContextServerConfiguration::Extension { command, .. } => command,
 101        }
 102    }
 103
 104    pub async fn from_settings(
 105        settings: ContextServerSettings,
 106        id: ContextServerId,
 107        registry: Entity<ContextServerDescriptorRegistry>,
 108        worktree_store: Entity<WorktreeStore>,
 109        cx: &AsyncApp,
 110    ) -> Option<Self> {
 111        match settings {
 112            ContextServerSettings::Custom { command } => {
 113                Some(ContextServerConfiguration::Custom { command })
 114            }
 115            ContextServerSettings::Extension { settings } => {
 116                let descriptor = cx
 117                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 118                    .ok()
 119                    .flatten()?;
 120
 121                let command = descriptor.command(worktree_store, cx).await.log_err()?;
 122
 123                Some(ContextServerConfiguration::Extension { command, settings })
 124            }
 125        }
 126    }
 127}
 128
 129pub type ContextServerFactory =
 130    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 131
 132pub struct ContextServerStore {
 133    servers: HashMap<ContextServerId, ContextServerState>,
 134    worktree_store: Entity<WorktreeStore>,
 135    registry: Entity<ContextServerDescriptorRegistry>,
 136    update_servers_task: Option<Task<Result<()>>>,
 137    context_server_factory: Option<ContextServerFactory>,
 138    needs_server_update: bool,
 139    _subscriptions: Vec<Subscription>,
 140}
 141
 142pub enum Event {
 143    ServerStatusChanged {
 144        server_id: ContextServerId,
 145        status: ContextServerStatus,
 146    },
 147}
 148
 149impl EventEmitter<Event> for ContextServerStore {}
 150
 151impl ContextServerStore {
 152    pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
 153        Self::new_internal(
 154            true,
 155            None,
 156            ContextServerDescriptorRegistry::default_global(cx),
 157            worktree_store,
 158            cx,
 159        )
 160    }
 161
 162    #[cfg(any(test, feature = "test-support"))]
 163    pub fn test(
 164        registry: Entity<ContextServerDescriptorRegistry>,
 165        worktree_store: Entity<WorktreeStore>,
 166        cx: &mut Context<Self>,
 167    ) -> Self {
 168        Self::new_internal(false, None, registry, worktree_store, cx)
 169    }
 170
 171    #[cfg(any(test, feature = "test-support"))]
 172    pub fn test_maintain_server_loop(
 173        context_server_factory: ContextServerFactory,
 174        registry: Entity<ContextServerDescriptorRegistry>,
 175        worktree_store: Entity<WorktreeStore>,
 176        cx: &mut Context<Self>,
 177    ) -> Self {
 178        Self::new_internal(
 179            true,
 180            Some(context_server_factory),
 181            registry,
 182            worktree_store,
 183            cx,
 184        )
 185    }
 186
 187    fn new_internal(
 188        maintain_server_loop: bool,
 189        context_server_factory: Option<ContextServerFactory>,
 190        registry: Entity<ContextServerDescriptorRegistry>,
 191        worktree_store: Entity<WorktreeStore>,
 192        cx: &mut Context<Self>,
 193    ) -> Self {
 194        let subscriptions = if maintain_server_loop {
 195            vec![
 196                cx.observe(&registry, |this, _registry, cx| {
 197                    this.available_context_servers_changed(cx);
 198                }),
 199                cx.observe_global::<SettingsStore>(|this, cx| {
 200                    this.available_context_servers_changed(cx);
 201                }),
 202            ]
 203        } else {
 204            Vec::new()
 205        };
 206
 207        let mut this = Self {
 208            _subscriptions: subscriptions,
 209            worktree_store,
 210            registry,
 211            needs_server_update: false,
 212            servers: HashMap::default(),
 213            update_servers_task: None,
 214            context_server_factory,
 215        };
 216        if maintain_server_loop {
 217            this.available_context_servers_changed(cx);
 218        }
 219        this
 220    }
 221
 222    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 223        self.servers.get(id).map(|state| state.server())
 224    }
 225
 226    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 227        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 228            Some(server.clone())
 229        } else {
 230            None
 231        }
 232    }
 233
 234    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 235        self.servers.get(id).map(ContextServerStatus::from_state)
 236    }
 237
 238    pub fn configuration_for_server(
 239        &self,
 240        id: &ContextServerId,
 241    ) -> Option<Arc<ContextServerConfiguration>> {
 242        self.servers.get(id).map(|state| state.configuration())
 243    }
 244
 245    pub fn all_server_ids(&self) -> Vec<ContextServerId> {
 246        self.servers.keys().cloned().collect()
 247    }
 248
 249    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 250        self.servers
 251            .values()
 252            .filter_map(|state| {
 253                if let ContextServerState::Running { server, .. } = state {
 254                    Some(server.clone())
 255                } else {
 256                    None
 257                }
 258            })
 259            .collect()
 260    }
 261
 262    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 263        cx.spawn(async move |this, cx| {
 264            let this = this.upgrade().context("Context server store dropped")?;
 265            let settings = this
 266                .update(cx, |this, cx| {
 267                    this.context_server_settings(cx)
 268                        .get(&server.id().0)
 269                        .cloned()
 270                })
 271                .ok()
 272                .flatten()
 273                .context("Failed to get context server settings")?;
 274
 275            let (registry, worktree_store) = this.update(cx, |this, _| {
 276                (this.registry.clone(), this.worktree_store.clone())
 277            })?;
 278            let configuration = ContextServerConfiguration::from_settings(
 279                settings,
 280                server.id(),
 281                registry,
 282                worktree_store,
 283                cx,
 284            )
 285            .await
 286            .context("Failed to create context server configuration")?;
 287
 288            this.update(cx, |this, cx| {
 289                this.run_server(server, Arc::new(configuration), cx)
 290            })
 291        })
 292        .detach_and_log_err(cx);
 293    }
 294
 295    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 296        let state = self
 297            .servers
 298            .remove(id)
 299            .context("Context server not found")?;
 300
 301        let server = state.server();
 302        let configuration = state.configuration();
 303        let mut result = Ok(());
 304        if let ContextServerState::Running { server, .. } = &state {
 305            result = server.stop();
 306        }
 307        drop(state);
 308
 309        self.update_server_state(
 310            id.clone(),
 311            ContextServerState::Stopped {
 312                configuration,
 313                server,
 314                error: None,
 315            },
 316            cx,
 317        );
 318
 319        result
 320    }
 321
 322    pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 323        if let Some(state) = self.servers.get(&id) {
 324            let configuration = state.configuration();
 325
 326            self.stop_server(&state.server().id(), cx)?;
 327            let new_server = self.create_context_server(id.clone(), configuration.clone())?;
 328            self.run_server(new_server, configuration, cx);
 329        }
 330        Ok(())
 331    }
 332
 333    fn run_server(
 334        &mut self,
 335        server: Arc<ContextServer>,
 336        configuration: Arc<ContextServerConfiguration>,
 337        cx: &mut Context<Self>,
 338    ) {
 339        let id = server.id();
 340        if matches!(
 341            self.servers.get(&id),
 342            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 343        ) {
 344            self.stop_server(&id, cx).log_err();
 345        }
 346
 347        let task = cx.spawn({
 348            let id = server.id();
 349            let server = server.clone();
 350            let configuration = configuration.clone();
 351            async move |this, cx| {
 352                match server.clone().start(&cx).await {
 353                    Ok(_) => {
 354                        log::info!("Started {} context server", id);
 355                        debug_assert!(server.client().is_some());
 356
 357                        this.update(cx, |this, cx| {
 358                            this.update_server_state(
 359                                id.clone(),
 360                                ContextServerState::Running {
 361                                    server,
 362                                    configuration,
 363                                },
 364                                cx,
 365                            )
 366                        })
 367                        .log_err()
 368                    }
 369                    Err(err) => {
 370                        log::error!("{} context server failed to start: {}", id, err);
 371                        this.update(cx, |this, cx| {
 372                            this.update_server_state(
 373                                id.clone(),
 374                                ContextServerState::Stopped {
 375                                    configuration,
 376                                    server,
 377                                    error: Some(err.to_string().into()),
 378                                },
 379                                cx,
 380                            )
 381                        })
 382                        .log_err()
 383                    }
 384                };
 385            }
 386        });
 387
 388        self.update_server_state(
 389            id.clone(),
 390            ContextServerState::Starting {
 391                configuration,
 392                _task: task,
 393                server,
 394            },
 395            cx,
 396        );
 397    }
 398
 399    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 400        let state = self
 401            .servers
 402            .remove(id)
 403            .context("Context server not found")?;
 404        drop(state);
 405        cx.emit(Event::ServerStatusChanged {
 406            server_id: id.clone(),
 407            status: ContextServerStatus::Stopped,
 408        });
 409        Ok(())
 410    }
 411
 412    fn create_context_server(
 413        &self,
 414        id: ContextServerId,
 415        configuration: Arc<ContextServerConfiguration>,
 416    ) -> Result<Arc<ContextServer>> {
 417        if let Some(factory) = self.context_server_factory.as_ref() {
 418            Ok(factory(id, configuration))
 419        } else {
 420            Ok(Arc::new(ContextServer::stdio(
 421                id,
 422                configuration.command().clone(),
 423            )))
 424        }
 425    }
 426
 427    fn context_server_settings<'a>(
 428        &'a self,
 429        cx: &'a App,
 430    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 431        let location = self
 432            .worktree_store
 433            .read(cx)
 434            .visible_worktrees(cx)
 435            .next()
 436            .map(|worktree| settings::SettingsLocation {
 437                worktree_id: worktree.read(cx).id(),
 438                path: Path::new(""),
 439            });
 440        &ProjectSettings::get(location, cx).context_servers
 441    }
 442
 443    fn update_server_state(
 444        &mut self,
 445        id: ContextServerId,
 446        state: ContextServerState,
 447        cx: &mut Context<Self>,
 448    ) {
 449        let status = ContextServerStatus::from_state(&state);
 450        self.servers.insert(id.clone(), state);
 451        cx.emit(Event::ServerStatusChanged {
 452            server_id: id,
 453            status,
 454        });
 455    }
 456
 457    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 458        if self.update_servers_task.is_some() {
 459            self.needs_server_update = true;
 460        } else {
 461            self.needs_server_update = false;
 462            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 463                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 464                    log::error!("Error maintaining context servers: {}", err);
 465                }
 466
 467                this.update(cx, |this, cx| {
 468                    this.update_servers_task.take();
 469                    if this.needs_server_update {
 470                        this.available_context_servers_changed(cx);
 471                    }
 472                })?;
 473
 474                Ok(())
 475            }));
 476        }
 477    }
 478
 479    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 480        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, cx| {
 481            (
 482                this.context_server_settings(cx).clone(),
 483                this.registry.clone(),
 484                this.worktree_store.clone(),
 485            )
 486        })?;
 487
 488        for (id, _) in
 489            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 490        {
 491            configured_servers
 492                .entry(id)
 493                .or_insert(ContextServerSettings::Extension {
 494                    settings: serde_json::json!({}),
 495                });
 496        }
 497
 498        let configured_servers = join_all(configured_servers.into_iter().map(|(id, settings)| {
 499            let id = ContextServerId(id);
 500            ContextServerConfiguration::from_settings(
 501                settings,
 502                id.clone(),
 503                registry.clone(),
 504                worktree_store.clone(),
 505                cx,
 506            )
 507            .map(|config| (id, config))
 508        }))
 509        .await
 510        .into_iter()
 511        .filter_map(|(id, config)| config.map(|config| (id, config)))
 512        .collect::<HashMap<_, _>>();
 513
 514        let mut servers_to_start = Vec::new();
 515        let mut servers_to_remove = HashSet::default();
 516        let mut servers_to_stop = HashSet::default();
 517
 518        this.update(cx, |this, _cx| {
 519            for server_id in this.servers.keys() {
 520                // All servers that are not in desired_servers should be removed from the store.
 521                // This can happen if the user removed a server from the context server settings.
 522                if !configured_servers.contains_key(&server_id) {
 523                    servers_to_remove.insert(server_id.clone());
 524                }
 525            }
 526
 527            for (id, config) in configured_servers {
 528                let existing_config = this.servers.get(&id).map(|state| state.configuration());
 529                if existing_config.as_deref() != Some(&config) {
 530                    let config = Arc::new(config);
 531                    if let Some(server) = this
 532                        .create_context_server(id.clone(), config.clone())
 533                        .log_err()
 534                    {
 535                        servers_to_start.push((server, config));
 536                        if this.servers.contains_key(&id) {
 537                            servers_to_stop.insert(id);
 538                        }
 539                    }
 540                }
 541            }
 542        })?;
 543
 544        this.update(cx, |this, cx| {
 545            for id in servers_to_stop {
 546                this.stop_server(&id, cx)?;
 547            }
 548            for id in servers_to_remove {
 549                this.remove_server(&id, cx)?;
 550            }
 551            for (server, config) in servers_to_start {
 552                this.run_server(server, config, cx);
 553            }
 554            anyhow::Ok(())
 555        })?
 556    }
 557}
 558
 559#[cfg(test)]
 560mod tests {
 561    use super::*;
 562    use crate::{
 563        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 564        project_settings::ProjectSettings,
 565    };
 566    use context_server::test::create_fake_transport;
 567    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 568    use serde_json::json;
 569    use std::{cell::RefCell, rc::Rc};
 570    use util::path;
 571
 572    #[gpui::test]
 573    async fn test_context_server_status(cx: &mut TestAppContext) {
 574        const SERVER_1_ID: &'static str = "mcp-1";
 575        const SERVER_2_ID: &'static str = "mcp-2";
 576
 577        let (_fs, project) = setup_context_server_test(
 578            cx,
 579            json!({"code.rs": ""}),
 580            vec![
 581                (SERVER_1_ID.into(), dummy_server_settings()),
 582                (SERVER_2_ID.into(), dummy_server_settings()),
 583            ],
 584        )
 585        .await;
 586
 587        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 588        let store = cx.new(|cx| {
 589            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 590        });
 591
 592        let server_1_id = ContextServerId(SERVER_1_ID.into());
 593        let server_2_id = ContextServerId(SERVER_2_ID.into());
 594
 595        let server_1 = Arc::new(ContextServer::new(
 596            server_1_id.clone(),
 597            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 598        ));
 599        let server_2 = Arc::new(ContextServer::new(
 600            server_2_id.clone(),
 601            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 602        ));
 603
 604        store.update(cx, |store, cx| store.start_server(server_1, cx));
 605
 606        cx.run_until_parked();
 607
 608        cx.update(|cx| {
 609            assert_eq!(
 610                store.read(cx).status_for_server(&server_1_id),
 611                Some(ContextServerStatus::Running)
 612            );
 613            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 614        });
 615
 616        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 617
 618        cx.run_until_parked();
 619
 620        cx.update(|cx| {
 621            assert_eq!(
 622                store.read(cx).status_for_server(&server_1_id),
 623                Some(ContextServerStatus::Running)
 624            );
 625            assert_eq!(
 626                store.read(cx).status_for_server(&server_2_id),
 627                Some(ContextServerStatus::Running)
 628            );
 629        });
 630
 631        store
 632            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 633            .unwrap();
 634
 635        cx.update(|cx| {
 636            assert_eq!(
 637                store.read(cx).status_for_server(&server_1_id),
 638                Some(ContextServerStatus::Running)
 639            );
 640            assert_eq!(
 641                store.read(cx).status_for_server(&server_2_id),
 642                Some(ContextServerStatus::Stopped)
 643            );
 644        });
 645    }
 646
 647    #[gpui::test]
 648    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 649        const SERVER_1_ID: &'static str = "mcp-1";
 650        const SERVER_2_ID: &'static str = "mcp-2";
 651
 652        let (_fs, project) = setup_context_server_test(
 653            cx,
 654            json!({"code.rs": ""}),
 655            vec![
 656                (SERVER_1_ID.into(), dummy_server_settings()),
 657                (SERVER_2_ID.into(), dummy_server_settings()),
 658            ],
 659        )
 660        .await;
 661
 662        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 663        let store = cx.new(|cx| {
 664            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 665        });
 666
 667        let server_1_id = ContextServerId(SERVER_1_ID.into());
 668        let server_2_id = ContextServerId(SERVER_2_ID.into());
 669
 670        let server_1 = Arc::new(ContextServer::new(
 671            server_1_id.clone(),
 672            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 673        ));
 674        let server_2 = Arc::new(ContextServer::new(
 675            server_2_id.clone(),
 676            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 677        ));
 678
 679        let _server_events = assert_server_events(
 680            &store,
 681            vec![
 682                (server_1_id.clone(), ContextServerStatus::Starting),
 683                (server_1_id.clone(), ContextServerStatus::Running),
 684                (server_2_id.clone(), ContextServerStatus::Starting),
 685                (server_2_id.clone(), ContextServerStatus::Running),
 686                (server_2_id.clone(), ContextServerStatus::Stopped),
 687            ],
 688            cx,
 689        );
 690
 691        store.update(cx, |store, cx| store.start_server(server_1, cx));
 692
 693        cx.run_until_parked();
 694
 695        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 696
 697        cx.run_until_parked();
 698
 699        store
 700            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 701            .unwrap();
 702    }
 703
 704    #[gpui::test(iterations = 25)]
 705    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 706        const SERVER_1_ID: &'static str = "mcp-1";
 707
 708        let (_fs, project) = setup_context_server_test(
 709            cx,
 710            json!({"code.rs": ""}),
 711            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 712        )
 713        .await;
 714
 715        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 716        let store = cx.new(|cx| {
 717            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 718        });
 719
 720        let server_id = ContextServerId(SERVER_1_ID.into());
 721
 722        let server_with_same_id_1 = Arc::new(ContextServer::new(
 723            server_id.clone(),
 724            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 725        ));
 726        let server_with_same_id_2 = Arc::new(ContextServer::new(
 727            server_id.clone(),
 728            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 729        ));
 730
 731        // If we start another server with the same id, we should report that we stopped the previous one
 732        let _server_events = assert_server_events(
 733            &store,
 734            vec![
 735                (server_id.clone(), ContextServerStatus::Starting),
 736                (server_id.clone(), ContextServerStatus::Stopped),
 737                (server_id.clone(), ContextServerStatus::Starting),
 738                (server_id.clone(), ContextServerStatus::Running),
 739            ],
 740            cx,
 741        );
 742
 743        store.update(cx, |store, cx| {
 744            store.start_server(server_with_same_id_1.clone(), cx)
 745        });
 746        store.update(cx, |store, cx| {
 747            store.start_server(server_with_same_id_2.clone(), cx)
 748        });
 749
 750        cx.run_until_parked();
 751
 752        cx.update(|cx| {
 753            assert_eq!(
 754                store.read(cx).status_for_server(&server_id),
 755                Some(ContextServerStatus::Running)
 756            );
 757        });
 758    }
 759
 760    #[gpui::test]
 761    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 762        const SERVER_1_ID: &'static str = "mcp-1";
 763        const SERVER_2_ID: &'static str = "mcp-2";
 764
 765        let server_1_id = ContextServerId(SERVER_1_ID.into());
 766        let server_2_id = ContextServerId(SERVER_2_ID.into());
 767
 768        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 769
 770        let (_fs, project) = setup_context_server_test(
 771            cx,
 772            json!({"code.rs": ""}),
 773            vec![(
 774                SERVER_1_ID.into(),
 775                ContextServerSettings::Extension {
 776                    settings: json!({
 777                        "somevalue": true
 778                    }),
 779                },
 780            )],
 781        )
 782        .await;
 783
 784        let executor = cx.executor();
 785        let registry = cx.new(|_| {
 786            let mut registry = ContextServerDescriptorRegistry::new();
 787            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1);
 788            registry
 789        });
 790        let store = cx.new(|cx| {
 791            ContextServerStore::test_maintain_server_loop(
 792                Box::new(move |id, _| {
 793                    Arc::new(ContextServer::new(
 794                        id.clone(),
 795                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 796                    ))
 797                }),
 798                registry.clone(),
 799                project.read(cx).worktree_store(),
 800                cx,
 801            )
 802        });
 803
 804        // Ensure that mcp-1 starts up
 805        {
 806            let _server_events = assert_server_events(
 807                &store,
 808                vec![
 809                    (server_1_id.clone(), ContextServerStatus::Starting),
 810                    (server_1_id.clone(), ContextServerStatus::Running),
 811                ],
 812                cx,
 813            );
 814            cx.run_until_parked();
 815        }
 816
 817        // Ensure that mcp-1 is restarted when the configuration was changed
 818        {
 819            let _server_events = assert_server_events(
 820                &store,
 821                vec![
 822                    (server_1_id.clone(), ContextServerStatus::Stopped),
 823                    (server_1_id.clone(), ContextServerStatus::Starting),
 824                    (server_1_id.clone(), ContextServerStatus::Running),
 825                ],
 826                cx,
 827            );
 828            set_context_server_configuration(
 829                vec![(
 830                    server_1_id.0.clone(),
 831                    ContextServerSettings::Extension {
 832                        settings: json!({
 833                            "somevalue": false
 834                        }),
 835                    },
 836                )],
 837                cx,
 838            );
 839
 840            cx.run_until_parked();
 841        }
 842
 843        // Ensure that mcp-1 is not restarted when the configuration was not changed
 844        {
 845            let _server_events = assert_server_events(&store, vec![], cx);
 846            set_context_server_configuration(
 847                vec![(
 848                    server_1_id.0.clone(),
 849                    ContextServerSettings::Extension {
 850                        settings: json!({
 851                            "somevalue": false
 852                        }),
 853                    },
 854                )],
 855                cx,
 856            );
 857
 858            cx.run_until_parked();
 859        }
 860
 861        // Ensure that mcp-2 is started once it is added to the settings
 862        {
 863            let _server_events = assert_server_events(
 864                &store,
 865                vec![
 866                    (server_2_id.clone(), ContextServerStatus::Starting),
 867                    (server_2_id.clone(), ContextServerStatus::Running),
 868                ],
 869                cx,
 870            );
 871            set_context_server_configuration(
 872                vec![
 873                    (
 874                        server_1_id.0.clone(),
 875                        ContextServerSettings::Extension {
 876                            settings: json!({
 877                                "somevalue": false
 878                            }),
 879                        },
 880                    ),
 881                    (
 882                        server_2_id.0.clone(),
 883                        ContextServerSettings::Custom {
 884                            command: ContextServerCommand {
 885                                path: "somebinary".to_string(),
 886                                args: vec!["arg".to_string()],
 887                                env: None,
 888                            },
 889                        },
 890                    ),
 891                ],
 892                cx,
 893            );
 894
 895            cx.run_until_parked();
 896        }
 897
 898        // Ensure that mcp-2 is restarted once the args have changed
 899        {
 900            let _server_events = assert_server_events(
 901                &store,
 902                vec![
 903                    (server_2_id.clone(), ContextServerStatus::Stopped),
 904                    (server_2_id.clone(), ContextServerStatus::Starting),
 905                    (server_2_id.clone(), ContextServerStatus::Running),
 906                ],
 907                cx,
 908            );
 909            set_context_server_configuration(
 910                vec![
 911                    (
 912                        server_1_id.0.clone(),
 913                        ContextServerSettings::Extension {
 914                            settings: json!({
 915                                "somevalue": false
 916                            }),
 917                        },
 918                    ),
 919                    (
 920                        server_2_id.0.clone(),
 921                        ContextServerSettings::Custom {
 922                            command: ContextServerCommand {
 923                                path: "somebinary".to_string(),
 924                                args: vec!["anotherArg".to_string()],
 925                                env: None,
 926                            },
 927                        },
 928                    ),
 929                ],
 930                cx,
 931            );
 932
 933            cx.run_until_parked();
 934        }
 935
 936        // Ensure that mcp-2 is removed once it is removed from the settings
 937        {
 938            let _server_events = assert_server_events(
 939                &store,
 940                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
 941                cx,
 942            );
 943            set_context_server_configuration(
 944                vec![(
 945                    server_1_id.0.clone(),
 946                    ContextServerSettings::Extension {
 947                        settings: json!({
 948                            "somevalue": false
 949                        }),
 950                    },
 951                )],
 952                cx,
 953            );
 954
 955            cx.run_until_parked();
 956
 957            cx.update(|cx| {
 958                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 959            });
 960        }
 961    }
 962
 963    fn set_context_server_configuration(
 964        context_servers: Vec<(Arc<str>, ContextServerSettings)>,
 965        cx: &mut TestAppContext,
 966    ) {
 967        cx.update(|cx| {
 968            SettingsStore::update_global(cx, |store, cx| {
 969                let mut settings = ProjectSettings::default();
 970                for (id, config) in context_servers {
 971                    settings.context_servers.insert(id, config);
 972                }
 973                store
 974                    .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
 975                    .unwrap();
 976            })
 977        });
 978    }
 979
 980    struct ServerEvents {
 981        received_event_count: Rc<RefCell<usize>>,
 982        expected_event_count: usize,
 983        _subscription: Subscription,
 984    }
 985
 986    impl Drop for ServerEvents {
 987        fn drop(&mut self) {
 988            let actual_event_count = *self.received_event_count.borrow();
 989            assert_eq!(
 990                actual_event_count, self.expected_event_count,
 991                "
 992                Expected to receive {} context server store events, but received {} events",
 993                self.expected_event_count, actual_event_count
 994            );
 995        }
 996    }
 997
 998    fn dummy_server_settings() -> ContextServerSettings {
 999        ContextServerSettings::Custom {
1000            command: ContextServerCommand {
1001                path: "somebinary".to_string(),
1002                args: vec!["arg".to_string()],
1003                env: None,
1004            },
1005        }
1006    }
1007
1008    fn assert_server_events(
1009        store: &Entity<ContextServerStore>,
1010        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1011        cx: &mut TestAppContext,
1012    ) -> ServerEvents {
1013        cx.update(|cx| {
1014            let mut ix = 0;
1015            let received_event_count = Rc::new(RefCell::new(0));
1016            let expected_event_count = expected_events.len();
1017            let subscription = cx.subscribe(store, {
1018                let received_event_count = received_event_count.clone();
1019                move |_, event, _| match event {
1020                    Event::ServerStatusChanged {
1021                        server_id: actual_server_id,
1022                        status: actual_status,
1023                    } => {
1024                        let (expected_server_id, expected_status) = &expected_events[ix];
1025
1026                        assert_eq!(
1027                            actual_server_id, expected_server_id,
1028                            "Expected different server id at index {}",
1029                            ix
1030                        );
1031                        assert_eq!(
1032                            actual_status, expected_status,
1033                            "Expected different status at index {}",
1034                            ix
1035                        );
1036                        ix += 1;
1037                        *received_event_count.borrow_mut() += 1;
1038                    }
1039                }
1040            });
1041            ServerEvents {
1042                expected_event_count,
1043                received_event_count,
1044                _subscription: subscription,
1045            }
1046        })
1047    }
1048
1049    async fn setup_context_server_test(
1050        cx: &mut TestAppContext,
1051        files: serde_json::Value,
1052        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1053    ) -> (Arc<FakeFs>, Entity<Project>) {
1054        cx.update(|cx| {
1055            let settings_store = SettingsStore::test(cx);
1056            cx.set_global(settings_store);
1057            Project::init_settings(cx);
1058            let mut settings = ProjectSettings::get_global(cx).clone();
1059            for (id, config) in context_server_configurations {
1060                settings.context_servers.insert(id, config);
1061            }
1062            ProjectSettings::override_global(settings, cx);
1063        });
1064
1065        let fs = FakeFs::new(cx.executor());
1066        fs.insert_tree(path!("/test"), files).await;
1067        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1068
1069        (fs, project)
1070    }
1071
1072    struct FakeContextServerDescriptor {
1073        path: String,
1074    }
1075
1076    impl FakeContextServerDescriptor {
1077        fn new(path: impl Into<String>) -> Self {
1078            Self { path: path.into() }
1079        }
1080    }
1081
1082    impl ContextServerDescriptor for FakeContextServerDescriptor {
1083        fn command(
1084            &self,
1085            _worktree_store: Entity<WorktreeStore>,
1086            _cx: &AsyncApp,
1087        ) -> Task<Result<ContextServerCommand>> {
1088            Task::ready(Ok(ContextServerCommand {
1089                path: self.path.clone(),
1090                args: vec!["arg1".to_string(), "arg2".to_string()],
1091                env: None,
1092            }))
1093        }
1094
1095        fn configuration(
1096            &self,
1097            _worktree_store: Entity<WorktreeStore>,
1098            _cx: &AsyncApp,
1099        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1100            Task::ready(Ok(None))
1101        }
1102    }
1103}