context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::{path::Path, sync::Arc};
   5
   6use anyhow::{Context as _, Result};
   7use collections::{HashMap, HashSet};
   8use context_server::{ContextServer, ContextServerId};
   9use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  10use registry::ContextServerDescriptorRegistry;
  11use settings::{Settings as _, SettingsStore};
  12use util::ResultExt as _;
  13
  14use crate::{
  15    project_settings::{ContextServerConfiguration, ProjectSettings},
  16    worktree_store::WorktreeStore,
  17};
  18
  19pub fn init(cx: &mut App) {
  20    extension::init(cx);
  21}
  22
  23actions!(context_server, [Restart]);
  24
  25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  26pub enum ContextServerStatus {
  27    Starting,
  28    Running,
  29    Stopped,
  30    Error(Arc<str>),
  31}
  32
  33impl ContextServerStatus {
  34    fn from_state(state: &ContextServerState) -> Self {
  35        match state {
  36            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  37            ContextServerState::Running { .. } => ContextServerStatus::Running,
  38            ContextServerState::Stopped { error, .. } => {
  39                if let Some(error) = error {
  40                    ContextServerStatus::Error(error.clone())
  41                } else {
  42                    ContextServerStatus::Stopped
  43                }
  44            }
  45        }
  46    }
  47}
  48
  49enum ContextServerState {
  50    Starting {
  51        server: Arc<ContextServer>,
  52        configuration: Arc<ContextServerConfiguration>,
  53        _task: Task<()>,
  54    },
  55    Running {
  56        server: Arc<ContextServer>,
  57        configuration: Arc<ContextServerConfiguration>,
  58    },
  59    Stopped {
  60        server: Arc<ContextServer>,
  61        configuration: Arc<ContextServerConfiguration>,
  62        error: Option<Arc<str>>,
  63    },
  64}
  65
  66impl ContextServerState {
  67    pub fn server(&self) -> Arc<ContextServer> {
  68        match self {
  69            ContextServerState::Starting { server, .. } => server.clone(),
  70            ContextServerState::Running { server, .. } => server.clone(),
  71            ContextServerState::Stopped { server, .. } => server.clone(),
  72        }
  73    }
  74
  75    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  76        match self {
  77            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  78            ContextServerState::Running { configuration, .. } => configuration.clone(),
  79            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  80        }
  81    }
  82}
  83
  84pub type ContextServerFactory =
  85    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
  86
  87pub struct ContextServerStore {
  88    servers: HashMap<ContextServerId, ContextServerState>,
  89    worktree_store: Entity<WorktreeStore>,
  90    registry: Entity<ContextServerDescriptorRegistry>,
  91    update_servers_task: Option<Task<Result<()>>>,
  92    context_server_factory: Option<ContextServerFactory>,
  93    needs_server_update: bool,
  94    _subscriptions: Vec<Subscription>,
  95}
  96
  97pub enum Event {
  98    ServerStatusChanged {
  99        server_id: ContextServerId,
 100        status: ContextServerStatus,
 101    },
 102}
 103
 104impl EventEmitter<Event> for ContextServerStore {}
 105
 106impl ContextServerStore {
 107    pub fn new(worktree_store: Entity<WorktreeStore>, cx: &mut Context<Self>) -> Self {
 108        Self::new_internal(
 109            true,
 110            None,
 111            ContextServerDescriptorRegistry::default_global(cx),
 112            worktree_store,
 113            cx,
 114        )
 115    }
 116
 117    #[cfg(any(test, feature = "test-support"))]
 118    pub fn test(
 119        registry: Entity<ContextServerDescriptorRegistry>,
 120        worktree_store: Entity<WorktreeStore>,
 121        cx: &mut Context<Self>,
 122    ) -> Self {
 123        Self::new_internal(false, None, registry, worktree_store, cx)
 124    }
 125
 126    #[cfg(any(test, feature = "test-support"))]
 127    pub fn test_maintain_server_loop(
 128        context_server_factory: ContextServerFactory,
 129        registry: Entity<ContextServerDescriptorRegistry>,
 130        worktree_store: Entity<WorktreeStore>,
 131        cx: &mut Context<Self>,
 132    ) -> Self {
 133        Self::new_internal(
 134            true,
 135            Some(context_server_factory),
 136            registry,
 137            worktree_store,
 138            cx,
 139        )
 140    }
 141
 142    fn new_internal(
 143        maintain_server_loop: bool,
 144        context_server_factory: Option<ContextServerFactory>,
 145        registry: Entity<ContextServerDescriptorRegistry>,
 146        worktree_store: Entity<WorktreeStore>,
 147        cx: &mut Context<Self>,
 148    ) -> Self {
 149        let subscriptions = if maintain_server_loop {
 150            vec![
 151                cx.observe(&registry, |this, _registry, cx| {
 152                    this.available_context_servers_changed(cx);
 153                }),
 154                cx.observe_global::<SettingsStore>(|this, cx| {
 155                    this.available_context_servers_changed(cx);
 156                }),
 157            ]
 158        } else {
 159            Vec::new()
 160        };
 161
 162        let mut this = Self {
 163            _subscriptions: subscriptions,
 164            worktree_store,
 165            registry,
 166            needs_server_update: false,
 167            servers: HashMap::default(),
 168            update_servers_task: None,
 169            context_server_factory,
 170        };
 171        if maintain_server_loop {
 172            this.available_context_servers_changed(cx);
 173        }
 174        this
 175    }
 176
 177    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 178        self.servers.get(id).map(|state| state.server())
 179    }
 180
 181    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 182        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 183            Some(server.clone())
 184        } else {
 185            None
 186        }
 187    }
 188
 189    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 190        self.servers.get(id).map(ContextServerStatus::from_state)
 191    }
 192
 193    pub fn all_server_ids(&self) -> Vec<ContextServerId> {
 194        self.servers.keys().cloned().collect()
 195    }
 196
 197    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 198        self.servers
 199            .values()
 200            .filter_map(|state| {
 201                if let ContextServerState::Running { server, .. } = state {
 202                    Some(server.clone())
 203                } else {
 204                    None
 205                }
 206            })
 207            .collect()
 208    }
 209
 210    pub fn start_server(
 211        &mut self,
 212        server: Arc<ContextServer>,
 213        cx: &mut Context<Self>,
 214    ) -> Result<()> {
 215        let location = self
 216            .worktree_store
 217            .read(cx)
 218            .visible_worktrees(cx)
 219            .next()
 220            .map(|worktree| settings::SettingsLocation {
 221                worktree_id: worktree.read(cx).id(),
 222                path: Path::new(""),
 223            });
 224        let settings = ProjectSettings::get(location, cx);
 225        let configuration = settings
 226            .context_servers
 227            .get(&server.id().0)
 228            .context("Failed to load context server configuration from settings")?
 229            .clone();
 230
 231        self.run_server(server, Arc::new(configuration), cx);
 232        Ok(())
 233    }
 234
 235    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 236        let Some(state) = self.servers.remove(id) else {
 237            return Err(anyhow::anyhow!("Context server not found"));
 238        };
 239
 240        let server = state.server();
 241        let configuration = state.configuration();
 242        let mut result = Ok(());
 243        if let ContextServerState::Running { server, .. } = &state {
 244            result = server.stop();
 245        }
 246        drop(state);
 247
 248        self.update_server_state(
 249            id.clone(),
 250            ContextServerState::Stopped {
 251                configuration,
 252                server,
 253                error: None,
 254            },
 255            cx,
 256        );
 257
 258        result
 259    }
 260
 261    pub fn restart_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 262        if let Some(state) = self.servers.get(&id) {
 263            let configuration = state.configuration();
 264
 265            self.stop_server(&state.server().id(), cx)?;
 266            let new_server = self.create_context_server(id.clone(), configuration.clone())?;
 267            self.run_server(new_server, configuration, cx);
 268        }
 269        Ok(())
 270    }
 271
 272    fn run_server(
 273        &mut self,
 274        server: Arc<ContextServer>,
 275        configuration: Arc<ContextServerConfiguration>,
 276        cx: &mut Context<Self>,
 277    ) {
 278        let id = server.id();
 279        if matches!(
 280            self.servers.get(&id),
 281            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 282        ) {
 283            self.stop_server(&id, cx).log_err();
 284        }
 285
 286        let task = cx.spawn({
 287            let id = server.id();
 288            let server = server.clone();
 289            let configuration = configuration.clone();
 290            async move |this, cx| {
 291                match server.clone().start(&cx).await {
 292                    Ok(_) => {
 293                        log::info!("Started {} context server", id);
 294                        debug_assert!(server.client().is_some());
 295
 296                        this.update(cx, |this, cx| {
 297                            this.update_server_state(
 298                                id.clone(),
 299                                ContextServerState::Running {
 300                                    server,
 301                                    configuration,
 302                                },
 303                                cx,
 304                            )
 305                        })
 306                        .log_err()
 307                    }
 308                    Err(err) => {
 309                        log::error!("{} context server failed to start: {}", id, err);
 310                        this.update(cx, |this, cx| {
 311                            this.update_server_state(
 312                                id.clone(),
 313                                ContextServerState::Stopped {
 314                                    configuration,
 315                                    server,
 316                                    error: Some(err.to_string().into()),
 317                                },
 318                                cx,
 319                            )
 320                        })
 321                        .log_err()
 322                    }
 323                };
 324            }
 325        });
 326
 327        self.update_server_state(
 328            id.clone(),
 329            ContextServerState::Starting {
 330                configuration,
 331                _task: task,
 332                server,
 333            },
 334            cx,
 335        );
 336    }
 337
 338    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 339        let Some(state) = self.servers.remove(id) else {
 340            return Err(anyhow::anyhow!("Context server not found"));
 341        };
 342        drop(state);
 343        cx.emit(Event::ServerStatusChanged {
 344            server_id: id.clone(),
 345            status: ContextServerStatus::Stopped,
 346        });
 347        Ok(())
 348    }
 349
 350    fn is_configuration_valid(&self, configuration: &ContextServerConfiguration) -> bool {
 351        // Command must be some when we are running in stdio mode.
 352        self.context_server_factory.as_ref().is_some() || configuration.command.is_some()
 353    }
 354
 355    fn create_context_server(
 356        &self,
 357        id: ContextServerId,
 358        configuration: Arc<ContextServerConfiguration>,
 359    ) -> Result<Arc<ContextServer>> {
 360        if let Some(factory) = self.context_server_factory.as_ref() {
 361            Ok(factory(id, configuration))
 362        } else {
 363            let command = configuration
 364                .command
 365                .clone()
 366                .context("Missing command to run context server")?;
 367            Ok(Arc::new(ContextServer::stdio(id, command)))
 368        }
 369    }
 370
 371    fn update_server_state(
 372        &mut self,
 373        id: ContextServerId,
 374        state: ContextServerState,
 375        cx: &mut Context<Self>,
 376    ) {
 377        let status = ContextServerStatus::from_state(&state);
 378        self.servers.insert(id.clone(), state);
 379        cx.emit(Event::ServerStatusChanged {
 380            server_id: id,
 381            status,
 382        });
 383    }
 384
 385    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 386        if self.update_servers_task.is_some() {
 387            self.needs_server_update = true;
 388        } else {
 389            self.needs_server_update = false;
 390            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 391                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 392                    log::error!("Error maintaining context servers: {}", err);
 393                }
 394
 395                this.update(cx, |this, cx| {
 396                    this.update_servers_task.take();
 397                    if this.needs_server_update {
 398                        this.available_context_servers_changed(cx);
 399                    }
 400                })?;
 401
 402                Ok(())
 403            }));
 404        }
 405    }
 406
 407    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 408        let mut desired_servers = HashMap::default();
 409
 410        let (registry, worktree_store) = this.update(cx, |this, cx| {
 411            let location = this
 412                .worktree_store
 413                .read(cx)
 414                .visible_worktrees(cx)
 415                .next()
 416                .map(|worktree| settings::SettingsLocation {
 417                    worktree_id: worktree.read(cx).id(),
 418                    path: Path::new(""),
 419                });
 420            let settings = ProjectSettings::get(location, cx);
 421            desired_servers = settings.context_servers.clone();
 422
 423            (this.registry.clone(), this.worktree_store.clone())
 424        })?;
 425
 426        for (id, descriptor) in
 427            registry.read_with(cx, |registry, _| registry.context_server_descriptors())?
 428        {
 429            let config = desired_servers.entry(id.clone()).or_default();
 430            if config.command.is_none() {
 431                if let Some(extension_command) = descriptor
 432                    .command(worktree_store.clone(), &cx)
 433                    .await
 434                    .log_err()
 435                {
 436                    config.command = Some(extension_command);
 437                }
 438            }
 439        }
 440
 441        this.update(cx, |this, _| {
 442            // Filter out configurations without commands, the user uninstalled an extension.
 443            desired_servers.retain(|_, configuration| this.is_configuration_valid(configuration));
 444        })?;
 445
 446        let mut servers_to_start = Vec::new();
 447        let mut servers_to_remove = HashSet::default();
 448        let mut servers_to_stop = HashSet::default();
 449
 450        this.update(cx, |this, _cx| {
 451            for server_id in this.servers.keys() {
 452                // All servers that are not in desired_servers should be removed from the store.
 453                // E.g. this can happen if the user removed a server from the configuration,
 454                // or the user uninstalled an extension.
 455                if !desired_servers.contains_key(&server_id.0) {
 456                    servers_to_remove.insert(server_id.clone());
 457                }
 458            }
 459
 460            for (id, config) in desired_servers {
 461                let id = ContextServerId(id.clone());
 462
 463                let existing_config = this.servers.get(&id).map(|state| state.configuration());
 464                if existing_config.as_deref() != Some(&config) {
 465                    let config = Arc::new(config);
 466                    if let Some(server) = this
 467                        .create_context_server(id.clone(), config.clone())
 468                        .log_err()
 469                    {
 470                        servers_to_start.push((server, config));
 471                        if this.servers.contains_key(&id) {
 472                            servers_to_stop.insert(id);
 473                        }
 474                    }
 475                }
 476            }
 477        })?;
 478
 479        for id in servers_to_stop {
 480            this.update(cx, |this, cx| this.stop_server(&id, cx).ok())?;
 481        }
 482
 483        for id in servers_to_remove {
 484            this.update(cx, |this, cx| this.remove_server(&id, cx).ok())?;
 485        }
 486
 487        for (server, config) in servers_to_start {
 488            this.update(cx, |this, cx| this.run_server(server, config, cx))
 489                .log_err();
 490        }
 491
 492        Ok(())
 493    }
 494}
 495
 496#[cfg(test)]
 497mod tests {
 498    use super::*;
 499    use crate::{FakeFs, Project, project_settings::ProjectSettings};
 500    use context_server::{
 501        transport::Transport,
 502        types::{
 503            self, Implementation, InitializeResponse, ProtocolVersion, RequestType,
 504            ServerCapabilities,
 505        },
 506    };
 507    use futures::{Stream, StreamExt as _, lock::Mutex};
 508    use gpui::{AppContext, BackgroundExecutor, TestAppContext, UpdateGlobal as _};
 509    use serde_json::json;
 510    use std::{cell::RefCell, pin::Pin, rc::Rc};
 511    use util::path;
 512
 513    #[gpui::test]
 514    async fn test_context_server_status(cx: &mut TestAppContext) {
 515        const SERVER_1_ID: &'static str = "mcp-1";
 516        const SERVER_2_ID: &'static str = "mcp-2";
 517
 518        let (_fs, project) = setup_context_server_test(
 519            cx,
 520            json!({"code.rs": ""}),
 521            vec![
 522                (SERVER_1_ID.into(), ContextServerConfiguration::default()),
 523                (SERVER_2_ID.into(), ContextServerConfiguration::default()),
 524            ],
 525        )
 526        .await;
 527
 528        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 529        let store = cx.new(|cx| {
 530            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 531        });
 532
 533        let server_1_id = ContextServerId("mcp-1".into());
 534        let server_2_id = ContextServerId("mcp-2".into());
 535
 536        let transport_1 =
 537            Arc::new(FakeTransport::new(
 538                cx.executor(),
 539                |_, request_type, _| match request_type {
 540                    Some(RequestType::Initialize) => {
 541                        Some(create_initialize_response("mcp-1".to_string()))
 542                    }
 543                    _ => None,
 544                },
 545            ));
 546
 547        let transport_2 =
 548            Arc::new(FakeTransport::new(
 549                cx.executor(),
 550                |_, request_type, _| match request_type {
 551                    Some(RequestType::Initialize) => {
 552                        Some(create_initialize_response("mcp-2".to_string()))
 553                    }
 554                    _ => None,
 555                },
 556            ));
 557
 558        let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
 559        let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
 560
 561        store
 562            .update(cx, |store, cx| store.start_server(server_1, cx))
 563            .unwrap();
 564
 565        cx.run_until_parked();
 566
 567        cx.update(|cx| {
 568            assert_eq!(
 569                store.read(cx).status_for_server(&server_1_id),
 570                Some(ContextServerStatus::Running)
 571            );
 572            assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 573        });
 574
 575        store
 576            .update(cx, |store, cx| store.start_server(server_2.clone(), cx))
 577            .unwrap();
 578
 579        cx.run_until_parked();
 580
 581        cx.update(|cx| {
 582            assert_eq!(
 583                store.read(cx).status_for_server(&server_1_id),
 584                Some(ContextServerStatus::Running)
 585            );
 586            assert_eq!(
 587                store.read(cx).status_for_server(&server_2_id),
 588                Some(ContextServerStatus::Running)
 589            );
 590        });
 591
 592        store
 593            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 594            .unwrap();
 595
 596        cx.update(|cx| {
 597            assert_eq!(
 598                store.read(cx).status_for_server(&server_1_id),
 599                Some(ContextServerStatus::Running)
 600            );
 601            assert_eq!(
 602                store.read(cx).status_for_server(&server_2_id),
 603                Some(ContextServerStatus::Stopped)
 604            );
 605        });
 606    }
 607
 608    #[gpui::test]
 609    async fn test_context_server_status_events(cx: &mut TestAppContext) {
 610        const SERVER_1_ID: &'static str = "mcp-1";
 611        const SERVER_2_ID: &'static str = "mcp-2";
 612
 613        let (_fs, project) = setup_context_server_test(
 614            cx,
 615            json!({"code.rs": ""}),
 616            vec![
 617                (SERVER_1_ID.into(), ContextServerConfiguration::default()),
 618                (SERVER_2_ID.into(), ContextServerConfiguration::default()),
 619            ],
 620        )
 621        .await;
 622
 623        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 624        let store = cx.new(|cx| {
 625            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 626        });
 627
 628        let server_1_id = ContextServerId("mcp-1".into());
 629        let server_2_id = ContextServerId("mcp-2".into());
 630
 631        let transport_1 =
 632            Arc::new(FakeTransport::new(
 633                cx.executor(),
 634                |_, request_type, _| match request_type {
 635                    Some(RequestType::Initialize) => {
 636                        Some(create_initialize_response("mcp-1".to_string()))
 637                    }
 638                    _ => None,
 639                },
 640            ));
 641
 642        let transport_2 =
 643            Arc::new(FakeTransport::new(
 644                cx.executor(),
 645                |_, request_type, _| match request_type {
 646                    Some(RequestType::Initialize) => {
 647                        Some(create_initialize_response("mcp-2".to_string()))
 648                    }
 649                    _ => None,
 650                },
 651            ));
 652
 653        let server_1 = Arc::new(ContextServer::new(server_1_id.clone(), transport_1.clone()));
 654        let server_2 = Arc::new(ContextServer::new(server_2_id.clone(), transport_2.clone()));
 655
 656        let _server_events = assert_server_events(
 657            &store,
 658            vec![
 659                (server_1_id.clone(), ContextServerStatus::Starting),
 660                (server_1_id.clone(), ContextServerStatus::Running),
 661                (server_2_id.clone(), ContextServerStatus::Starting),
 662                (server_2_id.clone(), ContextServerStatus::Running),
 663                (server_2_id.clone(), ContextServerStatus::Stopped),
 664            ],
 665            cx,
 666        );
 667
 668        store
 669            .update(cx, |store, cx| store.start_server(server_1, cx))
 670            .unwrap();
 671
 672        cx.run_until_parked();
 673
 674        store
 675            .update(cx, |store, cx| store.start_server(server_2.clone(), cx))
 676            .unwrap();
 677
 678        cx.run_until_parked();
 679
 680        store
 681            .update(cx, |store, cx| store.stop_server(&server_2_id, cx))
 682            .unwrap();
 683    }
 684
 685    #[gpui::test(iterations = 25)]
 686    async fn test_context_server_concurrent_starts(cx: &mut TestAppContext) {
 687        const SERVER_1_ID: &'static str = "mcp-1";
 688
 689        let (_fs, project) = setup_context_server_test(
 690            cx,
 691            json!({"code.rs": ""}),
 692            vec![(SERVER_1_ID.into(), ContextServerConfiguration::default())],
 693        )
 694        .await;
 695
 696        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 697        let store = cx.new(|cx| {
 698            ContextServerStore::test(registry.clone(), project.read(cx).worktree_store(), cx)
 699        });
 700
 701        let server_id = ContextServerId(SERVER_1_ID.into());
 702
 703        let transport_1 =
 704            Arc::new(FakeTransport::new(
 705                cx.executor(),
 706                |_, request_type, _| match request_type {
 707                    Some(RequestType::Initialize) => {
 708                        Some(create_initialize_response(SERVER_1_ID.to_string()))
 709                    }
 710                    _ => None,
 711                },
 712            ));
 713
 714        let transport_2 =
 715            Arc::new(FakeTransport::new(
 716                cx.executor(),
 717                |_, request_type, _| match request_type {
 718                    Some(RequestType::Initialize) => {
 719                        Some(create_initialize_response(SERVER_1_ID.to_string()))
 720                    }
 721                    _ => None,
 722                },
 723            ));
 724
 725        let server_with_same_id_1 = Arc::new(ContextServer::new(server_id.clone(), transport_1));
 726        let server_with_same_id_2 = Arc::new(ContextServer::new(server_id.clone(), transport_2));
 727
 728        // If we start another server with the same id, we should report that we stopped the previous one
 729        let _server_events = assert_server_events(
 730            &store,
 731            vec![
 732                (server_id.clone(), ContextServerStatus::Starting),
 733                (server_id.clone(), ContextServerStatus::Stopped),
 734                (server_id.clone(), ContextServerStatus::Starting),
 735                (server_id.clone(), ContextServerStatus::Running),
 736            ],
 737            cx,
 738        );
 739
 740        store
 741            .update(cx, |store, cx| {
 742                store.start_server(server_with_same_id_1.clone(), cx)
 743            })
 744            .unwrap();
 745        store
 746            .update(cx, |store, cx| {
 747                store.start_server(server_with_same_id_2.clone(), cx)
 748            })
 749            .unwrap();
 750        cx.update(|cx| {
 751            assert_eq!(
 752                store.read(cx).status_for_server(&server_id),
 753                Some(ContextServerStatus::Starting)
 754            );
 755        });
 756
 757        cx.run_until_parked();
 758
 759        cx.update(|cx| {
 760            assert_eq!(
 761                store.read(cx).status_for_server(&server_id),
 762                Some(ContextServerStatus::Running)
 763            );
 764        });
 765    }
 766
 767    #[gpui::test]
 768    async fn test_context_server_maintain_servers_loop(cx: &mut TestAppContext) {
 769        const SERVER_1_ID: &'static str = "mcp-1";
 770        const SERVER_2_ID: &'static str = "mcp-2";
 771
 772        let server_1_id = ContextServerId(SERVER_1_ID.into());
 773        let server_2_id = ContextServerId(SERVER_2_ID.into());
 774
 775        let (_fs, project) = setup_context_server_test(
 776            cx,
 777            json!({"code.rs": ""}),
 778            vec![(
 779                SERVER_1_ID.into(),
 780                ContextServerConfiguration {
 781                    command: None,
 782                    settings: Some(json!({
 783                        "somevalue": true
 784                    })),
 785                },
 786            )],
 787        )
 788        .await;
 789
 790        let executor = cx.executor();
 791        let registry = cx.new(|_| ContextServerDescriptorRegistry::new());
 792        let store = cx.new(|cx| {
 793            ContextServerStore::test_maintain_server_loop(
 794                Box::new(move |id, _| {
 795                    let transport = FakeTransport::new(executor.clone(), {
 796                        let id = id.0.clone();
 797                        move |_, request_type, _| match request_type {
 798                            Some(RequestType::Initialize) => {
 799                                Some(create_initialize_response(id.clone().to_string()))
 800                            }
 801                            _ => None,
 802                        }
 803                    });
 804                    Arc::new(ContextServer::new(id.clone(), Arc::new(transport)))
 805                }),
 806                registry.clone(),
 807                project.read(cx).worktree_store(),
 808                cx,
 809            )
 810        });
 811
 812        // Ensure that mcp-1 starts up
 813        {
 814            let _server_events = assert_server_events(
 815                &store,
 816                vec![
 817                    (server_1_id.clone(), ContextServerStatus::Starting),
 818                    (server_1_id.clone(), ContextServerStatus::Running),
 819                ],
 820                cx,
 821            );
 822            cx.run_until_parked();
 823        }
 824
 825        // Ensure that mcp-1 is restarted when the configuration was changed
 826        {
 827            let _server_events = assert_server_events(
 828                &store,
 829                vec![
 830                    (server_1_id.clone(), ContextServerStatus::Stopped),
 831                    (server_1_id.clone(), ContextServerStatus::Starting),
 832                    (server_1_id.clone(), ContextServerStatus::Running),
 833                ],
 834                cx,
 835            );
 836            set_context_server_configuration(
 837                vec![(
 838                    server_1_id.0.clone(),
 839                    ContextServerConfiguration {
 840                        command: None,
 841                        settings: Some(json!({
 842                            "somevalue": false
 843                        })),
 844                    },
 845                )],
 846                cx,
 847            );
 848
 849            cx.run_until_parked();
 850        }
 851
 852        // Ensure that mcp-1 is not restarted when the configuration was not changed
 853        {
 854            let _server_events = assert_server_events(&store, vec![], cx);
 855            set_context_server_configuration(
 856                vec![(
 857                    server_1_id.0.clone(),
 858                    ContextServerConfiguration {
 859                        command: None,
 860                        settings: Some(json!({
 861                            "somevalue": false
 862                        })),
 863                    },
 864                )],
 865                cx,
 866            );
 867
 868            cx.run_until_parked();
 869        }
 870
 871        // Ensure that mcp-2 is started once it is added to the settings
 872        {
 873            let _server_events = assert_server_events(
 874                &store,
 875                vec![
 876                    (server_2_id.clone(), ContextServerStatus::Starting),
 877                    (server_2_id.clone(), ContextServerStatus::Running),
 878                ],
 879                cx,
 880            );
 881            set_context_server_configuration(
 882                vec![
 883                    (
 884                        server_1_id.0.clone(),
 885                        ContextServerConfiguration {
 886                            command: None,
 887                            settings: Some(json!({
 888                                "somevalue": false
 889                            })),
 890                        },
 891                    ),
 892                    (
 893                        server_2_id.0.clone(),
 894                        ContextServerConfiguration {
 895                            command: None,
 896                            settings: Some(json!({
 897                                "somevalue": true
 898                            })),
 899                        },
 900                    ),
 901                ],
 902                cx,
 903            );
 904
 905            cx.run_until_parked();
 906        }
 907
 908        // Ensure that mcp-2 is removed once it is removed from the settings
 909        {
 910            let _server_events = assert_server_events(
 911                &store,
 912                vec![(server_2_id.clone(), ContextServerStatus::Stopped)],
 913                cx,
 914            );
 915            set_context_server_configuration(
 916                vec![(
 917                    server_1_id.0.clone(),
 918                    ContextServerConfiguration {
 919                        command: None,
 920                        settings: Some(json!({
 921                            "somevalue": false
 922                        })),
 923                    },
 924                )],
 925                cx,
 926            );
 927
 928            cx.run_until_parked();
 929
 930            cx.update(|cx| {
 931                assert_eq!(store.read(cx).status_for_server(&server_2_id), None);
 932            });
 933        }
 934    }
 935
 936    fn set_context_server_configuration(
 937        context_servers: Vec<(Arc<str>, ContextServerConfiguration)>,
 938        cx: &mut TestAppContext,
 939    ) {
 940        cx.update(|cx| {
 941            SettingsStore::update_global(cx, |store, cx| {
 942                let mut settings = ProjectSettings::default();
 943                for (id, config) in context_servers {
 944                    settings.context_servers.insert(id, config);
 945                }
 946                store
 947                    .set_user_settings(&serde_json::to_string(&settings).unwrap(), cx)
 948                    .unwrap();
 949            })
 950        });
 951    }
 952
 953    struct ServerEvents {
 954        received_event_count: Rc<RefCell<usize>>,
 955        expected_event_count: usize,
 956        _subscription: Subscription,
 957    }
 958
 959    impl Drop for ServerEvents {
 960        fn drop(&mut self) {
 961            let actual_event_count = *self.received_event_count.borrow();
 962            assert_eq!(
 963                actual_event_count, self.expected_event_count,
 964                "
 965                Expected to receive {} context server store events, but received {} events",
 966                self.expected_event_count, actual_event_count
 967            );
 968        }
 969    }
 970
 971    fn assert_server_events(
 972        store: &Entity<ContextServerStore>,
 973        expected_events: Vec<(ContextServerId, ContextServerStatus)>,
 974        cx: &mut TestAppContext,
 975    ) -> ServerEvents {
 976        cx.update(|cx| {
 977            let mut ix = 0;
 978            let received_event_count = Rc::new(RefCell::new(0));
 979            let expected_event_count = expected_events.len();
 980            let subscription = cx.subscribe(store, {
 981                let received_event_count = received_event_count.clone();
 982                move |_, event, _| match event {
 983                    Event::ServerStatusChanged {
 984                        server_id: actual_server_id,
 985                        status: actual_status,
 986                    } => {
 987                        let (expected_server_id, expected_status) = &expected_events[ix];
 988
 989                        assert_eq!(
 990                            actual_server_id, expected_server_id,
 991                            "Expected different server id at index {}",
 992                            ix
 993                        );
 994                        assert_eq!(
 995                            actual_status, expected_status,
 996                            "Expected different status at index {}",
 997                            ix
 998                        );
 999                        ix += 1;
1000                        *received_event_count.borrow_mut() += 1;
1001                    }
1002                }
1003            });
1004            ServerEvents {
1005                expected_event_count,
1006                received_event_count,
1007                _subscription: subscription,
1008            }
1009        })
1010    }
1011
1012    async fn setup_context_server_test(
1013        cx: &mut TestAppContext,
1014        files: serde_json::Value,
1015        context_server_configurations: Vec<(Arc<str>, ContextServerConfiguration)>,
1016    ) -> (Arc<FakeFs>, Entity<Project>) {
1017        cx.update(|cx| {
1018            let settings_store = SettingsStore::test(cx);
1019            cx.set_global(settings_store);
1020            Project::init_settings(cx);
1021            let mut settings = ProjectSettings::get_global(cx).clone();
1022            for (id, config) in context_server_configurations {
1023                settings.context_servers.insert(id, config);
1024            }
1025            ProjectSettings::override_global(settings, cx);
1026        });
1027
1028        let fs = FakeFs::new(cx.executor());
1029        fs.insert_tree(path!("/test"), files).await;
1030        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1031
1032        (fs, project)
1033    }
1034
1035    fn create_initialize_response(server_name: String) -> serde_json::Value {
1036        serde_json::to_value(&InitializeResponse {
1037            protocol_version: ProtocolVersion(types::LATEST_PROTOCOL_VERSION.to_string()),
1038            server_info: Implementation {
1039                name: server_name,
1040                version: "1.0.0".to_string(),
1041            },
1042            capabilities: ServerCapabilities::default(),
1043            meta: None,
1044        })
1045        .unwrap()
1046    }
1047
1048    struct FakeTransport {
1049        on_request: Arc<
1050            dyn Fn(u64, Option<RequestType>, serde_json::Value) -> Option<serde_json::Value>
1051                + Send
1052                + Sync,
1053        >,
1054        tx: futures::channel::mpsc::UnboundedSender<String>,
1055        rx: Arc<Mutex<futures::channel::mpsc::UnboundedReceiver<String>>>,
1056        executor: BackgroundExecutor,
1057    }
1058
1059    impl FakeTransport {
1060        fn new(
1061            executor: BackgroundExecutor,
1062            on_request: impl Fn(
1063                u64,
1064                Option<RequestType>,
1065                serde_json::Value,
1066            ) -> Option<serde_json::Value>
1067            + 'static
1068            + Send
1069            + Sync,
1070        ) -> Self {
1071            let (tx, rx) = futures::channel::mpsc::unbounded();
1072            Self {
1073                on_request: Arc::new(on_request),
1074                tx,
1075                rx: Arc::new(Mutex::new(rx)),
1076                executor,
1077            }
1078        }
1079    }
1080
1081    #[async_trait::async_trait]
1082    impl Transport for FakeTransport {
1083        async fn send(&self, message: String) -> Result<()> {
1084            if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&message) {
1085                let id = msg.get("id").and_then(|id| id.as_u64()).unwrap_or(0);
1086
1087                if let Some(method) = msg.get("method") {
1088                    let request_type = method
1089                        .as_str()
1090                        .and_then(|method| types::RequestType::try_from(method).ok());
1091                    if let Some(payload) = (self.on_request.as_ref())(id, request_type, msg) {
1092                        let response = serde_json::json!({
1093                            "jsonrpc": "2.0",
1094                            "id": id,
1095                            "result": payload
1096                        });
1097
1098                        self.tx
1099                            .unbounded_send(response.to_string())
1100                            .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
1101                    }
1102                }
1103            }
1104            Ok(())
1105        }
1106
1107        fn receive(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
1108            let rx = self.rx.clone();
1109            let executor = self.executor.clone();
1110            Box::pin(futures::stream::unfold(rx, move |rx| {
1111                let executor = executor.clone();
1112                async move {
1113                    let mut rx_guard = rx.lock().await;
1114                    executor.simulate_random_delay().await;
1115                    if let Some(message) = rx_guard.next().await {
1116                        drop(rx_guard);
1117                        Some((message, rx))
1118                    } else {
1119                        None
1120                    }
1121                }
1122            }))
1123        }
1124
1125        fn receive_err(&self) -> Pin<Box<dyn Stream<Item = String> + Send>> {
1126            Box::pin(futures::stream::empty())
1127        }
1128    }
1129}