context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::{path::Path, sync::Arc};
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerId};
   9use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  10use registry::ContextServerDescriptorRegistry;
  11use settings::{Settings as _, SettingsStore};
  12use util::ResultExt as _;
  13
  14use crate::{
  15    project_settings::{ContextServerConfiguration, ProjectSettings},
  16    worktree_store::WorktreeStore,
  17};
  18
  19pub fn init(cx: &mut App) {
  20    extension::init(cx);
  21}
  22
  23actions!(context_server, [Restart]);
  24
  25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  26pub enum ContextServerStatus {
  27    Starting,
  28    Running,
  29    Stopped,
  30    Error(Arc<str>),
  31}
  32
  33impl ContextServerStatus {
  34    fn from_state(state: &ContextServerState) -> Self {
  35        match state {
  36            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  37            ContextServerState::Running { .. } => ContextServerStatus::Running,
  38            ContextServerState::Stopped { error, .. } => {
  39                if let Some(error) = error {
  40                    ContextServerStatus::Error(error.clone())
  41                } else {
  42                    ContextServerStatus::Stopped
  43                }
  44            }
  45        }
  46    }
  47}
  48
  49enum ContextServerState {
  50    Starting {
  51        server: Arc<ContextServer>,
  52        configuration: Arc<ContextServerConfiguration>,
  53        _task: Task<()>,
  54    },
  55    Running {
  56        server: Arc<ContextServer>,
  57        configuration: Arc<ContextServerConfiguration>,
  58    },
  59    Stopped {
  60        server: Arc<ContextServer>,
  61        configuration: Arc<ContextServerConfiguration>,
  62        error: Option<Arc<str>>,
  63    },
  64}
  65
  66impl ContextServerState {
  67    pub fn server(&self) -> Arc<ContextServer> {
  68        match self {
  69            ContextServerState::Starting { server, .. } => server.clone(),
  70            ContextServerState::Running { server, .. } => server.clone(),
  71            ContextServerState::Stopped { server, .. } => server.clone(),
  72        }
  73    }
  74
  75    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  76        match self {
  77            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  78            ContextServerState::Running { configuration, .. } => configuration.clone(),
  79            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  80        }
  81    }
  82}
  83
  84pub type ContextServerFactory =
  85    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
  86
  87pub struct ContextServerStore {
  88    servers: HashMap<ContextServerId, ContextServerState>,
  89    worktree_store: Entity<WorktreeStore>,
  90    registry: Entity<ContextServerDescriptorRegistry>,
  91    update_servers_task: Option<Task<Result<()>>>,
  92    context_server_factory: Option<ContextServerFactory>,
  93    needs_server_update: bool,
  94    _subscriptions: Vec<Subscription>,
  95}
  96
  97pub enum Event {
  98    ServerStatusChanged {
  99        server_id: ContextServerId,
 100        status: ContextServerStatus,
 101    },
 102}
 103
 104impl EventEmitter<Event> for ContextServerStore {}
 105
 106impl ContextServerStore {
 107    pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
 108        Self::new_internal(
 109            true,
 110            None,
 111            ContextServerDescriptorRegistry::default_global(cx),
 112            worktree_store,
 113            cx,
 114        )
 115    }
 116
 117    #[cfg(any(test, feature = "test-support"))]
 118    pub fn test(
 119        registry: Entity<ContextServerDescriptorRegistry>,
 120        worktree_store: Entity<WorktreeStore>,
 121        cx: &mut Context<Self>,
 122    ) -> Self {
 123        Self::new_internal(false, None, registry, worktree_store, cx)
 124    }
 125
 126    #[cfg(any(test, feature = "test-support"))]
 127    pub fn test_maintain_server_loop(
 128        context_server_factory: ContextServerFactory,
 129        registry: Entity<ContextServerDescriptorRegistry>,
 130        worktree_store: Entity<WorktreeStore>,
 131        cx: &mut Context<Self>,
 132    ) -> Self {
 133        Self::new_internal(
 134            true,
 135            Some(context_server_factory),
 136            registry,
 137            worktree_store,
 138            cx,
 139        )
 140    }
 141
 142    fn new_internal(
 143        maintain_server_loop: bool,
 144        context_server_factory: Option<ContextServerFactory>,
 145        registry: Entity<ContextServerDescriptorRegistry>,
 146        worktree_store: Entity<WorktreeStore>,
 147        cx: &mut Context<Self>,
 148    ) -> Self {
 149        let subscriptions = if maintain_server_loop {
 150            vec![
 151                cx.observe(&registry, |this, _registry, cx| {
 152                    this.available_context_servers_changed(cx);
 153                }),
 154                cx.observe_global::<SettingsStore>(|this, cx| {
 155                    this.available_context_servers_changed(cx);
 156                }),
 157            ]
 158        } else {
 159            Vec::new()
 160        };
 161
 162        let mut this = Self {
 163            _subscriptions: subscriptions,
 164            worktree_store,
 165            registry,
 166            needs_server_update: false,
 167            servers: HashMap::default(),
 168            update_servers_task: None,
 169            context_server_factory,
 170        };
 171        if maintain_server_loop {
 172            this.available_context_servers_changed(cx);
 173        }
 174        this
 175    }
 176
 177    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 178        self.servers.get(id).map(|state| state.server())
 179    }
 180
 181    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 182        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 183            Some(server.clone())
 184        } else {
 185            None
 186        }
 187    }
 188
 189    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 190        self.servers.get(id).map(ContextServerStatus::from_state)
 191    }
 192
 193    pub fn all_server_ids(&self) -> Vec<ContextServerId> {
 194        self.servers.keys().cloned().collect()
 195    }
 196
 197    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 198        self.servers
 199            .values()
 200            .filter_map(|state| {
 201                if let ContextServerState::Running { server, .. } = state {
 202                    Some(server.clone())
 203                } else {
 204                    None
 205                }
 206            })
 207            .collect()
 208    }
 209
 210    pub fn start_server(
 211        &mut self,
 212        server: Arc<ContextServer>,
 213        cx: &mut Context<Self>,
 214    ) -> Result<()> {
 215        let location = self
 216            .worktree_store
 217            .read(cx)
 218            .visible_worktrees(cx)
 219            .next()
 220            .map(|worktree| settings::SettingsLocation {
 221                worktree_id: worktree.read(cx).id(),
 222                path: Path::new(""),
 223            });
 224        let settings = ProjectSettings::get(location, cx);
 225        let configuration = settings
 226            .context_servers
 227            .get(&server.id().0)
 228            .context("Failed to load context server configuration from settings")?
 229            .clone();
 230
 231        self.run_server(server, Arc::new(configuration), cx);
 232        Ok(())
 233    }
 234
 235    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 236        let state = self
 237            .servers
 238            .remove(id)
 239            .context("Context server not found")?;
 240
 241        let server = state.server();
 242        let configuration = state.configuration();
 243        let mut result = Ok(());
 244        if let ContextServerState::Running { server, .. } = &state {
 245            result = server.stop();
 246        }
 247        drop(state);
 248
 249        self.update_server_state(
 250            id.clone(),
 251            ContextServerState::Stopped {
 252                configuration,
 253                server,
 254                error: None,
 255            },
 256            cx,
 257        );
 258
 259        result
 260    }
 261
 262    pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 263        if let Some(state) = self.servers.get(&id) {
 264            let configuration = state.configuration();
 265
 266            self.stop_server(&state.server().id(), cx)?;
 267            let new_server = self.create_context_server(id.clone(), configuration.clone())?;
 268            self.run_server(new_server, configuration, cx);
 269        }
 270        Ok(())
 271    }
 272
 273    fn run_server(
 274        &mut self,
 275        server: Arc<ContextServer>,
 276        configuration: Arc<ContextServerConfiguration>,
 277        cx: &mut Context<Self>,
 278    ) {
 279        let id = server.id();
 280        if matches!(
 281            self.servers.get(&id),
 282            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 283        ) {
 284            self.stop_server(&id, cx).log_err();
 285        }
 286
 287        let task = cx.spawn({
 288            let id = server.id();
 289            let server = server.clone();
 290            let configuration = configuration.clone();
 291            async move |this, cx| {
 292                match server.clone().start(&cx).await {
 293                    Ok(_) => {
 294                        log::info!("Started {} context server", id);
 295                        debug_assert!(server.client().is_some());
 296
 297                        this.update(cx, |this, cx| {
 298                            this.update_server_state(
 299                                id.clone(),
 300                                ContextServerState::Running {
 301                                    server,
 302                                    configuration,
 303                                },
 304                                cx,
 305                            )
 306                        })
 307                        .log_err()
 308                    }
 309                    Err(err) => {
 310                        log::error!("{} context server failed to start: {}", id, err);
 311                        this.update(cx, |this, cx| {
 312                            this.update_server_state(
 313                                id.clone(),
 314                                ContextServerState::Stopped {
 315                                    configuration,
 316                                    server,
 317                                    error: Some(err.to_string().into()),
 318                                },
 319                                cx,
 320                            )
 321                        })
 322                        .log_err()
 323                    }
 324                };
 325            }
 326        });
 327
 328        self.update_server_state(
 329            id.clone(),
 330            ContextServerState::Starting {
 331                configuration,
 332                _task: task,
 333                server,
 334            },
 335            cx,
 336        );
 337    }
 338
 339    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 340        let state = self
 341            .servers
 342            .remove(id)
 343            .context("Context server not found")?;
 344        drop(state);
 345        cx.emit(Event::ServerStatusChanged {
 346            server_id: id.clone(),
 347            status: ContextServerStatus::Stopped,
 348        });
 349        Ok(())
 350    }
 351
 352    fn is_configuration_valid(&self, configuration: &ContextServerConfiguration) -> bool {
 353        // Command must be some when we are running in stdio mode.
 354        self.context_server_factory.as_ref().is_some() || configuration.command.is_some()
 355    }
 356
 357    fn create_context_server(
 358        &self,
 359        id: ContextServerId,
 360        configuration: Arc<ContextServerConfiguration>,
 361    ) -> Result<Arc<ContextServer>> {
 362        if let Some(factory) = self.context_server_factory.as_ref() {
 363            Ok(factory(id, configuration))
 364        } else {
 365            let command = configuration
 366                .command
 367                .clone()
 368                .context("Missing command to run context server")?;
 369            Ok(Arc::new(ContextServer::stdio(id, command)))
 370        }
 371    }
 372
 373    fn update_server_state(
 374        &mut self,
 375        id: ContextServerId,
 376        state: ContextServerState,
 377        cx: &mut Context<Self>,
 378    ) {
 379        let status = ContextServerStatus::from_state(&state);
 380        self.servers.insert(id.clone(), state);
 381        cx.emit(Event::ServerStatusChanged {
 382            server_id: id,
 383            status,
 384        });
 385    }
 386
 387    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 388        if self.update_servers_task.is_some() {
 389            self.needs_server_update = true;
 390        } else {
 391            self.needs_server_update = false;
 392            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 393                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 394                    log::error!("Error maintaining context servers: {}", err);
 395                }
 396
 397                this.update(cx, |this, cx| {
 398                    this.update_servers_task.take();
 399                    if this.needs_server_update {
 400                        this.available_context_servers_changed(cx);
 401                    }
 402                })?;
 403
 404                Ok(())
 405            }));
 406        }
 407    }
 408
 409    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 410        let mut desired_servers = HashMap::default();
 411
 412        let (registry, worktree_store) = this.update(cx, |this, cx| {
 413            let location = this
 414                .worktree_store
 415                .read(cx)
 416                .visible_worktrees(cx)
 417                .next()
 418                .map(|worktree| settings::SettingsLocation {
 419                    worktree_id: worktree.read(cx).id(),
 420                    path: Path::new(""),
 421                });
 422            let settings = ProjectSettings::get(location, cx);
 423            desired_servers = settings.context_servers.clone();
 424
 425            (this.registry.clone(), this.worktree_store.clone())
 426        })?;
 427
 428        for (id, descriptor) in
 429            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 430        {
 431            let config = desired_servers.entry(id.clone()).or_default();
 432            if config.command.is_none() {
 433                if let Some(extension_command) = descriptor
 434                    .command(worktree_store.clone(), &cx)
 435                    .await
 436                    .log_err()
 437                {
 438                    config.command = Some(extension_command);
 439                }
 440            }
 441        }
 442
 443        this.update(cx, |this, _| {
 444            // Filter out configurations without commands, the user uninstalled an extension.
 445            desired_servers.retain(|_, configuration| this.is_configuration_valid(configuration));
 446        })?;
 447
 448        let mut servers_to_start = Vec::new();
 449        let mut servers_to_remove = HashSet::default();
 450        let mut servers_to_stop = HashSet::default();
 451
 452        this.update(cx, |this, _cx| {
 453            for server_id in this.servers.keys() {
 454                // All servers that are not in desired_servers should be removed from the store.
 455                // E.g. this can happen if the user removed a server from the configuration,
 456                // or the user uninstalled an extension.
 457                if !desired_servers.contains_key(&server_id.0) {
 458                    servers_to_remove.insert(server_id.clone());
 459                }
 460            }
 461
 462            for (id, config) in desired_servers {
 463                let id = ContextServerId(id.clone());
 464
 465                let existing_config = this.servers.get(&id).map(|state| state.configuration());
 466                if existing_config.as_deref() != Some(&config) {
 467                    let config = Arc::new(config);
 468                    if let Some(server) = this
 469                        .create_context_server(id.clone(), config.clone())
 470                        .log_err()
 471                    {
 472                        servers_to_start.push((server, config));
 473                        if this.servers.contains_key(&id) {
 474                            servers_to_stop.insert(id);
 475                        }
 476                    }
 477                }
 478            }
 479        })?;
 480
 481        for id in servers_to_stop {
 482            this.update(cx, |this, cx| this.stop_server(&id, cx).ok())?;
 483        }
 484
 485        for id in servers_to_remove {
 486            this.update(cx, |this, cx| this.remove_server(&id, cx).ok())?;
 487        }
 488
 489        for (server, config) in servers_to_start {
 490            this.update(cx, |this, cx| this.run_server(server, config, cx))
 491                .log_err();
 492        }
 493
 494        Ok(())
 495    }
 496}
 497
 498#[cfg(test)]
 499mod tests {
 500    use super::*;
 501    use crate::{FakeFs, Project, project_settings::ProjectSettings};
 502    use context_server::{
 503        transport::Transport,
 504        types::{
 505            self, Implementation, InitializeResponse, ProtocolVersion, RequestType,
 506            ServerCapabilities,
 507        },
 508    };
 509    use futures::{Stream, StreamExt as _, lock::Mutex};
 510    use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _};
 511    use serde_json::json;
 512    use std::{cell::RefCell, pin::Pin, rc::Rc};
 513    use util::path;
 514
 515    #[gpui::test]
 516    async fn test_context_server_status(cx: &mut TestAppContext) {
 517        const SERVER_1_ID: &'static str = "mcp-1";
 518        const SERVER_2_ID: &'static str = "mcp-2";
 519
 520        let (_fs, project) = setup_context_server_test(
 521            cx,
 522            json!({"code.rs": ""}),
 523            vec![
 524                (SERVER_1_ID.into(), ContextServerConfiguration::default()),
 525                (SERVER_2_ID.into(), ContextServerConfiguration::default()),
 526            ],
 527        )
 528        .await;
 529
 530        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 531        let store = cx.new(|cx| {
 532            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 533        });
 534
 535        let server_1_id = ContextServerId("mcp-1".into());
 536        let server_2_id = ContextServerId("mcp-2".into());
 537
 538        let transport_1 =
 539            Arc::new(FakeTransport::new(
 540                cx.executor(),
 541                |_, request_type, _| match request_type {
 542                    Some(RequestType::Initialize) => {
 543                        Some(create_initialize_response("mcp-1".to_string()))
 544                    }
 545                    _ => None,
 546                },
 547            ));
 548
 549        let transport_2 =
 550            Arc::new(FakeTransport::new(
 551                cx.executor(),
 552                |_, request_type, _| match request_type {
 553                    Some(RequestType::Initialize) => {
 554                        Some(create_initialize_response("mcp-2".to_string()))
 555                    }
 556                    _ => None,
 557                },
 558            ));
 559
 560        let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
 561        let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
 562
 563        store
 564            .update(cx, |store, cx| store.start_server(server_1, cx))
 565            .unwrap();
 566
 567        cx.run_until_parked();
 568
 569        cx.update(|cx| {
 570            assert_eq!(
 571                store.read(cx).status_for_server(&server_1_id),
 572                Some(ContextServerStatus::Running)
 573            );
 574            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 575        });
 576
 577        store
 578            .update(cx, |store, cx| store.start_server(server_2.clone(), cx))
 579            .unwrap();
 580
 581        cx.run_until_parked();
 582
 583        cx.update(|cx| {
 584            assert_eq!(
 585                store.read(cx).status_for_server(&server_1_id),
 586                Some(ContextServerStatus::Running)
 587            );
 588            assert_eq!(
 589                store.read(cx).status_for_server(&server_2_id),
 590                Some(ContextServerStatus::Running)
 591            );
 592        });
 593
 594        store
 595            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 596            .unwrap();
 597
 598        cx.update(|cx| {
 599            assert_eq!(
 600                store.read(cx).status_for_server(&server_1_id),
 601                Some(ContextServerStatus::Running)
 602            );
 603            assert_eq!(
 604                store.read(cx).status_for_server(&server_2_id),
 605                Some(ContextServerStatus::Stopped)
 606            );
 607        });
 608    }
 609
 610    #[gpui::test]
 611    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 612        const SERVER_1_ID: &'static str = "mcp-1";
 613        const SERVER_2_ID: &'static str = "mcp-2";
 614
 615        let (_fs, project) = setup_context_server_test(
 616            cx,
 617            json!({"code.rs": ""}),
 618            vec![
 619                (SERVER_1_ID.into(), ContextServerConfiguration::default()),
 620                (SERVER_2_ID.into(), ContextServerConfiguration::default()),
 621            ],
 622        )
 623        .await;
 624
 625        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 626        let store = cx.new(|cx| {
 627            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 628        });
 629
 630        let server_1_id = ContextServerId("mcp-1".into());
 631        let server_2_id = ContextServerId("mcp-2".into());
 632
 633        let transport_1 =
 634            Arc::new(FakeTransport::new(
 635                cx.executor(),
 636                |_, request_type, _| match request_type {
 637                    Some(RequestType::Initialize) => {
 638                        Some(create_initialize_response("mcp-1".to_string()))
 639                    }
 640                    _ => None,
 641                },
 642            ));
 643
 644        let transport_2 =
 645            Arc::new(FakeTransport::new(
 646                cx.executor(),
 647                |_, request_type, _| match request_type {
 648                    Some(RequestType::Initialize) => {
 649                        Some(create_initialize_response("mcp-2".to_string()))
 650                    }
 651                    _ => None,
 652                },
 653            ));
 654
 655        let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
 656        let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
 657
 658        let _server_events = assert_server_events(
 659            &store,
 660            vec![
 661                (server_1_id.clone(), ContextServerStatus::Starting),
 662                (server_1_id.clone(), ContextServerStatus::Running),
 663                (server_2_id.clone(), ContextServerStatus::Starting),
 664                (server_2_id.clone(), ContextServerStatus::Running),
 665                (server_2_id.clone(), ContextServerStatus::Stopped),
 666            ],
 667            cx,
 668        );
 669
 670        store
 671            .update(cx, |store, cx| store.start_server(server_1, cx))
 672            .unwrap();
 673
 674        cx.run_until_parked();
 675
 676        store
 677            .update(cx, |store, cx| store.start_server(server_2.clone(), cx))
 678            .unwrap();
 679
 680        cx.run_until_parked();
 681
 682        store
 683            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 684            .unwrap();
 685    }
 686
 687    #[gpui::test(iterations = 25)]
 688    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 689        const SERVER_1_ID: &'static str = "mcp-1";
 690
 691        let (_fs, project) = setup_context_server_test(
 692            cx,
 693            json!({"code.rs": ""}),
 694            vec![(SERVER_1_ID.into(), ContextServerConfiguration::default())],
 695        )
 696        .await;
 697
 698        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 699        let store = cx.new(|cx| {
 700            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 701        });
 702
 703        let server_id = ContextServerId(SERVER_1_ID.into());
 704
 705        let transport_1 =
 706            Arc::new(FakeTransport::new(
 707                cx.executor(),
 708                |_, request_type, _| match request_type {
 709                    Some(RequestType::Initialize) => {
 710                        Some(create_initialize_response(SERVER_1_ID.to_string()))
 711                    }
 712                    _ => None,
 713                },
 714            ));
 715
 716        let transport_2 =
 717            Arc::new(FakeTransport::new(
 718                cx.executor(),
 719                |_, request_type, _| match request_type {
 720                    Some(RequestType::Initialize) => {
 721                        Some(create_initialize_response(SERVER_1_ID.to_string()))
 722                    }
 723                    _ => None,
 724                },
 725            ));
 726
 727        let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1));
 728        let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2));
 729
 730        // If we start another server with the same id, we should report that we stopped the previous one
 731        let _server_events = assert_server_events(
 732            &store,
 733            vec![
 734                (server_id.clone(), ContextServerStatus::Starting),
 735                (server_id.clone(), ContextServerStatus::Stopped),
 736                (server_id.clone(), ContextServerStatus::Starting),
 737                (server_id.clone(), ContextServerStatus::Running),
 738            ],
 739            cx,
 740        );
 741
 742        store
 743            .update(cx, |store, cx| {
 744                store.start_server(server_with_same_id_1.clone(), cx)
 745            })
 746            .unwrap();
 747        store
 748            .update(cx, |store, cx| {
 749                store.start_server(server_with_same_id_2.clone(), cx)
 750            })
 751            .unwrap();
 752        cx.update(|cx| {
 753            assert_eq!(
 754                store.read(cx).status_for_server(&server_id),
 755                Some(ContextServerStatus::Starting)
 756            );
 757        });
 758
 759        cx.run_until_parked();
 760
 761        cx.update(|cx| {
 762            assert_eq!(
 763                store.read(cx).status_for_server(&server_id),
 764                Some(ContextServerStatus::Running)
 765            );
 766        });
 767    }
 768
 769    #[gpui::test]
 770    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 771        const SERVER_1_ID: &'static str = "mcp-1";
 772        const SERVER_2_ID: &'static str = "mcp-2";
 773
 774        let server_1_id = ContextServerId(SERVER_1_ID.into());
 775        let server_2_id = ContextServerId(SERVER_2_ID.into());
 776
 777        let (_fs, project) = setup_context_server_test(
 778            cx,
 779            json!({"code.rs": ""}),
 780            vec![(
 781                SERVER_1_ID.into(),
 782                ContextServerConfiguration {
 783                    command: None,
 784                    settings: Some(json!({
 785                        "somevalue": true
 786                    })),
 787                },
 788            )],
 789        )
 790        .await;
 791
 792        let executor = cx.executor();
 793        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 794        let store = cx.new(|cx| {
 795            ContextServerStore::test_maintain_server_loop(
 796                Box::new(move |id, _| {
 797                    let transport = FakeTransport::new(executor.clone(), {
 798                        let id = id.0.clone();
 799                        move |_, request_type, _| match request_type {
 800                            Some(RequestType::Initialize) => {
 801                                Some(create_initialize_response(id.clone().to_string()))
 802                            }
 803                            _ => None,
 804                        }
 805                    });
 806                    Arc::new(ContextServer::new(id.clone(), Arc::new(transport)))
 807                }),
 808                registry.clone(),
 809                project.read(cx).worktree_store(),
 810                cx,
 811            )
 812        });
 813
 814        // Ensure that mcp-1 starts up
 815        {
 816            let _server_events = assert_server_events(
 817                &store,
 818                vec![
 819                    (server_1_id.clone(), ContextServerStatus::Starting),
 820                    (server_1_id.clone(), ContextServerStatus::Running),
 821                ],
 822                cx,
 823            );
 824            cx.run_until_parked();
 825        }
 826
 827        // Ensure that mcp-1 is restarted when the configuration was changed
 828        {
 829            let _server_events = assert_server_events(
 830                &store,
 831                vec![
 832                    (server_1_id.clone(), ContextServerStatus::Stopped),
 833                    (server_1_id.clone(), ContextServerStatus::Starting),
 834                    (server_1_id.clone(), ContextServerStatus::Running),
 835                ],
 836                cx,
 837            );
 838            set_context_server_configuration(
 839                vec![(
 840                    server_1_id.0.clone(),
 841                    ContextServerConfiguration {
 842                        command: None,
 843                        settings: Some(json!({
 844                            "somevalue": false
 845                        })),
 846                    },
 847                )],
 848                cx,
 849            );
 850
 851            cx.run_until_parked();
 852        }
 853
 854        // Ensure that mcp-1 is not restarted when the configuration was not changed
 855        {
 856            let _server_events = assert_server_events(&store, vec![], cx);
 857            set_context_server_configuration(
 858                vec![(
 859                    server_1_id.0.clone(),
 860                    ContextServerConfiguration {
 861                        command: None,
 862                        settings: Some(json!({
 863                            "somevalue": false
 864                        })),
 865                    },
 866                )],
 867                cx,
 868            );
 869
 870            cx.run_until_parked();
 871        }
 872
 873        // Ensure that mcp-2 is started once it is added to the settings
 874        {
 875            let _server_events = assert_server_events(
 876                &store,
 877                vec![
 878                    (server_2_id.clone(), ContextServerStatus::Starting),
 879                    (server_2_id.clone(), ContextServerStatus::Running),
 880                ],
 881                cx,
 882            );
 883            set_context_server_configuration(
 884                vec![
 885                    (
 886                        server_1_id.0.clone(),
 887                        ContextServerConfiguration {
 888                            command: None,
 889                            settings: Some(json!({
 890                                "somevalue": false
 891                            })),
 892                        },
 893                    ),
 894                    (
 895                        server_2_id.0.clone(),
 896                        ContextServerConfiguration {
 897                            command: None,
 898                            settings: Some(json!({
 899                                "somevalue": true
 900                            })),
 901                        },
 902                    ),
 903                ],
 904                cx,
 905            );
 906
 907            cx.run_until_parked();
 908        }
 909
 910        // Ensure that mcp-2 is removed once it is removed from the settings
 911        {
 912            let _server_events = assert_server_events(
 913                &store,
 914                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
 915                cx,
 916            );
 917            set_context_server_configuration(
 918                vec![(
 919                    server_1_id.0.clone(),
 920                    ContextServerConfiguration {
 921                        command: None,
 922                        settings: Some(json!({
 923                            "somevalue": false
 924                        })),
 925                    },
 926                )],
 927                cx,
 928            );
 929
 930            cx.run_until_parked();
 931
 932            cx.update(|cx| {
 933                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 934            });
 935        }
 936    }
 937
 938    fn set_context_server_configuration(
 939        context_servers: Vec<(Arc<str>, ContextServerConfiguration)>,
 940        cx: &mut TestAppContext,
 941    ) {
 942        cx.update(|cx| {
 943            SettingsStore::update_global(cx, |store, cx| {
 944                let mut settings = ProjectSettings::default();
 945                for (id, config) in context_servers {
 946                    settings.context_servers.insert(id, config);
 947                }
 948                store
 949                    .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
 950                    .unwrap();
 951            })
 952        });
 953    }
 954
 955    struct ServerEvents {
 956        received_event_count: Rc<RefCell<usize>>,
 957        expected_event_count: usize,
 958        _subscription: Subscription,
 959    }
 960
 961    impl Drop for ServerEvents {
 962        fn drop(&mut self) {
 963            let actual_event_count = *self.received_event_count.borrow();
 964            assert_eq!(
 965                actual_event_count, self.expected_event_count,
 966                "
 967                Expected to receive {} context server store events, but received {} events",
 968                self.expected_event_count, actual_event_count
 969            );
 970        }
 971    }
 972
 973    fn assert_server_events(
 974        store: &Entity<ContextServerStore>,
 975        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
 976        cx: &mut TestAppContext,
 977    ) -> ServerEvents {
 978        cx.update(|cx| {
 979            let mut ix = 0;
 980            let received_event_count = Rc::new(RefCell::new(0));
 981            let expected_event_count = expected_events.len();
 982            let subscription = cx.subscribe(store, {
 983                let received_event_count = received_event_count.clone();
 984                move |_, event, _| match event {
 985                    Event::ServerStatusChanged {
 986                        server_id: actual_server_id,
 987                        status: actual_status,
 988                    } => {
 989                        let (expected_server_id, expected_status) = &expected_events[ix];
 990
 991                        assert_eq!(
 992                            actual_server_id, expected_server_id,
 993                            "Expected different server id at index {}",
 994                            ix
 995                        );
 996                        assert_eq!(
 997                            actual_status, expected_status,
 998                            "Expected different status at index {}",
 999                            ix
1000                        );
1001                        ix += 1;
1002                        *received_event_count.borrow_mut() += 1;
1003                    }
1004                }
1005            });
1006            ServerEvents {
1007                expected_event_count,
1008                received_event_count,
1009                _subscription: subscription,
1010            }
1011        })
1012    }
1013
1014    async fn setup_context_server_test(
1015        cx: &mut TestAppContext,
1016        files: serde_json::Value,
1017        context_server_configurations: Vec<(Arc<str>, ContextServerConfiguration)>,
1018    ) -> (Arc<FakeFs>, Entity<Project>) {
1019        cx.update(|cx| {
1020            let settings_store = SettingsStore::test(cx);
1021            cx.set_global(settings_store);
1022            Project::init_settings(cx);
1023            let mut settings = ProjectSettings::get_global(cx).clone();
1024            for (id, config) in context_server_configurations {
1025                settings.context_servers.insert(id, config);
1026            }
1027            ProjectSettings::override_global(settings, cx);
1028        });
1029
1030        let fs = FakeFs::new(cx.executor());
1031        fs.insert_tree(path!("/test"), files).await;
1032        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1033
1034        (fs, project)
1035    }
1036
1037    fn create_initialize_response(server_name: String) -> serde_json::Value {
1038        serde_json::to_value(&InitializeResponse {
1039            protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
1040            server_info: Implementation {
1041                name: server_name,
1042                version: "1.0.0".to_string(),
1043            },
1044            capabilities: ServerCapabilities::default(),
1045            meta: None,
1046        })
1047        .unwrap()
1048    }
1049
1050    struct FakeTransport {
1051        on_request: Arc<
1052            dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
1053                + Send
1054                + Sync,
1055        >,
1056        tx: futures::channel::mpsc::UnboundedSender<String>,
1057        rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
1058        executor: BackgroundExecutor,
1059    }
1060
1061    impl FakeTransport {
1062        fn new(
1063            executor: BackgroundExecutor,
1064            on_request: impl Fn(
1065                u64,
1066                Option<RequestType>,
1067                serde_json::Value,
1068            ) -> Option<serde_json::Value>
1069            + 'static
1070            + Send
1071            + Sync,
1072        ) -> Self {
1073            let (tx, rx) = futures::channel::mpsc::unbounded();
1074            Self {
1075                on_request: Arc::new(on_request),
1076                tx,
1077                rx: Arc::new(Mutex::new(rx)),
1078                executor,
1079            }
1080        }
1081    }
1082
1083    #[async_trait::async_trait]
1084    impl Transport for FakeTransport {
1085        async fn send(&self, message: String) -> Result<()> {
1086            if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
1087                let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
1088
1089                if let Some(method) = msg.get("method") {
1090                    let request_type = method
1091                        .as_str()
1092                        .and_then(|method| types::RequestType::try_from(method).ok());
1093                    if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
1094                        let response = serde_json::json!({
1095                            "jsonrpc": "2.0",
1096                            "id": id,
1097                            "result": payload
1098                        });
1099
1100                        self.tx
1101                            .unbounded_send(response.to_string())
1102                            .context("sending a message")?;
1103                    }
1104                }
1105            }
1106            Ok(())
1107        }
1108
1109        fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
1110            let rx = self.rx.clone();
1111            let executor = self.executor.clone();
1112            Box::pin(futures::stream::unfold(rx, move |rx| {
1113                let executor = executor.clone();
1114                async move {
1115                    let mut rx_guard = rx.lock().await;
1116                    executor.simulate_random_delay().await;
1117                    if let Some(message) = rx_guard.next().await {
1118                        drop(rx_guard);
1119                        Some((message, rx))
1120                    } else {
1121                        None
1122                    }
1123                }
1124            }))
1125        }
1126
1127        fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
1128            Box::pin(futures::stream::empty())
1129        }
1130    }
1131}