context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::{path::Path, sync::Arc};
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   9use futures::{FutureExt as _, future::join_all};
  10use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  11use registry::ContextServerDescriptorRegistry;
  12use settings::{Settings as _, SettingsStore};
  13use util::ResultExt as _;
  14
  15use crate::{
  16    project_settings::{ContextServerSettings, ProjectSettings},
  17    worktree_store::WorktreeStore,
  18};
  19
  20pub fn init(cx: &mut App) {
  21    extension::init(cx);
  22}
  23
  24actions!(context_server, [Restart]);
  25
  26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  27pub enum ContextServerStatus {
  28    Starting,
  29    Running,
  30    Stopped,
  31    Error(Arc<str>),
  32}
  33
  34impl ContextServerStatus {
  35    fn from_state(state: &ContextServerState) -> Self {
  36        match state {
  37            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  38            ContextServerState::Running { .. } => ContextServerStatus::Running,
  39            ContextServerState::Stopped { error, .. } => {
  40                if let Some(error) = error {
  41                    ContextServerStatus::Error(error.clone())
  42                } else {
  43                    ContextServerStatus::Stopped
  44                }
  45            }
  46        }
  47    }
  48}
  49
  50enum ContextServerState {
  51    Starting {
  52        server: Arc<ContextServer>,
  53        configuration: Arc<ContextServerConfiguration>,
  54        _task: Task<()>,
  55    },
  56    Running {
  57        server: Arc<ContextServer>,
  58        configuration: Arc<ContextServerConfiguration>,
  59    },
  60    Stopped {
  61        server: Arc<ContextServer>,
  62        configuration: Arc<ContextServerConfiguration>,
  63        error: Option<Arc<str>>,
  64    },
  65}
  66
  67impl ContextServerState {
  68    pub fn server(&self) -> Arc<ContextServer> {
  69        match self {
  70            ContextServerState::Starting { server, .. } => server.clone(),
  71            ContextServerState::Running { server, .. } => server.clone(),
  72            ContextServerState::Stopped { server, .. } => server.clone(),
  73        }
  74    }
  75
  76    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  77        match self {
  78            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  79            ContextServerState::Running { configuration, .. } => configuration.clone(),
  80            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  81        }
  82    }
  83}
  84
  85#[derive(PartialEq, Eq)]
  86pub enum ContextServerConfiguration {
  87    Custom {
  88        command: ContextServerCommand,
  89    },
  90    Extension {
  91        command: ContextServerCommand,
  92        settings: serde_json::Value,
  93    },
  94}
  95
  96impl ContextServerConfiguration {
  97    pub fn command(&self) -> &ContextServerCommand {
  98        match self {
  99            ContextServerConfiguration::Custom { command } => command,
 100            ContextServerConfiguration::Extension { command, .. } => command,
 101        }
 102    }
 103
 104    pub async fn from_settings(
 105        settings: ContextServerSettings,
 106        id: ContextServerId,
 107        registry: Entity<ContextServerDescriptorRegistry>,
 108        worktree_store: Entity<WorktreeStore>,
 109        cx: &AsyncApp,
 110    ) -> Option<Self> {
 111        match settings {
 112            ContextServerSettings::Custom { command } => {
 113                Some(ContextServerConfiguration::Custom { command })
 114            }
 115            ContextServerSettings::Extension { settings } => {
 116                let descriptor = cx
 117                    .update(|cx| registry.read(cx).context_server_descriptor(&id.0))
 118                    .ok()
 119                    .flatten()?;
 120
 121                let command = descriptor.command(worktree_store, cx).await.log_err()?;
 122
 123                Some(ContextServerConfiguration::Extension { command, settings })
 124            }
 125        }
 126    }
 127}
 128
 129pub type ContextServerFactory =
 130    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 131
 132pub struct ContextServerStore {
 133    servers: HashMap<ContextServerId, ContextServerState>,
 134    worktree_store: Entity<WorktreeStore>,
 135    registry: Entity<ContextServerDescriptorRegistry>,
 136    update_servers_task: Option<Task<Result<()>>>,
 137    context_server_factory: Option<ContextServerFactory>,
 138    needs_server_update: bool,
 139    _subscriptions: Vec<Subscription>,
 140}
 141
 142pub enum Event {
 143    ServerStatusChanged {
 144        server_id: ContextServerId,
 145        status: ContextServerStatus,
 146    },
 147}
 148
 149impl EventEmitter<Event> for ContextServerStore {}
 150
 151impl ContextServerStore {
 152    pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
 153        Self::new_internal(
 154            true,
 155            None,
 156            ContextServerDescriptorRegistry::default_global(cx),
 157            worktree_store,
 158            cx,
 159        )
 160    }
 161
 162    #[cfg(any(test, feature = "test-support"))]
 163    pub fn test(
 164        registry: Entity<ContextServerDescriptorRegistry>,
 165        worktree_store: Entity<WorktreeStore>,
 166        cx: &mut Context<Self>,
 167    ) -> Self {
 168        Self::new_internal(false, None, registry, worktree_store, cx)
 169    }
 170
 171    #[cfg(any(test, feature = "test-support"))]
 172    pub fn test_maintain_server_loop(
 173        context_server_factory: ContextServerFactory,
 174        registry: Entity<ContextServerDescriptorRegistry>,
 175        worktree_store: Entity<WorktreeStore>,
 176        cx: &mut Context<Self>,
 177    ) -> Self {
 178        Self::new_internal(
 179            true,
 180            Some(context_server_factory),
 181            registry,
 182            worktree_store,
 183            cx,
 184        )
 185    }
 186
 187    fn new_internal(
 188        maintain_server_loop: bool,
 189        context_server_factory: Option<ContextServerFactory>,
 190        registry: Entity<ContextServerDescriptorRegistry>,
 191        worktree_store: Entity<WorktreeStore>,
 192        cx: &mut Context<Self>,
 193    ) -> Self {
 194        let subscriptions = if maintain_server_loop {
 195            vec![
 196                cx.observe(&registry, |this, _registry, cx| {
 197                    this.available_context_servers_changed(cx);
 198                }),
 199                cx.observe_global::<SettingsStore>(|this, cx| {
 200                    this.available_context_servers_changed(cx);
 201                }),
 202            ]
 203        } else {
 204            Vec::new()
 205        };
 206
 207        let mut this = Self {
 208            _subscriptions: subscriptions,
 209            worktree_store,
 210            registry,
 211            needs_server_update: false,
 212            servers: HashMap::default(),
 213            update_servers_task: None,
 214            context_server_factory,
 215        };
 216        if maintain_server_loop {
 217            this.available_context_servers_changed(cx);
 218        }
 219        this
 220    }
 221
 222    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 223        self.servers.get(id).map(|state| state.server())
 224    }
 225
 226    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 227        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 228            Some(server.clone())
 229        } else {
 230            None
 231        }
 232    }
 233
 234    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 235        self.servers.get(id).map(ContextServerStatus::from_state)
 236    }
 237
 238    pub fn all_server_ids(&self) -> Vec<ContextServerId> {
 239        self.servers.keys().cloned().collect()
 240    }
 241
 242    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 243        self.servers
 244            .values()
 245            .filter_map(|state| {
 246                if let ContextServerState::Running { server, .. } = state {
 247                    Some(server.clone())
 248                } else {
 249                    None
 250                }
 251            })
 252            .collect()
 253    }
 254
 255    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 256        cx.spawn(async move |this, cx| {
 257            let this = this.upgrade().context("Context server store dropped")?;
 258            let settings = this
 259                .update(cx, |this, cx| {
 260                    this.context_server_settings(cx)
 261                        .get(&server.id().0)
 262                        .cloned()
 263                })
 264                .ok()
 265                .flatten()
 266                .context("Failed to get context server settings")?;
 267
 268            let (registry, worktree_store) = this.update(cx, |this, _| {
 269                (this.registry.clone(), this.worktree_store.clone())
 270            })?;
 271            let configuration = ContextServerConfiguration::from_settings(
 272                settings,
 273                server.id(),
 274                registry,
 275                worktree_store,
 276                cx,
 277            )
 278            .await
 279            .context("Failed to create context server configuration")?;
 280
 281            this.update(cx, |this, cx| {
 282                this.run_server(server, Arc::new(configuration), cx)
 283            })
 284        })
 285        .detach_and_log_err(cx);
 286    }
 287
 288    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 289        let state = self
 290            .servers
 291            .remove(id)
 292            .context("Context server not found")?;
 293
 294        let server = state.server();
 295        let configuration = state.configuration();
 296        let mut result = Ok(());
 297        if let ContextServerState::Running { server, .. } = &state {
 298            result = server.stop();
 299        }
 300        drop(state);
 301
 302        self.update_server_state(
 303            id.clone(),
 304            ContextServerState::Stopped {
 305                configuration,
 306                server,
 307                error: None,
 308            },
 309            cx,
 310        );
 311
 312        result
 313    }
 314
 315    pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 316        if let Some(state) = self.servers.get(&id) {
 317            let configuration = state.configuration();
 318
 319            self.stop_server(&state.server().id(), cx)?;
 320            let new_server = self.create_context_server(id.clone(), configuration.clone())?;
 321            self.run_server(new_server, configuration, cx);
 322        }
 323        Ok(())
 324    }
 325
 326    fn run_server(
 327        &mut self,
 328        server: Arc<ContextServer>,
 329        configuration: Arc<ContextServerConfiguration>,
 330        cx: &mut Context<Self>,
 331    ) {
 332        let id = server.id();
 333        if matches!(
 334            self.servers.get(&id),
 335            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 336        ) {
 337            self.stop_server(&id, cx).log_err();
 338        }
 339
 340        let task = cx.spawn({
 341            let id = server.id();
 342            let server = server.clone();
 343            let configuration = configuration.clone();
 344            async move |this, cx| {
 345                match server.clone().start(&cx).await {
 346                    Ok(_) => {
 347                        log::info!("Started {} context server", id);
 348                        debug_assert!(server.client().is_some());
 349
 350                        this.update(cx, |this, cx| {
 351                            this.update_server_state(
 352                                id.clone(),
 353                                ContextServerState::Running {
 354                                    server,
 355                                    configuration,
 356                                },
 357                                cx,
 358                            )
 359                        })
 360                        .log_err()
 361                    }
 362                    Err(err) => {
 363                        log::error!("{} context server failed to start: {}", id, err);
 364                        this.update(cx, |this, cx| {
 365                            this.update_server_state(
 366                                id.clone(),
 367                                ContextServerState::Stopped {
 368                                    configuration,
 369                                    server,
 370                                    error: Some(err.to_string().into()),
 371                                },
 372                                cx,
 373                            )
 374                        })
 375                        .log_err()
 376                    }
 377                };
 378            }
 379        });
 380
 381        self.update_server_state(
 382            id.clone(),
 383            ContextServerState::Starting {
 384                configuration,
 385                _task: task,
 386                server,
 387            },
 388            cx,
 389        );
 390    }
 391
 392    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 393        let state = self
 394            .servers
 395            .remove(id)
 396            .context("Context server not found")?;
 397        drop(state);
 398        cx.emit(Event::ServerStatusChanged {
 399            server_id: id.clone(),
 400            status: ContextServerStatus::Stopped,
 401        });
 402        Ok(())
 403    }
 404
 405    fn create_context_server(
 406        &self,
 407        id: ContextServerId,
 408        configuration: Arc<ContextServerConfiguration>,
 409    ) -> Result<Arc<ContextServer>> {
 410        if let Some(factory) = self.context_server_factory.as_ref() {
 411            Ok(factory(id, configuration))
 412        } else {
 413            Ok(Arc::new(ContextServer::stdio(
 414                id,
 415                configuration.command().clone(),
 416            )))
 417        }
 418    }
 419
 420    fn context_server_settings<'a>(
 421        &'a self,
 422        cx: &'a App,
 423    ) -> &'a HashMap<Arc<str>, ContextServerSettings> {
 424        let location = self
 425            .worktree_store
 426            .read(cx)
 427            .visible_worktrees(cx)
 428            .next()
 429            .map(|worktree| settings::SettingsLocation {
 430                worktree_id: worktree.read(cx).id(),
 431                path: Path::new(""),
 432            });
 433        &ProjectSettings::get(location, cx).context_servers
 434    }
 435
 436    fn update_server_state(
 437        &mut self,
 438        id: ContextServerId,
 439        state: ContextServerState,
 440        cx: &mut Context<Self>,
 441    ) {
 442        let status = ContextServerStatus::from_state(&state);
 443        self.servers.insert(id.clone(), state);
 444        cx.emit(Event::ServerStatusChanged {
 445            server_id: id,
 446            status,
 447        });
 448    }
 449
 450    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 451        if self.update_servers_task.is_some() {
 452            self.needs_server_update = true;
 453        } else {
 454            self.needs_server_update = false;
 455            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 456                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 457                    log::error!("Error maintaining context servers: {}", err);
 458                }
 459
 460                this.update(cx, |this, cx| {
 461                    this.update_servers_task.take();
 462                    if this.needs_server_update {
 463                        this.available_context_servers_changed(cx);
 464                    }
 465                })?;
 466
 467                Ok(())
 468            }));
 469        }
 470    }
 471
 472    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 473        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, cx| {
 474            (
 475                this.context_server_settings(cx).clone(),
 476                this.registry.clone(),
 477                this.worktree_store.clone(),
 478            )
 479        })?;
 480
 481        for (id, _) in
 482            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 483        {
 484            configured_servers
 485                .entry(id)
 486                .or_insert(ContextServerSettings::Extension {
 487                    settings: serde_json::json!({}),
 488                });
 489        }
 490
 491        let configured_servers = join_all(configured_servers.into_iter().map(|(id, settings)| {
 492            let id = ContextServerId(id);
 493            ContextServerConfiguration::from_settings(
 494                settings,
 495                id.clone(),
 496                registry.clone(),
 497                worktree_store.clone(),
 498                cx,
 499            )
 500            .map(|config| (id, config))
 501        }))
 502        .await
 503        .into_iter()
 504        .filter_map(|(id, config)| config.map(|config| (id, config)))
 505        .collect::<HashMap<_, _>>();
 506
 507        let mut servers_to_start = Vec::new();
 508        let mut servers_to_remove = HashSet::default();
 509        let mut servers_to_stop = HashSet::default();
 510
 511        this.update(cx, |this, _cx| {
 512            for server_id in this.servers.keys() {
 513                // All servers that are not in desired_servers should be removed from the store.
 514                // This can happen if the user removed a server from the context server settings.
 515                if !configured_servers.contains_key(&server_id) {
 516                    servers_to_remove.insert(server_id.clone());
 517                }
 518            }
 519
 520            for (id, config) in configured_servers {
 521                let existing_config = this.servers.get(&id).map(|state| state.configuration());
 522                if existing_config.as_deref() != Some(&config) {
 523                    let config = Arc::new(config);
 524                    if let Some(server) = this
 525                        .create_context_server(id.clone(), config.clone())
 526                        .log_err()
 527                    {
 528                        servers_to_start.push((server, config));
 529                        if this.servers.contains_key(&id) {
 530                            servers_to_stop.insert(id);
 531                        }
 532                    }
 533                }
 534            }
 535        })?;
 536
 537        this.update(cx, |this, cx| {
 538            for id in servers_to_stop {
 539                this.stop_server(&id, cx)?;
 540            }
 541            for id in servers_to_remove {
 542                this.remove_server(&id, cx)?;
 543            }
 544            for (server, config) in servers_to_start {
 545                this.run_server(server, config, cx);
 546            }
 547            anyhow::Ok(())
 548        })?
 549    }
 550}
 551
 552#[cfg(test)]
 553mod tests {
 554    use super::*;
 555    use crate::{
 556        FakeFs, Project, context_server_store::registry::ContextServerDescriptor,
 557        project_settings::ProjectSettings,
 558    };
 559    use context_server::test::create_fake_transport;
 560    use gpui::{AppContext, TestAppContext, UpdateGlobal as _};
 561    use serde_json::json;
 562    use std::{cell::RefCell, rc::Rc};
 563    use util::path;
 564
 565    #[gpui::test]
 566    async fn test_context_server_status(cx: &mut TestAppContext) {
 567        const SERVER_1_ID: &'static str = "mcp-1";
 568        const SERVER_2_ID: &'static str = "mcp-2";
 569
 570        let (_fs, project) = setup_context_server_test(
 571            cx,
 572            json!({"code.rs": ""}),
 573            vec![
 574                (SERVER_1_ID.into(), dummy_server_settings()),
 575                (SERVER_2_ID.into(), dummy_server_settings()),
 576            ],
 577        )
 578        .await;
 579
 580        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 581        let store = cx.new(|cx| {
 582            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 583        });
 584
 585        let server_1_id = ContextServerId(SERVER_1_ID.into());
 586        let server_2_id = ContextServerId(SERVER_2_ID.into());
 587
 588        let server_1 = Arc::new(ContextServer::new(
 589            server_1_id.clone(),
 590            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 591        ));
 592        let server_2 = Arc::new(ContextServer::new(
 593            server_2_id.clone(),
 594            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 595        ));
 596
 597        store.update(cx, |store, cx| store.start_server(server_1, cx));
 598
 599        cx.run_until_parked();
 600
 601        cx.update(|cx| {
 602            assert_eq!(
 603                store.read(cx).status_for_server(&server_1_id),
 604                Some(ContextServerStatus::Running)
 605            );
 606            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 607        });
 608
 609        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 610
 611        cx.run_until_parked();
 612
 613        cx.update(|cx| {
 614            assert_eq!(
 615                store.read(cx).status_for_server(&server_1_id),
 616                Some(ContextServerStatus::Running)
 617            );
 618            assert_eq!(
 619                store.read(cx).status_for_server(&server_2_id),
 620                Some(ContextServerStatus::Running)
 621            );
 622        });
 623
 624        store
 625            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 626            .unwrap();
 627
 628        cx.update(|cx| {
 629            assert_eq!(
 630                store.read(cx).status_for_server(&server_1_id),
 631                Some(ContextServerStatus::Running)
 632            );
 633            assert_eq!(
 634                store.read(cx).status_for_server(&server_2_id),
 635                Some(ContextServerStatus::Stopped)
 636            );
 637        });
 638    }
 639
 640    #[gpui::test]
 641    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 642        const SERVER_1_ID: &'static str = "mcp-1";
 643        const SERVER_2_ID: &'static str = "mcp-2";
 644
 645        let (_fs, project) = setup_context_server_test(
 646            cx,
 647            json!({"code.rs": ""}),
 648            vec![
 649                (SERVER_1_ID.into(), dummy_server_settings()),
 650                (SERVER_2_ID.into(), dummy_server_settings()),
 651            ],
 652        )
 653        .await;
 654
 655        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 656        let store = cx.new(|cx| {
 657            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 658        });
 659
 660        let server_1_id = ContextServerId(SERVER_1_ID.into());
 661        let server_2_id = ContextServerId(SERVER_2_ID.into());
 662
 663        let server_1 = Arc::new(ContextServer::new(
 664            server_1_id.clone(),
 665            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 666        ));
 667        let server_2 = Arc::new(ContextServer::new(
 668            server_2_id.clone(),
 669            Arc::new(create_fake_transport(SERVER_2_ID, cx.executor())),
 670        ));
 671
 672        let _server_events = assert_server_events(
 673            &store,
 674            vec![
 675                (server_1_id.clone(), ContextServerStatus::Starting),
 676                (server_1_id.clone(), ContextServerStatus::Running),
 677                (server_2_id.clone(), ContextServerStatus::Starting),
 678                (server_2_id.clone(), ContextServerStatus::Running),
 679                (server_2_id.clone(), ContextServerStatus::Stopped),
 680            ],
 681            cx,
 682        );
 683
 684        store.update(cx, |store, cx| store.start_server(server_1, cx));
 685
 686        cx.run_until_parked();
 687
 688        store.update(cx, |store, cx| store.start_server(server_2.clone(), cx));
 689
 690        cx.run_until_parked();
 691
 692        store
 693            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 694            .unwrap();
 695    }
 696
 697    #[gpui::test(iterations = 25)]
 698    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 699        const SERVER_1_ID: &'static str = "mcp-1";
 700
 701        let (_fs, project) = setup_context_server_test(
 702            cx,
 703            json!({"code.rs": ""}),
 704            vec![(SERVER_1_ID.into(), dummy_server_settings())],
 705        )
 706        .await;
 707
 708        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 709        let store = cx.new(|cx| {
 710            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 711        });
 712
 713        let server_id = ContextServerId(SERVER_1_ID.into());
 714
 715        let server_with_same_id_1 = Arc::new(ContextServer::new(
 716            server_id.clone(),
 717            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 718        ));
 719        let server_with_same_id_2 = Arc::new(ContextServer::new(
 720            server_id.clone(),
 721            Arc::new(create_fake_transport(SERVER_1_ID, cx.executor())),
 722        ));
 723
 724        // If we start another server with the same id, we should report that we stopped the previous one
 725        let _server_events = assert_server_events(
 726            &store,
 727            vec![
 728                (server_id.clone(), ContextServerStatus::Starting),
 729                (server_id.clone(), ContextServerStatus::Stopped),
 730                (server_id.clone(), ContextServerStatus::Starting),
 731                (server_id.clone(), ContextServerStatus::Running),
 732            ],
 733            cx,
 734        );
 735
 736        store.update(cx, |store, cx| {
 737            store.start_server(server_with_same_id_1.clone(), cx)
 738        });
 739        store.update(cx, |store, cx| {
 740            store.start_server(server_with_same_id_2.clone(), cx)
 741        });
 742
 743        cx.run_until_parked();
 744
 745        cx.update(|cx| {
 746            assert_eq!(
 747                store.read(cx).status_for_server(&server_id),
 748                Some(ContextServerStatus::Running)
 749            );
 750        });
 751    }
 752
 753    #[gpui::test]
 754    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 755        const SERVER_1_ID: &'static str = "mcp-1";
 756        const SERVER_2_ID: &'static str = "mcp-2";
 757
 758        let server_1_id = ContextServerId(SERVER_1_ID.into());
 759        let server_2_id = ContextServerId(SERVER_2_ID.into());
 760
 761        let fake_descriptor_1 = Arc::new(FakeContextServerDescriptor::new(SERVER_1_ID));
 762
 763        let (_fs, project) = setup_context_server_test(
 764            cx,
 765            json!({"code.rs": ""}),
 766            vec![(
 767                SERVER_1_ID.into(),
 768                ContextServerSettings::Extension {
 769                    settings: json!({
 770                        "somevalue": true
 771                    }),
 772                },
 773            )],
 774        )
 775        .await;
 776
 777        let executor = cx.executor();
 778        let registry = cx.new(|_| {
 779            let mut registry = ContextServerDescriptorRegistry::new();
 780            registry.register_context_server_descriptor(SERVER_1_ID.into(), fake_descriptor_1);
 781            registry
 782        });
 783        let store = cx.new(|cx| {
 784            ContextServerStore::test_maintain_server_loop(
 785                Box::new(move |id, _| {
 786                    Arc::new(ContextServer::new(
 787                        id.clone(),
 788                        Arc::new(create_fake_transport(id.0.to_string(), executor.clone())),
 789                    ))
 790                }),
 791                registry.clone(),
 792                project.read(cx).worktree_store(),
 793                cx,
 794            )
 795        });
 796
 797        // Ensure that mcp-1 starts up
 798        {
 799            let _server_events = assert_server_events(
 800                &store,
 801                vec![
 802                    (server_1_id.clone(), ContextServerStatus::Starting),
 803                    (server_1_id.clone(), ContextServerStatus::Running),
 804                ],
 805                cx,
 806            );
 807            cx.run_until_parked();
 808        }
 809
 810        // Ensure that mcp-1 is restarted when the configuration was changed
 811        {
 812            let _server_events = assert_server_events(
 813                &store,
 814                vec![
 815                    (server_1_id.clone(), ContextServerStatus::Stopped),
 816                    (server_1_id.clone(), ContextServerStatus::Starting),
 817                    (server_1_id.clone(), ContextServerStatus::Running),
 818                ],
 819                cx,
 820            );
 821            set_context_server_configuration(
 822                vec![(
 823                    server_1_id.0.clone(),
 824                    ContextServerSettings::Extension {
 825                        settings: json!({
 826                            "somevalue": false
 827                        }),
 828                    },
 829                )],
 830                cx,
 831            );
 832
 833            cx.run_until_parked();
 834        }
 835
 836        // Ensure that mcp-1 is not restarted when the configuration was not changed
 837        {
 838            let _server_events = assert_server_events(&store, vec![], cx);
 839            set_context_server_configuration(
 840                vec![(
 841                    server_1_id.0.clone(),
 842                    ContextServerSettings::Extension {
 843                        settings: json!({
 844                            "somevalue": false
 845                        }),
 846                    },
 847                )],
 848                cx,
 849            );
 850
 851            cx.run_until_parked();
 852        }
 853
 854        // Ensure that mcp-2 is started once it is added to the settings
 855        {
 856            let _server_events = assert_server_events(
 857                &store,
 858                vec![
 859                    (server_2_id.clone(), ContextServerStatus::Starting),
 860                    (server_2_id.clone(), ContextServerStatus::Running),
 861                ],
 862                cx,
 863            );
 864            set_context_server_configuration(
 865                vec![
 866                    (
 867                        server_1_id.0.clone(),
 868                        ContextServerSettings::Extension {
 869                            settings: json!({
 870                                "somevalue": false
 871                            }),
 872                        },
 873                    ),
 874                    (
 875                        server_2_id.0.clone(),
 876                        ContextServerSettings::Custom {
 877                            command: ContextServerCommand {
 878                                path: "somebinary".to_string(),
 879                                args: vec!["arg".to_string()],
 880                                env: None,
 881                            },
 882                        },
 883                    ),
 884                ],
 885                cx,
 886            );
 887
 888            cx.run_until_parked();
 889        }
 890
 891        // Ensure that mcp-2 is restarted once the args have changed
 892        {
 893            let _server_events = assert_server_events(
 894                &store,
 895                vec![
 896                    (server_2_id.clone(), ContextServerStatus::Stopped),
 897                    (server_2_id.clone(), ContextServerStatus::Starting),
 898                    (server_2_id.clone(), ContextServerStatus::Running),
 899                ],
 900                cx,
 901            );
 902            set_context_server_configuration(
 903                vec![
 904                    (
 905                        server_1_id.0.clone(),
 906                        ContextServerSettings::Extension {
 907                            settings: json!({
 908                                "somevalue": false
 909                            }),
 910                        },
 911                    ),
 912                    (
 913                        server_2_id.0.clone(),
 914                        ContextServerSettings::Custom {
 915                            command: ContextServerCommand {
 916                                path: "somebinary".to_string(),
 917                                args: vec!["anotherArg".to_string()],
 918                                env: None,
 919                            },
 920                        },
 921                    ),
 922                ],
 923                cx,
 924            );
 925
 926            cx.run_until_parked();
 927        }
 928
 929        // Ensure that mcp-2 is removed once it is removed from the settings
 930        {
 931            let _server_events = assert_server_events(
 932                &store,
 933                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
 934                cx,
 935            );
 936            set_context_server_configuration(
 937                vec![(
 938                    server_1_id.0.clone(),
 939                    ContextServerSettings::Extension {
 940                        settings: json!({
 941                            "somevalue": false
 942                        }),
 943                    },
 944                )],
 945                cx,
 946            );
 947
 948            cx.run_until_parked();
 949
 950            cx.update(|cx| {
 951                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 952            });
 953        }
 954    }
 955
 956    fn set_context_server_configuration(
 957        context_servers: Vec<(Arc<str>, ContextServerSettings)>,
 958        cx: &mut TestAppContext,
 959    ) {
 960        cx.update(|cx| {
 961            SettingsStore::update_global(cx, |store, cx| {
 962                let mut settings = ProjectSettings::default();
 963                for (id, config) in context_servers {
 964                    settings.context_servers.insert(id, config);
 965                }
 966                store
 967                    .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
 968                    .unwrap();
 969            })
 970        });
 971    }
 972
 973    struct ServerEvents {
 974        received_event_count: Rc<RefCell<usize>>,
 975        expected_event_count: usize,
 976        _subscription: Subscription,
 977    }
 978
 979    impl Drop for ServerEvents {
 980        fn drop(&mut self) {
 981            let actual_event_count = *self.received_event_count.borrow();
 982            assert_eq!(
 983                actual_event_count, self.expected_event_count,
 984                "
 985                Expected to receive {} context server store events, but received {} events",
 986                self.expected_event_count, actual_event_count
 987            );
 988        }
 989    }
 990
 991    fn dummy_server_settings() -> ContextServerSettings {
 992        ContextServerSettings::Custom {
 993            command: ContextServerCommand {
 994                path: "somebinary".to_string(),
 995                args: vec!["arg".to_string()],
 996                env: None,
 997            },
 998        }
 999    }
