context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::path::Path;
   5use std::sync::Arc;
   6use std::time::Duration;
   7
   8use anyhow::{Context as _, Result};
   9use collections::{HashMap, HashSet};
  10use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  11use futures::{FutureExt as _, future::Either, future::join_all};
  12use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  13use itertools::Itertools;
  14use registry::ContextServerDescriptorRegistry;
  15use remote::RemoteClient;
  16use rpc::{AnyProtoClient, TypedEnvelope, proto};
  17use settings::{Settings as _, SettingsStore};
  18use util::{ResultExt as _, rel_path::RelPath};
  19
  20use crate::{
  21    DisableAiSettings, Project,
  22    project_settings::{ContextServerSettings, ProjectSettings},
  23    worktree_store::WorktreeStore,
  24};
  25
  26/// Maximum timeout for context server requests
  27/// Prevents extremely large timeout values from tying up resources indefinitely.
  28const MAX_TIMEOUT_SECS: u64 = 600; // 10 minutes
  29
  30pub fn init(cx: &mut App) {
  31    extension::init(cx);
  32}
  33
  34actions!(
  35    context_server,
  36    [
  37        /// Restarts the context server.
  38        Restart
  39    ]
  40);
  41
  42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  43pub enum ContextServerStatus {
  44    Starting,
  45    Running,
  46    Stopped,
  47    Error(Arc<str>),
  48}
  49
  50impl ContextServerStatus {
  51    fn from_state(state: &ContextServerState) -> Self {
  52        match state {
  53            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  54            ContextServerState::Running { .. } => ContextServerStatus::Running,
  55            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  56            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  57        }
  58    }
  59}
  60
  61enum ContextServerState {
  62    Starting {
  63        server: Arc<ContextServer>,
  64        configuration: Arc<ContextServerConfiguration>,
  65        _task: Task<()>,
  66    },
  67    Running {
  68        server: Arc<ContextServer>,
  69        configuration: Arc<ContextServerConfiguration>,
  70    },
  71    Stopped {
  72        server: Arc<ContextServer>,
  73        configuration: Arc<ContextServerConfiguration>,
  74    },
  75    Error {
  76        server: Arc<ContextServer>,
  77        configuration: Arc<ContextServerConfiguration>,
  78        error: Arc<str>,
  79    },
  80}
  81
  82impl ContextServerState {
  83    pub fn server(&self) -> Arc<ContextServer> {
  84        match self {
  85            ContextServerState::Starting { server, .. } => server.clone(),
  86            ContextServerState::Running { server, .. } => server.clone(),
  87            ContextServerState::Stopped { server, .. } => server.clone(),
  88            ContextServerState::Error { server, .. } => server.clone(),
  89        }
  90    }
  91
  92    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
  93        match self {
  94            ContextServerState::Starting { configuration, .. } => configuration.clone(),
  95            ContextServerState::Running { configuration, .. } => configuration.clone(),
  96            ContextServerState::Stopped { configuration, .. } => configuration.clone(),
  97            ContextServerState::Error { configuration, .. } => configuration.clone(),
  98        }
  99    }
 100}
 101
 102#[derive(Debug, PartialEq, Eq)]
 103pub enum ContextServerConfiguration {
 104    Custom {
 105        command: ContextServerCommand,
 106        remote: bool,
 107    },
 108    Extension {
 109        command: ContextServerCommand,
 110        settings: serde_json::Value,
 111        remote: bool,
 112    },
 113    Http {
 114        url: url::Url,
 115        headers: HashMap<String, String>,
 116        timeout: Option<u64>,
 117    },
 118}
 119
 120impl ContextServerConfiguration {
 121    pub fn command(&self) -> Option<&ContextServerCommand> {
 122        match self {
 123            ContextServerConfiguration::Custom { command, .. } => Some(command),
 124            ContextServerConfiguration::Extension { command, .. } => Some(command),
 125            ContextServerConfiguration::Http { .. } => None,
 126        }
 127    }
 128
 129    pub fn remote(&self) -> bool {
 130        match self {
 131            ContextServerConfiguration::Custom { remote, .. } => *remote,
 132            ContextServerConfiguration::Extension { remote, .. } => *remote,
 133            ContextServerConfiguration::Http { .. } => false,
 134        }
 135    }
 136
 137    pub async fn from_settings(
 138        settings: ContextServerSettings,
 139        id: ContextServerId,
 140        registry: Entity<ContextServerDescriptorRegistry>,
 141        worktree_store: Entity<WorktreeStore>,
 142        cx: &AsyncApp,
 143    ) -> Option<Self> {
 144        const EXTENSION_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
 145
 146        match settings {
 147            ContextServerSettings::Stdio {
 148                enabled: _,
 149                command,
 150                remote,
 151            } => Some(ContextServerConfiguration::Custom { command, remote }),
 152            ContextServerSettings::Extension {
 153                enabled: _,
 154                settings,
 155                remote,
 156            } => {
 157                let descriptor =
 158                    cx.update(|cx| registry.read(cx).context_server_descriptor(&id.0))?;
 159
 160                let command_future = descriptor.command(worktree_store, cx);
 161                let timeout_future = cx.background_executor().timer(EXTENSION_COMMAND_TIMEOUT);
 162
 163                match futures::future::select(command_future, timeout_future).await {
 164                    Either::Left((Ok(command), _)) => Some(ContextServerConfiguration::Extension {
 165                        command,
 166                        settings,
 167                        remote,
 168                    }),
 169                    Either::Left((Err(e), _)) => {
 170                        log::error!(
 171                            "Failed to create context server configuration from settings: {e:#}"
 172                        );
 173                        None
 174                    }
 175                    Either::Right(_) => {
 176                        log::error!(
 177                            "Timed out resolving command for extension context server {id}"
 178                        );
 179                        None
 180                    }
 181                }
 182            }
 183            ContextServerSettings::Http {
 184                enabled: _,
 185                url,
 186                headers: auth,
 187                timeout,
 188            } => {
 189                let url = url::Url::parse(&url).log_err()?;
 190                Some(ContextServerConfiguration::Http {
 191                    url,
 192                    headers: auth,
 193                    timeout,
 194                })
 195            }
 196        }
 197    }
 198}
 199
 200pub type ContextServerFactory =
 201    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 202
 203enum ContextServerStoreState {
 204    Local {
 205        downstream_client: Option<(u64, AnyProtoClient)>,
 206        is_headless: bool,
 207    },
 208    Remote {
 209        project_id: u64,
 210        upstream_client: Entity<RemoteClient>,
 211    },
 212}
 213
 214pub struct ContextServerStore {
 215    state: ContextServerStoreState,
 216    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 217    servers: HashMap<ContextServerId, ContextServerState>,
 218    server_ids: Vec<ContextServerId>,
 219    worktree_store: Entity<WorktreeStore>,
 220    project: Option<WeakEntity<Project>>,
 221    registry: Entity<ContextServerDescriptorRegistry>,
 222    update_servers_task: Option<Task<Result<()>>>,
 223    context_server_factory: Option<ContextServerFactory>,
 224    needs_server_update: bool,
 225    _subscriptions: Vec<Subscription>,
 226}
 227
 228pub struct ServerStatusChangedEvent {
 229    pub server_id: ContextServerId,
 230    pub status: ContextServerStatus,
 231}
 232
 233impl EventEmitter<ServerStatusChangedEvent> for ContextServerStore {}
 234
 235impl ContextServerStore {
 236    pub fn local(
 237        worktree_store: Entity<WorktreeStore>,
 238        weak_project: Option<WeakEntity<Project>>,
 239        headless: bool,
 240        cx: &mut Context<Self>,
 241    ) -> Self {
 242        Self::new_internal(
 243            !headless,
 244            None,
 245            ContextServerDescriptorRegistry::default_global(cx),
 246            worktree_store,
 247            weak_project,
 248            ContextServerStoreState::Local {
 249                downstream_client: None,
 250                is_headless: headless,
 251            },
 252            cx,
 253        )
 254    }
 255
 256    pub fn remote(
 257        project_id: u64,
 258        upstream_client: Entity<RemoteClient>,
 259        worktree_store: Entity<WorktreeStore>,
 260        weak_project: Option<WeakEntity<Project>>,
 261        cx: &mut Context<Self>,
 262    ) -> Self {
 263        Self::new_internal(
 264            true,
 265            None,
 266            ContextServerDescriptorRegistry::default_global(cx),
 267            worktree_store,
 268            weak_project,
 269            ContextServerStoreState::Remote {
 270                project_id,
 271                upstream_client,
 272            },
 273            cx,
 274        )
 275    }
 276
 277    pub fn init_headless(session: &AnyProtoClient) {
 278        session.add_entity_request_handler(Self::handle_get_context_server_command);
 279    }
 280
 281    pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) {
 282        if let ContextServerStoreState::Local {
 283            downstream_client, ..
 284        } = &mut self.state
 285        {
 286            *downstream_client = Some((project_id, client));
 287        }
 288    }
 289
 290    pub fn is_remote_project(&self) -> bool {
 291        matches!(self.state, ContextServerStoreState::Remote { .. })
 292    }
 293
 294    /// Returns all configured context server ids, excluding the ones that are disabled
 295    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 296        self.context_server_settings
 297            .iter()
 298            .filter(|(_, settings)| settings.enabled())
 299            .map(|(id, _)| ContextServerId(id.clone()))
 300            .collect()
 301    }
 302
 303    #[cfg(feature = "test-support")]
 304    pub fn test(
 305        registry: Entity<ContextServerDescriptorRegistry>,
 306        worktree_store: Entity<WorktreeStore>,
 307        weak_project: Option<WeakEntity<Project>>,
 308        cx: &mut Context<Self>,
 309    ) -> Self {
 310        Self::new_internal(
 311            false,
 312            None,
 313            registry,
 314            worktree_store,
 315            weak_project,
 316            ContextServerStoreState::Local {
 317                downstream_client: None,
 318                is_headless: false,
 319            },
 320            cx,
 321        )
 322    }
 323
 324    #[cfg(feature = "test-support")]
 325    pub fn test_maintain_server_loop(
 326        context_server_factory: Option<ContextServerFactory>,
 327        registry: Entity<ContextServerDescriptorRegistry>,
 328        worktree_store: Entity<WorktreeStore>,
 329        weak_project: Option<WeakEntity<Project>>,
 330        cx: &mut Context<Self>,
 331    ) -> Self {
 332        Self::new_internal(
 333            true,
 334            context_server_factory,
 335            registry,
 336            worktree_store,
 337            weak_project,
 338            ContextServerStoreState::Local {
 339                downstream_client: None,
 340                is_headless: false,
 341            },
 342            cx,
 343        )
 344    }
 345
 346    #[cfg(feature = "test-support")]
 347    pub fn set_context_server_factory(&mut self, factory: ContextServerFactory) {
 348        self.context_server_factory = Some(factory);
 349    }
 350
 351    #[cfg(feature = "test-support")]
 352    pub fn registry(&self) -> &Entity<ContextServerDescriptorRegistry> {
 353        &self.registry
 354    }
 355
 356    #[cfg(feature = "test-support")]
 357    pub fn test_start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 358        let configuration = Arc::new(ContextServerConfiguration::Custom {
 359            command: ContextServerCommand {
 360                path: "test".into(),
 361                args: vec![],
 362                env: None,
 363                timeout: None,
 364            },
 365            remote: false,
 366        });
 367        self.run_server(server, configuration, cx);
 368    }
 369
 370    fn new_internal(
 371        maintain_server_loop: bool,
 372        context_server_factory: Option<ContextServerFactory>,
 373        registry: Entity<ContextServerDescriptorRegistry>,
 374        worktree_store: Entity<WorktreeStore>,
 375        weak_project: Option<WeakEntity<Project>>,
 376        state: ContextServerStoreState,
 377        cx: &mut Context<Self>,
 378    ) -> Self {
 379        let mut subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
 380            let settings =
 381                &Self::resolve_project_settings(&this.worktree_store, cx).context_servers;
 382            if &this.context_server_settings == settings {
 383                return;
 384            }
 385            this.context_server_settings = settings.clone();
 386            if maintain_server_loop {
 387                this.available_context_servers_changed(cx);
 388            }
 389        })];
 390
 391        if maintain_server_loop {
 392            subscriptions.push(cx.observe(&registry, |this, _registry, cx| {
 393                this.available_context_servers_changed(cx);
 394            }));
 395        }
 396
 397        let mut this = Self {
 398            state,
 399            _subscriptions: subscriptions,
 400            context_server_settings: Self::resolve_project_settings(&worktree_store, cx)
 401                .context_servers
 402                .clone(),
 403            worktree_store,
 404            project: weak_project,
 405            registry,
 406            needs_server_update: false,
 407            servers: HashMap::default(),
 408            server_ids: Default::default(),
 409            update_servers_task: None,
 410            context_server_factory,
 411        };
 412        if maintain_server_loop {
 413            this.available_context_servers_changed(cx);
 414        }
 415        this
 416    }
 417
 418    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 419        self.servers.get(id).map(|state| state.server())
 420    }
 421
 422    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 423        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 424            Some(server.clone())
 425        } else {
 426            None
 427        }
 428    }
 429
 430    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 431        self.servers.get(id).map(ContextServerStatus::from_state)
 432    }
 433
 434    pub fn configuration_for_server(
 435        &self,
 436        id: &ContextServerId,
 437    ) -> Option<Arc<ContextServerConfiguration>> {
 438        self.servers.get(id).map(|state| state.configuration())
 439    }
 440
 441    /// Returns a sorted slice of available unique context server IDs. Within the
 442    /// slice, context servers which have `mcp-server-` as a prefix in their ID will
 443    /// appear after servers that do not have this prefix in their ID.
 444    pub fn server_ids(&self) -> &[ContextServerId] {
 445        self.server_ids.as_slice()
 446    }
 447
 448    fn populate_server_ids(&mut self, cx: &App) {
 449        self.server_ids = self
 450            .servers
 451            .keys()
 452            .cloned()
 453            .chain(
 454                self.registry
 455                    .read(cx)
 456                    .context_server_descriptors()
 457                    .into_iter()
 458                    .map(|(id, _)| ContextServerId(id)),
 459            )
 460            .chain(
 461                self.context_server_settings
 462                    .keys()
 463                    .map(|id| ContextServerId(id.clone())),
 464            )
 465            .unique()
 466            .sorted_unstable_by(
 467                // Sort context servers: ones without mcp-server- prefix first, then prefixed ones
 468                |a, b| {
 469                    const MCP_PREFIX: &str = "mcp-server-";
 470                    match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) {
 471                        // If one has mcp-server- prefix and other doesn't, non-mcp comes first
 472                        (Some(_), None) => std::cmp::Ordering::Greater,
 473                        (None, Some(_)) => std::cmp::Ordering::Less,
 474                        // If both have same prefix status, sort by appropriate key
 475                        (Some(a), Some(b)) => a.cmp(b),
 476                        (None, None) => a.0.cmp(&b.0),
 477                    }
 478                },
 479            )
 480            .collect();
 481    }
 482
 483    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 484        self.servers
 485            .values()
 486            .filter_map(|state| {
 487                if let ContextServerState::Running { server, .. } = state {
 488                    Some(server.clone())
 489                } else {
 490                    None
 491                }
 492            })
 493            .collect()
 494    }
 495
 496    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 497        cx.spawn(async move |this, cx| {
 498            let this = this.upgrade().context("Context server store dropped")?;
 499            let settings = this
 500                .update(cx, |this, _| {
 501                    this.context_server_settings.get(&server.id().0).cloned()
 502                })
 503                .context("Failed to get context server settings")?;
 504
 505            if !settings.enabled() {
 506                return anyhow::Ok(());
 507            }
 508
 509            let (registry, worktree_store) = this.update(cx, |this, _| {
 510                (this.registry.clone(), this.worktree_store.clone())
 511            });
 512            let configuration = ContextServerConfiguration::from_settings(
 513                settings,
 514                server.id(),
 515                registry,
 516                worktree_store,
 517                cx,
 518            )
 519            .await
 520            .context("Failed to create context server configuration")?;
 521
 522            this.update(cx, |this, cx| {
 523                this.run_server(server, Arc::new(configuration), cx)
 524            });
 525            Ok(())
 526        })
 527        .detach_and_log_err(cx);
 528    }
 529
 530    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 531        if matches!(
 532            self.servers.get(id),
 533            Some(ContextServerState::Stopped { .. })
 534        ) {
 535            return Ok(());
 536        }
 537
 538        let state = self
 539            .servers
 540            .remove(id)
 541            .context("Context server not found")?;
 542
 543        let server = state.server();
 544        let configuration = state.configuration();
 545        let mut result = Ok(());
 546        if let ContextServerState::Running { server, .. } = &state {
 547            result = server.stop();
 548        }
 549        drop(state);
 550
 551        self.update_server_state(
 552            id.clone(),
 553            ContextServerState::Stopped {
 554                configuration,
 555                server,
 556            },
 557            cx,
 558        );
 559
 560        result
 561    }
 562
 563    pub fn stop_all_servers(&mut self, cx: &mut Context<Self>) {
 564        let server_ids: Vec<_> = self.servers.keys().cloned().collect();
 565        for id in server_ids {
 566            self.stop_server(&id, cx).log_err();
 567        }
 568    }
 569
 570    fn run_server(
 571        &mut self,
 572        server: Arc<ContextServer>,
 573        configuration: Arc<ContextServerConfiguration>,
 574        cx: &mut Context<Self>,
 575    ) {
 576        let id = server.id();
 577        if matches!(
 578            self.servers.get(&id),
 579            Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
 580        ) {
 581            self.stop_server(&id, cx).log_err();
 582        }
 583        let task = cx.spawn({
 584            let id = server.id();
 585            let server = server.clone();
 586            let configuration = configuration.clone();
 587
 588            async move |this, cx| {
 589                match server.clone().start(cx).await {
 590                    Ok(_) => {
 591                        debug_assert!(server.client().is_some());
 592
 593                        this.update(cx, |this, cx| {
 594                            this.update_server_state(
 595                                id.clone(),
 596                                ContextServerState::Running {
 597                                    server,
 598                                    configuration,
 599                                },
 600                                cx,
 601                            )
 602                        })
 603                        .log_err()
 604                    }
 605                    Err(err) => {
 606                        log::error!("{} context server failed to start: {}", id, err);
 607                        this.update(cx, |this, cx| {
 608                            this.update_server_state(
 609                                id.clone(),
 610                                ContextServerState::Error {
 611                                    configuration,
 612                                    server,
 613                                    error: err.to_string().into(),
 614                                },
 615                                cx,
 616                            )
 617                        })
 618                        .log_err()
 619                    }
 620                };
 621            }
 622        });
 623
 624        self.update_server_state(
 625            id.clone(),
 626            ContextServerState::Starting {
 627                configuration,
 628                _task: task,
 629                server,
 630            },
 631            cx,
 632        );
 633    }
 634
 635    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 636        let state = self
 637            .servers
 638            .remove(id)
 639            .context("Context server not found")?;
 640        drop(state);
 641        cx.emit(ServerStatusChangedEvent {
 642            server_id: id.clone(),
 643            status: ContextServerStatus::Stopped,
 644        });
 645        Ok(())
 646    }
 647
 648    pub async fn create_context_server(
 649        this: WeakEntity<Self>,
 650        id: ContextServerId,
 651        configuration: Arc<ContextServerConfiguration>,
 652        cx: &mut AsyncApp,
 653    ) -> Result<(Arc<ContextServer>, Arc<ContextServerConfiguration>)> {
 654        let remote = configuration.remote();
 655        let needs_remote_command = match configuration.as_ref() {
 656            ContextServerConfiguration::Custom { .. }
 657            | ContextServerConfiguration::Extension { .. } => remote,
 658            ContextServerConfiguration::Http { .. } => false,
 659        };
 660
 661        let (remote_state, is_remote_project) = this.update(cx, |this, _| {
 662            let remote_state = match &this.state {
 663                ContextServerStoreState::Remote {
 664                    project_id,
 665                    upstream_client,
 666                } if needs_remote_command => Some((*project_id, upstream_client.clone())),
 667                _ => None,
 668            };
 669            (remote_state, this.is_remote_project())
 670        })?;
 671
 672        let root_path: Option<Arc<Path>> = this.update(cx, |this, cx| {
 673            this.project
 674                .as_ref()
 675                .and_then(|project| {
 676                    project
 677                        .read_with(cx, |project, cx| project.active_project_directory(cx))
 678                        .ok()
 679                        .flatten()
 680                })
 681                .or_else(|| {
 682                    this.worktree_store.read_with(cx, |store, cx| {
 683                        store.visible_worktrees(cx).fold(None, |acc, item| {
 684                            if acc.is_none() {
 685                                item.read(cx).root_dir()
 686                            } else {
 687                                acc
 688                            }
 689                        })
 690                    })
 691                })
 692        })?;
 693
 694        let configuration = if let Some((project_id, upstream_client)) = remote_state {
 695            let root_dir = root_path.as_ref().map(|p| p.display().to_string());
 696
 697            let response = upstream_client
 698                .update(cx, |client, _| {
 699                    client
 700                        .proto_client()
 701                        .request(proto::GetContextServerCommand {
 702                            project_id,
 703                            server_id: id.0.to_string(),
 704                            root_dir: root_dir.clone(),
 705                        })
 706                })
 707                .await?;
 708
 709            let remote_command = upstream_client.update(cx, |client, _| {
 710                client.build_command(
 711                    Some(response.path),
 712                    &response.args,
 713                    &response.env.into_iter().collect(),
 714                    root_dir,
 715                    None,
 716                )
 717            })?;
 718
 719            let command = ContextServerCommand {
 720                path: remote_command.program.into(),
 721                args: remote_command.args,
 722                env: Some(remote_command.env.into_iter().collect()),
 723                timeout: None,
 724            };
 725
 726            Arc::new(ContextServerConfiguration::Custom { command, remote })
 727        } else {
 728            configuration
 729        };
 730
 731        let server: Arc<ContextServer> = this.update(cx, |this, cx| {
 732            let global_timeout =
 733                Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
 734
 735            if let Some(factory) = this.context_server_factory.as_ref() {
 736                return anyhow::Ok(factory(id.clone(), configuration.clone()));
 737            }
 738
 739            match configuration.as_ref() {
 740                ContextServerConfiguration::Http {
 741                    url,
 742                    headers,
 743                    timeout,
 744                } => anyhow::Ok(Arc::new(ContextServer::http(
 745                    id,
 746                    url,
 747                    headers.clone(),
 748                    cx.http_client(),
 749                    cx.background_executor().clone(),
 750                    Some(Duration::from_secs(
 751                        timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
 752                    )),
 753                )?)),
 754                _ => {
 755                    let mut command = configuration
 756                        .command()
 757                        .context("Missing command configuration for stdio context server")?
 758                        .clone();
 759                    command.timeout = Some(
 760                        command
 761                            .timeout
 762                            .unwrap_or(global_timeout)
 763                            .min(MAX_TIMEOUT_SECS),
 764                    );
 765
 766                    // Don't pass remote paths as working directory for locally-spawned processes
 767                    let working_directory = if is_remote_project { None } else { root_path };
 768                    anyhow::Ok(Arc::new(ContextServer::stdio(
 769                        id,
 770                        command,
 771                        working_directory,
 772                    )))
 773                }
 774            }
 775        })??;
 776
 777        Ok((server, configuration))
 778    }
 779
 780    async fn handle_get_context_server_command(
 781        this: Entity<Self>,
 782        envelope: TypedEnvelope<proto::GetContextServerCommand>,
 783        mut cx: AsyncApp,
 784    ) -> Result<proto::ContextServerCommand> {
 785        let server_id = ContextServerId(envelope.payload.server_id.into());
 786
 787        let (settings, registry, worktree_store) = this.update(&mut cx, |this, inner_cx| {
 788            let ContextServerStoreState::Local {
 789                is_headless: true, ..
 790            } = &this.state
 791            else {
 792                anyhow::bail!("unexpected GetContextServerCommand request in a non-local project");
 793            };
 794
 795            let settings = this
 796                .context_server_settings
 797                .get(&server_id.0)
 798                .cloned()
 799                .or_else(|| {
 800                    this.registry
 801                        .read(inner_cx)
 802                        .context_server_descriptor(&server_id.0)
 803                        .map(|_| ContextServerSettings::default_extension())
 804                })
 805                .with_context(|| format!("context server `{}` not found", server_id))?;
 806
 807            anyhow::Ok((settings, this.registry.clone(), this.worktree_store.clone()))
 808        })?;
 809
 810        let configuration = ContextServerConfiguration::from_settings(
 811            settings,
 812            server_id.clone(),
 813            registry,
 814            worktree_store,
 815            &cx,
 816        )
 817        .await
 818        .with_context(|| format!("failed to build configuration for `{}`", server_id))?;
 819
 820        let command = configuration
 821            .command()
 822            .context("context server has no command (HTTP servers don't need RPC)")?;
 823
 824        Ok(proto::ContextServerCommand {
 825            path: command.path.display().to_string(),
 826            args: command.args.clone(),
 827            env: command
 828                .env
 829                .clone()
 830                .map(|env| env.into_iter().collect())
 831                .unwrap_or_default(),
 832        })
 833    }
 834
 835    fn resolve_project_settings<'a>(
 836        worktree_store: &'a Entity<WorktreeStore>,
 837        cx: &'a App,
 838    ) -> &'a ProjectSettings {
 839        let location = worktree_store
 840            .read(cx)
 841            .visible_worktrees(cx)
 842            .next()
 843            .map(|worktree| settings::SettingsLocation {
 844                worktree_id: worktree.read(cx).id(),
 845                path: RelPath::empty(),
 846            });
 847        ProjectSettings::get(location, cx)
 848    }
 849
 850    fn update_server_state(
 851        &mut self,
 852        id: ContextServerId,
 853        state: ContextServerState,
 854        cx: &mut Context<Self>,
 855    ) {
 856        let status = ContextServerStatus::from_state(&state);
 857        self.servers.insert(id.clone(), state);
 858        cx.emit(ServerStatusChangedEvent {
 859            server_id: id,
 860            status,
 861        });
 862    }
 863
 864    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
 865        if self.update_servers_task.is_some() {
 866            self.needs_server_update = true;
 867        } else {
 868            self.needs_server_update = false;
 869            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
 870                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
 871                    log::error!("Error maintaining context servers: {}", err);
 872                }
 873
 874                this.update(cx, |this, cx| {
 875                    this.populate_server_ids(cx);
 876                    cx.notify();
 877                    this.update_servers_task.take();
 878                    if this.needs_server_update {
 879                        this.available_context_servers_changed(cx);
 880                    }
 881                })?;
 882
 883                Ok(())
 884            }));
 885        }
 886    }
 887
 888    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 889        // Don't start context servers if AI is disabled
 890        let ai_disabled = this.update(cx, |_, cx| DisableAiSettings::get_global(cx).disable_ai)?;
 891        if ai_disabled {
 892            // Stop all running servers when AI is disabled
 893            this.update(cx, |this, cx| {
 894                let server_ids: Vec<_> = this.servers.keys().cloned().collect();
 895                for id in server_ids {
 896                    let _ = this.stop_server(&id, cx);
 897                }
 898            })?;
 899            return Ok(());
 900        }
 901
 902        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
 903            (
 904                this.context_server_settings.clone(),
 905                this.registry.clone(),
 906                this.worktree_store.clone(),
 907            )
 908        })?;
 909
 910        for (id, _) in registry.read_with(cx, |registry, _| registry.context_server_descriptors()) {
 911            configured_servers
 912                .entry(id)
 913                .or_insert(ContextServerSettings::default_extension());
 914        }
 915
 916        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
 917            configured_servers
 918                .into_iter()
 919                .partition(|(_, settings)| settings.enabled());
 920
 921        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
 922            let id = ContextServerId(id);
 923            ContextServerConfiguration::from_settings(
 924                settings,
 925                id.clone(),
 926                registry.clone(),
 927                worktree_store.clone(),
 928                cx,
 929            )
 930            .map(move |config| (id, config))
 931        }))
 932        .await
 933        .into_iter()
 934        .filter_map(|(id, config)| config.map(|config| (id, config)))
 935        .collect::<HashMap<_, _>>();
 936
 937        let mut servers_to_start = Vec::new();
 938        let mut servers_to_remove = HashSet::default();
 939        let mut servers_to_stop = HashSet::default();
 940
 941        this.update(cx, |this, _cx| {
 942            for server_id in this.servers.keys() {
 943                // All servers that are not in desired_servers should be removed from the store.
 944                // This can happen if the user removed a server from the context server settings.
 945                if !configured_servers.contains_key(server_id) {
 946                    if disabled_servers.contains_key(&server_id.0) {
 947                        servers_to_stop.insert(server_id.clone());
 948                    } else {
 949                        servers_to_remove.insert(server_id.clone());
 950                    }
 951                }
 952            }
 953
 954            for (id, config) in configured_servers {
 955                let state = this.servers.get(&id);
 956                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
 957                let existing_config = state.as_ref().map(|state| state.configuration());
 958                if existing_config.as_deref() != Some(&config) || is_stopped {
 959                    let config = Arc::new(config);
 960                    servers_to_start.push((id.clone(), config));
 961                    if this.servers.contains_key(&id) {
 962                        servers_to_stop.insert(id);
 963                    }
 964                }
 965            }
 966
 967            anyhow::Ok(())
 968        })??;
 969
 970        this.update(cx, |this, inner_cx| {
 971            for id in servers_to_stop {
 972                this.stop_server(&id, inner_cx)?;
 973            }
 974            for id in servers_to_remove {
 975                this.remove_server(&id, inner_cx)?;
 976            }
 977            anyhow::Ok(())
 978        })??;
 979
 980        for (id, config) in servers_to_start {
 981            match Self::create_context_server(this.clone(), id.clone(), config, cx).await {
 982                Ok((server, config)) => {
 983                    this.update(cx, |this, cx| {
 984                        this.run_server(server, config, cx);
 985                    })?;
 986                }
 987                Err(err) => {
 988                    log::error!("{id} context server failed to create: {err:#}");
 989                    this.update(cx, |_this, cx| {
 990                        cx.emit(ServerStatusChangedEvent {
 991                            server_id: id,
 992                            status: ContextServerStatus::Error(err.to_string().into()),
 993                        });
 994                        cx.notify();
 995                    })?;
 996                }
 997            }
 998        }
 999
1000        Ok(())
1001    }
1002}