1000
1001    fn assert_server_events(
1002        store: &Entity<ContextServerStore>,
1003        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
1004        cx: &mut TestAppContext,
1005    ) -> ServerEvents {
1006        cx.update(|cx| {
1007            let mut ix = 0;
1008            let received_event_count = Rc::new(RefCell::new(0));
1009            let expected_event_count = expected_events.len();
1010            let subscription = cx.subscribe(store, {
1011                let received_event_count = received_event_count.clone();
1012                move |_, event, _| match event {
1013                    Event::ServerStatusChanged {
1014                        server_id: actual_server_id,
1015                        status: actual_status,
1016                    } => {
1017                        let (expected_server_id, expected_status) = &expected_events[ix];
1018
1019                        assert_eq!(
1020                            actual_server_id, expected_server_id,
1021                            "Expected different server id at index {}",
1022                            ix
1023                        );
1024                        assert_eq!(
1025                            actual_status, expected_status,
1026                            "Expected different status at index {}",
1027                            ix
1028                        );
1029                        ix += 1;
1030                        *received_event_count.borrow_mut() += 1;
1031                    }
1032                }
1033            });
1034            ServerEvents {
1035                expected_event_count,
1036                received_event_count,
1037                _subscription: subscription,
1038            }
1039        })
1040    }
1041
1042    async fn setup_context_server_test(
1043        cx: &mut TestAppContext,
1044        files: serde_json::Value,
1045        context_server_configurations: Vec<(Arc<str>, ContextServerSettings)>,
1046    ) -> (Arc<FakeFs>, Entity<Project>) {
1047        cx.update(|cx| {
1048            let settings_store = SettingsStore::test(cx);
1049            cx.set_global(settings_store);
1050            Project::init_settings(cx);
1051            let mut settings = ProjectSettings::get_global(cx).clone();
1052            for (id, config) in context_server_configurations {
1053                settings.context_servers.insert(id, config);
1054            }
1055            ProjectSettings::override_global(settings, cx);
1056        });
1057
1058        let fs = FakeFs::new(cx.executor());
1059        fs.insert_tree(path!("/test"), files).await;
1060        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1061
1062        (fs, project)
1063    }
1064
1065    struct FakeContextServerDescriptor {
1066        path: String,
1067    }
1068
1069    impl FakeContextServerDescriptor {
1070        fn new(path: impl Into<String>) -> Self {
1071            Self { path: path.into() }
1072        }
1073    }
1074
1075    impl ContextServerDescriptor for FakeContextServerDescriptor {
1076        fn command(
1077            &self,
1078            _worktree_store: Entity<WorktreeStore>,
1079            _cx: &AsyncApp,
1080        ) -> Task<Result<ContextServerCommand>> {
1081            Task::ready(Ok(ContextServerCommand {
1082                path: self.path.clone(),
1083                args: vec!["arg1".to_string(), "arg2".to_string()],
1084                env: None,
1085            }))
1086        }
1087
1088        fn configuration(
1089            &self,
1090            _worktree_store: Entity<WorktreeStore>,
1091            _cx: &AsyncApp,
1092        ) -> Task<Result<Option<::extension::ContextServerConfiguration>>> {
1093            Task::ready(Ok(None))
1094        }
1095    }
1096}