context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::path::Path;
   5use std::sync::Arc;
   6use std::time::Duration;
   7
   8use anyhow::{Context as _, Result};
   9use collections::{HashMap, HashSet};
  10use context_server::oauth::{self, McpOAuthTokenProvider, OAuthDiscovery, OAuthSession};
  11use context_server::transport::{HttpTransport, TransportError};
  12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  13use credentials_provider::CredentialsProvider;
  14use futures::future::Either;
  15use futures::{FutureExt as _, StreamExt as _, future::join_all};
  16use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  17use http_client::HttpClient;
  18use itertools::Itertools;
  19use rand::Rng as _;
  20use registry::ContextServerDescriptorRegistry;
  21use remote::RemoteClient;
  22use rpc::{AnyProtoClient, TypedEnvelope, proto};
  23use settings::{Settings as _, SettingsStore};
  24use util::{ResultExt as _, rel_path::RelPath};
  25
  26use crate::{
  27    DisableAiSettings, Project,
  28    project_settings::{ContextServerSettings, OAuthClientSettings, ProjectSettings},
  29    worktree_store::WorktreeStore,
  30};
  31
  32/// Maximum timeout for context server requests
  33/// Prevents extremely large timeout values from tying up resources indefinitely.
  34const MAX_TIMEOUT_SECS: u64 = 600; // 10 minutes
  35
  36pub fn init(cx: &mut App) {
  37    extension::init(cx);
  38}
  39
  40actions!(
  41    context_server,
  42    [
  43        /// Restarts the context server.
  44        Restart
  45    ]
  46);
  47
  48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  49pub enum ContextServerStatus {
  50    Starting,
  51    Running,
  52    Stopped,
  53    Error(Arc<str>),
  54    /// The server returned 401 and OAuth authorization is needed. The UI
  55    /// should show an "Authenticate" button.
  56    AuthRequired,
  57    /// The server has a pre-registered OAuth client_id, but a client_secret
  58    /// is needed and not available in settings or the keychain. The UI should
  59    /// show a text input to collect it.
  60    ClientSecretRequired,
  61    /// The OAuth browser flow is in progress — the user has been redirected
  62    /// to the authorization server and we're waiting for the callback.
  63    Authenticating,
  64}
  65
  66impl ContextServerStatus {
  67    fn from_state(state: &ContextServerState) -> Self {
  68        match state {
  69            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  70            ContextServerState::Running { .. } => ContextServerStatus::Running,
  71            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  72            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  73            ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
  74            ContextServerState::ClientSecretRequired { .. } => {
  75                ContextServerStatus::ClientSecretRequired
  76            }
  77            ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
  78        }
  79    }
  80}
  81
  82enum ContextServerState {
  83    Starting {
  84        server: Arc<ContextServer>,
  85        configuration: Arc<ContextServerConfiguration>,
  86        _task: Task<()>,
  87    },
  88    Running {
  89        server: Arc<ContextServer>,
  90        configuration: Arc<ContextServerConfiguration>,
  91    },
  92    Stopped {
  93        server: Arc<ContextServer>,
  94        configuration: Arc<ContextServerConfiguration>,
  95    },
  96    Error {
  97        server: Arc<ContextServer>,
  98        configuration: Arc<ContextServerConfiguration>,
  99        error: Arc<str>,
 100    },
 101    /// The server requires OAuth authorization before it can be used. The
 102    /// `OAuthDiscovery` holds everything needed to start the browser flow.
 103    AuthRequired {
 104        server: Arc<ContextServer>,
 105        configuration: Arc<ContextServerConfiguration>,
 106        discovery: Arc<OAuthDiscovery>,
 107    },
 108    /// A pre-registered client_id is configured but no client_secret was found
 109    /// in settings or the keychain. The user needs to provide it interactively.
 110    ClientSecretRequired {
 111        server: Arc<ContextServer>,
 112        configuration: Arc<ContextServerConfiguration>,
 113        discovery: Arc<OAuthDiscovery>,
 114    },
 115    /// The OAuth browser flow is in progress. The user has been redirected
 116    /// to the authorization server and we're waiting for the callback.
 117    Authenticating {
 118        server: Arc<ContextServer>,
 119        configuration: Arc<ContextServerConfiguration>,
 120        _task: Task<()>,
 121    },
 122}
 123
 124impl ContextServerState {
 125    pub fn server(&self) -> Arc<ContextServer> {
 126        match self {
 127            ContextServerState::Starting { server, .. }
 128            | ContextServerState::Running { server, .. }
 129            | ContextServerState::Stopped { server, .. }
 130            | ContextServerState::Error { server, .. }
 131            | ContextServerState::AuthRequired { server, .. }
 132            | ContextServerState::ClientSecretRequired { server, .. }
 133            | ContextServerState::Authenticating { server, .. } => server.clone(),
 134        }
 135    }
 136
 137    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
 138        match self {
 139            ContextServerState::Starting { configuration, .. }
 140            | ContextServerState::Running { configuration, .. }
 141            | ContextServerState::Stopped { configuration, .. }
 142            | ContextServerState::Error { configuration, .. }
 143            | ContextServerState::AuthRequired { configuration, .. }
 144            | ContextServerState::ClientSecretRequired { configuration, .. }
 145            | ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
 146        }
 147    }
 148}
 149
 150#[derive(Debug, PartialEq, Eq)]
 151pub enum ContextServerConfiguration {
 152    Custom {
 153        command: ContextServerCommand,
 154        remote: bool,
 155    },
 156    Extension {
 157        command: ContextServerCommand,
 158        settings: serde_json::Value,
 159        remote: bool,
 160    },
 161    Http {
 162        url: url::Url,
 163        headers: HashMap<String, String>,
 164        timeout: Option<u64>,
 165        oauth: Option<OAuthClientSettings>,
 166    },
 167}
 168
 169impl ContextServerConfiguration {
 170    pub fn command(&self) -> Option<&ContextServerCommand> {
 171        match self {
 172            ContextServerConfiguration::Custom { command, .. } => Some(command),
 173            ContextServerConfiguration::Extension { command, .. } => Some(command),
 174            ContextServerConfiguration::Http { .. } => None,
 175        }
 176    }
 177
 178    pub fn has_static_auth_header(&self) -> bool {
 179        match self {
 180            ContextServerConfiguration::Http { headers, .. } => headers
 181                .keys()
 182                .any(|k| k.eq_ignore_ascii_case("authorization")),
 183            _ => false,
 184        }
 185    }
 186
 187    pub fn remote(&self) -> bool {
 188        match self {
 189            ContextServerConfiguration::Custom { remote, .. } => *remote,
 190            ContextServerConfiguration::Extension { remote, .. } => *remote,
 191            ContextServerConfiguration::Http { .. } => false,
 192        }
 193    }
 194
 195    pub async fn from_settings(
 196        settings: ContextServerSettings,
 197        id: ContextServerId,
 198        registry: Entity<ContextServerDescriptorRegistry>,
 199        worktree_store: Entity<WorktreeStore>,
 200        cx: &AsyncApp,
 201    ) -> Option<Self> {
 202        const EXTENSION_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
 203
 204        match settings {
 205            ContextServerSettings::Stdio {
 206                enabled: _,
 207                command,
 208                remote,
 209            } => Some(ContextServerConfiguration::Custom { command, remote }),
 210            ContextServerSettings::Extension {
 211                enabled: _,
 212                settings,
 213                remote,
 214            } => {
 215                let descriptor =
 216                    cx.update(|cx| registry.read(cx).context_server_descriptor(&id.0))?;
 217
 218                let command_future = descriptor.command(worktree_store, cx);
 219                let timeout_future = cx.background_executor().timer(EXTENSION_COMMAND_TIMEOUT);
 220
 221                match futures::future::select(command_future, timeout_future).await {
 222                    Either::Left((Ok(command), _)) => Some(ContextServerConfiguration::Extension {
 223                        command,
 224                        settings,
 225                        remote,
 226                    }),
 227                    Either::Left((Err(e), _)) => {
 228                        log::error!(
 229                            "Failed to create context server configuration from settings: {e:#}"
 230                        );
 231                        None
 232                    }
 233                    Either::Right(_) => {
 234                        log::error!(
 235                            "Timed out resolving command for extension context server {id}"
 236                        );
 237                        None
 238                    }
 239                }
 240            }
 241            ContextServerSettings::Http {
 242                enabled: _,
 243                url,
 244                headers: auth,
 245                timeout,
 246                oauth,
 247            } => {
 248                let url = url::Url::parse(&url).log_err()?;
 249                Some(ContextServerConfiguration::Http {
 250                    url,
 251                    headers: auth,
 252                    timeout,
 253                    oauth,
 254                })
 255            }
 256        }
 257    }
 258}
 259
 260pub type ContextServerFactory =
 261    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 262
 263enum ContextServerStoreState {
 264    Local {
 265        downstream_client: Option<(u64, AnyProtoClient)>,
 266        is_headless: bool,
 267    },
 268    Remote {
 269        project_id: u64,
 270        upstream_client: Entity<RemoteClient>,
 271    },
 272}
 273
 274pub struct ContextServerStore {
 275    state: ContextServerStoreState,
 276    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 277    servers: HashMap<ContextServerId, ContextServerState>,
 278    server_ids: Vec<ContextServerId>,
 279    worktree_store: Entity<WorktreeStore>,
 280    project: Option<WeakEntity<Project>>,
 281    registry: Entity<ContextServerDescriptorRegistry>,
 282    update_servers_task: Option<Task<Result<()>>>,
 283    context_server_factory: Option<ContextServerFactory>,
 284    needs_server_update: bool,
 285    ai_disabled: bool,
 286    _subscriptions: Vec<Subscription>,
 287}
 288
 289pub struct ServerStatusChangedEvent {
 290    pub server_id: ContextServerId,
 291    pub status: ContextServerStatus,
 292}
 293
 294impl EventEmitter<ServerStatusChangedEvent> for ContextServerStore {}
 295
 296impl ContextServerStore {
 297    pub fn local(
 298        worktree_store: Entity<WorktreeStore>,
 299        weak_project: Option<WeakEntity<Project>>,
 300        headless: bool,
 301        cx: &mut Context<Self>,
 302    ) -> Self {
 303        Self::new_internal(
 304            !headless,
 305            None,
 306            ContextServerDescriptorRegistry::default_global(cx),
 307            worktree_store,
 308            weak_project,
 309            ContextServerStoreState::Local {
 310                downstream_client: None,
 311                is_headless: headless,
 312            },
 313            cx,
 314        )
 315    }
 316
 317    pub fn remote(
 318        project_id: u64,
 319        upstream_client: Entity<RemoteClient>,
 320        worktree_store: Entity<WorktreeStore>,
 321        weak_project: Option<WeakEntity<Project>>,
 322        cx: &mut Context<Self>,
 323    ) -> Self {
 324        Self::new_internal(
 325            true,
 326            None,
 327            ContextServerDescriptorRegistry::default_global(cx),
 328            worktree_store,
 329            weak_project,
 330            ContextServerStoreState::Remote {
 331                project_id,
 332                upstream_client,
 333            },
 334            cx,
 335        )
 336    }
 337
 338    pub fn init_headless(session: &AnyProtoClient) {
 339        session.add_entity_request_handler(Self::handle_get_context_server_command);
 340    }
 341
 342    pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) {
 343        if let ContextServerStoreState::Local {
 344            downstream_client, ..
 345        } = &mut self.state
 346        {
 347            *downstream_client = Some((project_id, client));
 348        }
 349    }
 350
 351    pub fn is_remote_project(&self) -> bool {
 352        matches!(self.state, ContextServerStoreState::Remote { .. })
 353    }
 354
 355    /// Returns all configured context server ids, excluding the ones that are disabled
 356    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 357        self.context_server_settings
 358            .iter()
 359            .filter(|(_, settings)| settings.enabled())
 360            .map(|(id, _)| ContextServerId(id.clone()))
 361            .collect()
 362    }
 363
 364    #[cfg(feature = "test-support")]
 365    pub fn test(
 366        registry: Entity<ContextServerDescriptorRegistry>,
 367        worktree_store: Entity<WorktreeStore>,
 368        weak_project: Option<WeakEntity<Project>>,
 369        cx: &mut Context<Self>,
 370    ) -> Self {
 371        Self::new_internal(
 372            false,
 373            None,
 374            registry,
 375            worktree_store,
 376            weak_project,
 377            ContextServerStoreState::Local {
 378                downstream_client: None,
 379                is_headless: false,
 380            },
 381            cx,
 382        )
 383    }
 384
 385    #[cfg(feature = "test-support")]
 386    pub fn test_maintain_server_loop(
 387        context_server_factory: Option<ContextServerFactory>,
 388        registry: Entity<ContextServerDescriptorRegistry>,
 389        worktree_store: Entity<WorktreeStore>,
 390        weak_project: Option<WeakEntity<Project>>,
 391        cx: &mut Context<Self>,
 392    ) -> Self {
 393        Self::new_internal(
 394            true,
 395            context_server_factory,
 396            registry,
 397            worktree_store,
 398            weak_project,
 399            ContextServerStoreState::Local {
 400                downstream_client: None,
 401                is_headless: false,
 402            },
 403            cx,
 404        )
 405    }
 406
 407    #[cfg(feature = "test-support")]
 408    pub fn set_context_server_factory(&mut self, factory: ContextServerFactory) {
 409        self.context_server_factory = Some(factory);
 410    }
 411
 412    #[cfg(feature = "test-support")]
 413    pub fn registry(&self) -> &Entity<ContextServerDescriptorRegistry> {
 414        &self.registry
 415    }
 416
 417    #[cfg(feature = "test-support")]
 418    pub fn test_start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 419        let configuration = Arc::new(ContextServerConfiguration::Custom {
 420            command: ContextServerCommand {
 421                path: "test".into(),
 422                args: vec![],
 423                env: None,
 424                timeout: None,
 425            },
 426            remote: false,
 427        });
 428        self.run_server(server, configuration, cx);
 429    }
 430
 431    fn new_internal(
 432        maintain_server_loop: bool,
 433        context_server_factory: Option<ContextServerFactory>,
 434        registry: Entity<ContextServerDescriptorRegistry>,
 435        worktree_store: Entity<WorktreeStore>,
 436        weak_project: Option<WeakEntity<Project>>,
 437        state: ContextServerStoreState,
 438        cx: &mut Context<Self>,
 439    ) -> Self {
 440        let mut subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
 441            let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
 442            let ai_was_disabled = this.ai_disabled;
 443            this.ai_disabled = ai_disabled;
 444
 445            let settings =
 446                &Self::resolve_project_settings(&this.worktree_store, cx).context_servers;
 447            let settings_changed = &this.context_server_settings != settings;
 448
 449            if settings_changed {
 450                this.context_server_settings = settings.clone();
 451            }
 452
 453            // When AI is disabled, stop all running servers
 454            if ai_disabled {
 455                let server_ids: Vec<_> = this.servers.keys().cloned().collect();
 456                for id in server_ids {
 457                    this.stop_server(&id, cx).log_err();
 458                }
 459                return;
 460            }
 461
 462            // Trigger updates if AI was re-enabled or settings changed
 463            if maintain_server_loop && (ai_was_disabled || settings_changed) {
 464                this.available_context_servers_changed(cx);
 465            }
 466        })];
 467
 468        if maintain_server_loop {
 469            subscriptions.push(cx.observe(&registry, |this, _registry, cx| {
 470                if !DisableAiSettings::get_global(cx).disable_ai {
 471                    this.available_context_servers_changed(cx);
 472                }
 473            }));
 474        }
 475
 476        let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
 477        let mut this = Self {
 478            state,
 479            _subscriptions: subscriptions,
 480            context_server_settings: Self::resolve_project_settings(&worktree_store, cx)
 481                .context_servers
 482                .clone(),
 483            worktree_store,
 484            project: weak_project,
 485            registry,
 486            needs_server_update: false,
 487            ai_disabled,
 488            servers: HashMap::default(),
 489            server_ids: Default::default(),
 490            update_servers_task: None,
 491            context_server_factory,
 492        };
 493        if maintain_server_loop && !DisableAiSettings::get_global(cx).disable_ai {
 494            this.available_context_servers_changed(cx);
 495        }
 496        this
 497    }
 498
 499    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 500        self.servers.get(id).map(|state| state.server())
 501    }
 502
 503    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 504        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 505            Some(server.clone())
 506        } else {
 507            None
 508        }
 509    }
 510
 511    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 512        self.servers.get(id).map(ContextServerStatus::from_state)
 513    }
 514
 515    pub fn configuration_for_server(
 516        &self,
 517        id: &ContextServerId,
 518    ) -> Option<Arc<ContextServerConfiguration>> {
 519        self.servers.get(id).map(|state| state.configuration())
 520    }
 521
 522    /// Returns a sorted slice of available unique context server IDs. Within the
 523    /// slice, context servers which have `mcp-server-` as a prefix in their ID will
 524    /// appear after servers that do not have this prefix in their ID.
 525    pub fn server_ids(&self) -> &[ContextServerId] {
 526        self.server_ids.as_slice()
 527    }
 528
 529    fn populate_server_ids(&mut self, cx: &App) {
 530        self.server_ids = self
 531            .servers
 532            .keys()
 533            .cloned()
 534            .chain(
 535                self.registry
 536                    .read(cx)
 537                    .context_server_descriptors()
 538                    .into_iter()
 539                    .map(|(id, _)| ContextServerId(id)),
 540            )
 541            .chain(
 542                self.context_server_settings
 543                    .keys()
 544                    .map(|id| ContextServerId(id.clone())),
 545            )
 546            .unique()
 547            .sorted_unstable_by(
 548                // Sort context servers: ones without mcp-server- prefix first, then prefixed ones
 549                |a, b| {
 550                    const MCP_PREFIX: &str = "mcp-server-";
 551                    match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) {
 552                        // If one has mcp-server- prefix and other doesn't, non-mcp comes first
 553                        (Some(_), None) => std::cmp::Ordering::Greater,
 554                        (None, Some(_)) => std::cmp::Ordering::Less,
 555                        // If both have same prefix status, sort by appropriate key
 556                        (Some(a), Some(b)) => a.cmp(b),
 557                        (None, None) => a.0.cmp(&b.0),
 558                    }
 559                },
 560            )
 561            .collect();
 562    }
 563
 564    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 565        self.servers
 566            .values()
 567            .filter_map(|state| {
 568                if let ContextServerState::Running { server, .. } = state {
 569                    Some(server.clone())
 570                } else {
 571                    None
 572                }
 573            })
 574            .collect()
 575    }
 576
 577    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 578        cx.spawn(async move |this, cx| {
 579            let this = this.upgrade().context("Context server store dropped")?;
 580            let id = server.id();
 581            let settings = this
 582                .update(cx, |this, _| {
 583                    this.context_server_settings.get(&id.0).cloned()
 584                })
 585                .context("Failed to get context server settings")?;
 586
 587            if !settings.enabled() {
 588                return anyhow::Ok(());
 589            }
 590
 591            let (registry, worktree_store) = this.update(cx, |this, _| {
 592                (this.registry.clone(), this.worktree_store.clone())
 593            });
 594            let configuration = ContextServerConfiguration::from_settings(
 595                settings,
 596                id.clone(),
 597                registry,
 598                worktree_store,
 599                cx,
 600            )
 601            .await
 602            .context("Failed to create context server configuration")?;
 603
 604            this.update(cx, |this, cx| {
 605                this.run_server(server, Arc::new(configuration), cx)
 606            });
 607            Ok(())
 608        })
 609        .detach_and_log_err(cx);
 610    }
 611
 612    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 613        if matches!(
 614            self.servers.get(id),
 615            Some(ContextServerState::Stopped { .. })
 616        ) {
 617            return Ok(());
 618        }
 619
 620        let state = self
 621            .servers
 622            .remove(id)
 623            .context("Context server not found")?;
 624
 625        let server = state.server();
 626        let configuration = state.configuration();
 627        let mut result = Ok(());
 628        if let ContextServerState::Running { server, .. } = &state {
 629            result = server.stop();
 630        }
 631        drop(state);
 632
 633        self.update_server_state(
 634            id.clone(),
 635            ContextServerState::Stopped {
 636                configuration,
 637                server,
 638            },
 639            cx,
 640        );
 641
 642        result
 643    }
 644
 645    fn run_server(
 646        &mut self,
 647        server: Arc<ContextServer>,
 648        configuration: Arc<ContextServerConfiguration>,
 649        cx: &mut Context<Self>,
 650    ) {
 651        let id = server.id();
 652        if matches!(
 653            self.servers.get(&id),
 654            Some(
 655                ContextServerState::Starting { .. }
 656                    | ContextServerState::Running { .. }
 657                    | ContextServerState::Authenticating { .. },
 658            )
 659        ) {
 660            self.stop_server(&id, cx).log_err();
 661        }
 662        let task = cx.spawn({
 663            let id = server.id();
 664            let server = server.clone();
 665            let configuration = configuration.clone();
 666
 667            async move |this, cx| {
 668                let new_state = match server.clone().start(cx).await {
 669                    Ok(_) => {
 670                        debug_assert!(server.client().is_some());
 671                        ContextServerState::Running {
 672                            server,
 673                            configuration,
 674                        }
 675                    }
 676                    Err(err) => resolve_start_failure(&id, err, server, configuration, cx).await,
 677                };
 678                this.update(cx, |this, cx| {
 679                    this.update_server_state(id.clone(), new_state, cx)
 680                })
 681                .log_err();
 682            }
 683        });
 684
 685        self.update_server_state(
 686            id.clone(),
 687            ContextServerState::Starting {
 688                configuration,
 689                _task: task,
 690                server,
 691            },
 692            cx,
 693        );
 694    }
 695
 696    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 697        let state = self
 698            .servers
 699            .remove(id)
 700            .context("Context server not found")?;
 701
 702        if let ContextServerConfiguration::Http { url, .. } = state.configuration().as_ref() {
 703            let server_url = url.clone();
 704            let id = id.clone();
 705            cx.spawn(async move |_this, cx| {
 706                let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
 707                if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
 708                {
 709                    log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
 710                }
 711            })
 712            .detach();
 713        }
 714
 715        drop(state);
 716        cx.emit(ServerStatusChangedEvent {
 717            server_id: id.clone(),
 718            status: ContextServerStatus::Stopped,
 719        });
 720        Ok(())
 721    }
 722
 723    pub async fn create_context_server(
 724        this: WeakEntity<Self>,
 725        id: ContextServerId,
 726        configuration: Arc<ContextServerConfiguration>,
 727        cx: &mut AsyncApp,
 728    ) -> Result<(Arc<ContextServer>, Arc<ContextServerConfiguration>)> {
 729        let remote = configuration.remote();
 730        let needs_remote_command = match configuration.as_ref() {
 731            ContextServerConfiguration::Custom { .. }
 732            | ContextServerConfiguration::Extension { .. } => remote,
 733            ContextServerConfiguration::Http { .. } => false,
 734        };
 735
 736        let (remote_state, is_remote_project) = this.update(cx, |this, _| {
 737            let remote_state = match &this.state {
 738                ContextServerStoreState::Remote {
 739                    project_id,
 740                    upstream_client,
 741                } if needs_remote_command => Some((*project_id, upstream_client.clone())),
 742                _ => None,
 743            };
 744            (remote_state, this.is_remote_project())
 745        })?;
 746
 747        let root_path: Option<Arc<Path>> = this.update(cx, |this, cx| {
 748            this.project
 749                .as_ref()
 750                .and_then(|project| {
 751                    project
 752                        .read_with(cx, |project, cx| project.active_project_directory(cx))
 753                        .ok()
 754                        .flatten()
 755                })
 756                .or_else(|| {
 757                    this.worktree_store.read_with(cx, |store, cx| {
 758                        store.visible_worktrees(cx).fold(None, |acc, item| {
 759                            if acc.is_none() {
 760                                item.read(cx).root_dir()
 761                            } else {
 762                                acc
 763                            }
 764                        })
 765                    })
 766                })
 767        })?;
 768
 769        let configuration = if let Some((project_id, upstream_client)) = remote_state {
 770            let root_dir = root_path.as_ref().map(|p| p.display().to_string());
 771
 772            let response = upstream_client
 773                .update(cx, |client, _| {
 774                    client
 775                        .proto_client()
 776                        .request(proto::GetContextServerCommand {
 777                            project_id,
 778                            server_id: id.0.to_string(),
 779                            root_dir: root_dir.clone(),
 780                        })
 781                })
 782                .await?;
 783
 784            let remote_command = upstream_client.update(cx, |client, _| {
 785                client.build_command(
 786                    Some(response.path),
 787                    &response.args,
 788                    &response.env.into_iter().collect(),
 789                    root_dir,
 790                    None,
 791                )
 792            })?;
 793
 794            let command = ContextServerCommand {
 795                path: remote_command.program.into(),
 796                args: remote_command.args,
 797                env: Some(remote_command.env.into_iter().collect()),
 798                timeout: None,
 799            };
 800
 801            Arc::new(ContextServerConfiguration::Custom { command, remote })
 802        } else {
 803            configuration
 804        };
 805
 806        if let Some(server) = this.update(cx, |this, _| {
 807            this.context_server_factory
 808                .as_ref()
 809                .map(|factory| factory(id.clone(), configuration.clone()))
 810        })? {
 811            return Ok((server, configuration));
 812        }
 813
 814        let cached_token_provider: Option<Arc<dyn oauth::OAuthTokenProvider>> =
 815            if let ContextServerConfiguration::Http { url, .. } = configuration.as_ref() {
 816                if configuration.has_static_auth_header() {
 817                    None
 818                } else {
 819                    let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
 820                    let http_client = cx.update(|cx| cx.http_client());
 821
 822                    match Self::load_session(&credentials_provider, url, &cx).await {
 823                        Ok(Some(session)) => {
 824                            log::info!("{} loaded cached OAuth session from keychain", id);
 825                            Some(Self::create_oauth_token_provider(
 826                                &id,
 827                                url,
 828                                session,
 829                                http_client,
 830                                credentials_provider,
 831                                cx,
 832                            ))
 833                        }
 834                        Ok(None) => None,
 835                        Err(err) => {
 836                            log::warn!("{} failed to load cached OAuth session: {}", id, err);
 837                            None
 838                        }
 839                    }
 840                }
 841            } else {
 842                None
 843            };
 844
 845        let server: Arc<ContextServer> = this.update(cx, |this, cx| {
 846            let global_timeout =
 847                Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
 848
 849            match configuration.as_ref() {
 850                ContextServerConfiguration::Http {
 851                    url,
 852                    headers,
 853                    timeout,
 854                    oauth: _,
 855                } => {
 856                    let transport = HttpTransport::new_with_token_provider(
 857                        cx.http_client(),
 858                        url.to_string(),
 859                        headers.clone(),
 860                        cx.background_executor().clone(),
 861                        cached_token_provider.clone(),
 862                    );
 863                    anyhow::Ok(Arc::new(ContextServer::new_with_timeout(
 864                        id,
 865                        Arc::new(transport),
 866                        Some(Duration::from_secs(
 867                            timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
 868                        )),
 869                    )))
 870                }
 871                _ => {
 872                    let mut command = configuration
 873                        .command()
 874                        .context("Missing command configuration for stdio context server")?
 875                        .clone();
 876                    command.timeout = Some(
 877                        command
 878                            .timeout
 879                            .unwrap_or(global_timeout)
 880                            .min(MAX_TIMEOUT_SECS),
 881                    );
 882
 883                    // Don't pass remote paths as working directory for locally-spawned processes
 884                    let working_directory = if is_remote_project { None } else { root_path };
 885                    anyhow::Ok(Arc::new(ContextServer::stdio(
 886                        id,
 887                        command,
 888                        working_directory,
 889                    )))
 890                }
 891            }
 892        })??;
 893
 894        Ok((server, configuration))
 895    }
 896
 897    async fn handle_get_context_server_command(
 898        this: Entity<Self>,
 899        envelope: TypedEnvelope<proto::GetContextServerCommand>,
 900        mut cx: AsyncApp,
 901    ) -> Result<proto::ContextServerCommand> {
 902        let server_id = ContextServerId(envelope.payload.server_id.into());
 903
 904        let (settings, registry, worktree_store) = this.update(&mut cx, |this, inner_cx| {
 905            let ContextServerStoreState::Local {
 906                is_headless: true, ..
 907            } = &this.state
 908            else {
 909                anyhow::bail!("unexpected GetContextServerCommand request in a non-local project");
 910            };
 911
 912            let settings = this
 913                .context_server_settings
 914                .get(&server_id.0)
 915                .cloned()
 916                .or_else(|| {
 917                    this.registry
 918                        .read(inner_cx)
 919                        .context_server_descriptor(&server_id.0)
 920                        .map(|_| ContextServerSettings::default_extension())
 921                })
 922                .with_context(|| format!("context server `{}` not found", server_id))?;
 923
 924            anyhow::Ok((settings, this.registry.clone(), this.worktree_store.clone()))
 925        })?;
 926
 927        let configuration = ContextServerConfiguration::from_settings(
 928            settings,
 929            server_id.clone(),
 930            registry,
 931            worktree_store,
 932            &cx,
 933        )
 934        .await
 935        .with_context(|| format!("failed to build configuration for `{}`", server_id))?;
 936
 937        let command = configuration
 938            .command()
 939            .context("context server has no command (HTTP servers don't need RPC)")?;
 940
 941        Ok(proto::ContextServerCommand {
 942            path: command.path.display().to_string(),
 943            args: command.args.clone(),
 944            env: command
 945                .env
 946                .clone()
 947                .map(|env| env.into_iter().collect())
 948                .unwrap_or_default(),
 949        })
 950    }
 951
 952    fn resolve_project_settings<'a>(
 953        worktree_store: &'a Entity<WorktreeStore>,
 954        cx: &'a App,
 955    ) -> &'a ProjectSettings {
 956        let location = worktree_store
 957            .read(cx)
 958            .visible_worktrees(cx)
 959            .next()
 960            .map(|worktree| settings::SettingsLocation {
 961                worktree_id: worktree.read(cx).id(),
 962                path: RelPath::empty(),
 963            });
 964        ProjectSettings::get(location, cx)
 965    }
 966
 967    fn create_oauth_token_provider(
 968        id: &ContextServerId,
 969        server_url: &url::Url,
 970        session: OAuthSession,
 971        http_client: Arc<dyn HttpClient>,
 972        credentials_provider: Arc<dyn CredentialsProvider>,
 973        cx: &mut AsyncApp,
 974    ) -> Arc<dyn oauth::OAuthTokenProvider> {
 975        let (token_refresh_tx, mut token_refresh_rx) = futures::channel::mpsc::unbounded();
 976        let id = id.clone();
 977        let server_url = server_url.clone();
 978
 979        cx.spawn(async move |cx| {
 980            while let Some(refreshed_session) = token_refresh_rx.next().await {
 981                if let Err(err) =
 982                    Self::store_session(&credentials_provider, &server_url, &refreshed_session, &cx)
 983                        .await
 984                {
 985                    log::warn!("{} failed to persist refreshed OAuth session: {}", id, err);
 986                }
 987            }
 988            log::debug!("{} OAuth session persistence task ended", id);
 989        })
 990        .detach();
 991
 992        Arc::new(McpOAuthTokenProvider::new(
 993            session,
 994            http_client,
 995            Some(token_refresh_tx),
 996        ))
 997    }
 998
 999    /// Initiate the OAuth browser flow for a server in the `AuthRequired` state.
1000    ///
1001    /// This starts a loopback HTTP callback server on an ephemeral port, builds
1002    /// the authorization URL, opens the user's browser, waits for the callback,
1003    /// exchanges the code for tokens, persists them in the keychain, and restarts
1004    /// the server with the new token provider.
1005    pub fn authenticate_server(
1006        &mut self,
1007        id: &ContextServerId,
1008        cx: &mut Context<Self>,
1009    ) -> Result<()> {
1010        let state = self.servers.get(id).context("Context server not found")?;
1011
1012        let (discovery, server, configuration) = match state {
1013            ContextServerState::AuthRequired {
1014                discovery,
1015                server,
1016                configuration,
1017            } => (discovery.clone(), server.clone(), configuration.clone()),
1018            _ => anyhow::bail!("Server is not in AuthRequired state"),
1019        };
1020
1021        // Check if the configuration has pre-registered OAuth credentials that
1022        // need a client_secret we don't have yet.
1023        let needs_secret_prompt = match configuration.as_ref() {
1024            ContextServerConfiguration::Http {
1025                url,
1026                oauth: Some(oauth_settings),
1027                ..
1028            } if oauth_settings.client_secret.is_none() => Some(url.clone()),
1029            _ => None,
1030        };
1031
1032        let id = id.clone();
1033
1034        if let Some(server_url) = needs_secret_prompt {
1035            // Check keychain for the secret asynchronously.
1036            let task = cx.spawn({
1037                let id = id.clone();
1038                let server = server.clone();
1039                let configuration = configuration.clone();
1040                async move |this, cx| {
1041                    let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1042                    let keychain_secret =
1043                        Self::load_client_secret(&credentials_provider, &server_url, cx)
1044                            .await
1045                            .ok()
1046                            .flatten();
1047
1048                    if keychain_secret.is_some() {
1049                        // Secret found in keychain, proceed with OAuth flow.
1050                        let result = Self::run_oauth_flow(
1051                            this.clone(),
1052                            id.clone(),
1053                            discovery.clone(),
1054                            configuration.clone(),
1055                            cx,
1056                        )
1057                        .await;
1058
1059                        if let Err(err) = &result {
1060                            log::error!("{} OAuth authentication failed: {:?}", id, err);
1061                            this.update(cx, |this, cx| {
1062                                this.update_server_state(
1063                                    id.clone(),
1064                                    ContextServerState::AuthRequired {
1065                                        server,
1066                                        configuration,
1067                                        discovery,
1068                                    },
1069                                    cx,
1070                                )
1071                            })
1072                            .log_err();
1073                        }
1074                    } else {
1075                        // No secret anywhere — prompt the user.
1076                        this.update(cx, |this, cx| {
1077                            this.update_server_state(
1078                                id.clone(),
1079                                ContextServerState::ClientSecretRequired {
1080                                    server,
1081                                    configuration,
1082                                    discovery,
1083                                },
1084                                cx,
1085                            );
1086                        })
1087                        .log_err();
1088                    }
1089                }
1090            });
1091
1092            self.update_server_state(
1093                id,
1094                ContextServerState::Authenticating {
1095                    server,
1096                    configuration,
1097                    _task: task,
1098                },
1099                cx,
1100            );
1101        } else {
1102            // No pre-registration, or secret already in settings — proceed directly.
1103            let task = cx.spawn({
1104                let id = id.clone();
1105                let server = server.clone();
1106                let configuration = configuration.clone();
1107                async move |this, cx| {
1108                    let result = Self::run_oauth_flow(
1109                        this.clone(),
1110                        id.clone(),
1111                        discovery.clone(),
1112                        configuration.clone(),
1113                        cx,
1114                    )
1115                    .await;
1116
1117                    if let Err(err) = &result {
1118                        log::error!("{} OAuth authentication failed: {:?}", id, err);
1119                        this.update(cx, |this, cx| {
1120                            this.update_server_state(
1121                                id.clone(),
1122                                ContextServerState::AuthRequired {
1123                                    server,
1124                                    configuration,
1125                                    discovery,
1126                                },
1127                                cx,
1128                            )
1129                        })
1130                        .log_err();
1131                    }
1132                }
1133            });
1134
1135            self.update_server_state(
1136                id,
1137                ContextServerState::Authenticating {
1138                    server,
1139                    configuration,
1140                    _task: task,
1141                },
1142                cx,
1143            );
1144        }
1145
1146        Ok(())
1147    }
1148
1149    /// Store an interactively-provided client secret and proceed with authentication.
1150    pub fn submit_client_secret(
1151        &mut self,
1152        id: &ContextServerId,
1153        secret: String,
1154        cx: &mut Context<Self>,
1155    ) -> Result<()> {
1156        let state = self.servers.get(id).context("Context server not found")?;
1157
1158        let (server, configuration, discovery) = match state {
1159            ContextServerState::ClientSecretRequired {
1160                server,
1161                configuration,
1162                discovery,
1163            } => (server.clone(), configuration.clone(), discovery.clone()),
1164            _ => anyhow::bail!("Server is not in ClientSecretRequired state"),
1165        };
1166
1167        let server_url = match configuration.as_ref() {
1168            ContextServerConfiguration::Http { url, .. } => url.clone(),
1169            _ => anyhow::bail!("OAuth only supported for HTTP servers"),
1170        };
1171
1172        let id = id.clone();
1173
1174        let task = cx.spawn({
1175            let id = id.clone();
1176            let server = server.clone();
1177            let configuration = configuration.clone();
1178            async move |this, cx| {
1179                // Store the secret if non-empty (empty means public client / skip).
1180                if !secret.is_empty() {
1181                    let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1182                    if let Err(err) =
1183                        Self::store_client_secret(&credentials_provider, &server_url, &secret, cx)
1184                            .await
1185                    {
1186                        log::error!(
1187                            "{} failed to store client secret in keychain: {:?}",
1188                            id,
1189                            err
1190                        );
1191                    }
1192                }
1193
1194                let result = Self::run_oauth_flow(
1195                    this.clone(),
1196                    id.clone(),
1197                    discovery.clone(),
1198                    configuration.clone(),
1199                    cx,
1200                )
1201                .await;
1202
1203                if let Err(err) = &result {
1204                    log::error!("{} OAuth authentication failed: {:?}", id, err);
1205                    this.update(cx, |this, cx| {
1206                        this.update_server_state(
1207                            id.clone(),
1208                            ContextServerState::AuthRequired {
1209                                server,
1210                                configuration,
1211                                discovery,
1212                            },
1213                            cx,
1214                        )
1215                    })
1216                    .log_err();
1217                }
1218            }
1219        });
1220
1221        self.update_server_state(
1222            id,
1223            ContextServerState::Authenticating {
1224                server,
1225                configuration,
1226                _task: task,
1227            },
1228            cx,
1229        );
1230
1231        Ok(())
1232    }
1233
1234    async fn run_oauth_flow(
1235        this: WeakEntity<Self>,
1236        id: ContextServerId,
1237        discovery: Arc<OAuthDiscovery>,
1238        configuration: Arc<ContextServerConfiguration>,
1239        cx: &mut AsyncApp,
1240    ) -> Result<()> {
1241        let resource = oauth::canonical_server_uri(&discovery.resource_metadata.resource);
1242        let pkce = oauth::generate_pkce_challenge();
1243
1244        let mut state_bytes = [0u8; 32];
1245        rand::rng().fill(&mut state_bytes);
1246        let state_param: String = state_bytes.iter().map(|b| format!("{:02x}", b)).collect();
1247
1248        // Start a loopback HTTP server on an ephemeral port. The redirect URI
1249        // includes this port so the browser sends the callback directly to our
1250        // process.
1251        let (redirect_uri, callback_rx) = oauth::start_callback_server()
1252            .await
1253            .context("Failed to start OAuth callback server")?;
1254
1255        let http_client = cx.update(|cx| cx.http_client());
1256        let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1257        let server_url = match configuration.as_ref() {
1258            ContextServerConfiguration::Http { url, .. } => url.clone(),
1259            _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1260        };
1261
1262        let client_registration = match configuration.as_ref() {
1263            ContextServerConfiguration::Http {
1264                url,
1265                oauth: Some(oauth_settings),
1266                ..
1267            } => {
1268                // Pre-registered client. Resolve the secret from settings, then keychain.
1269                let client_secret = if oauth_settings.client_secret.is_some() {
1270                    oauth_settings.client_secret.clone()
1271                } else {
1272                    Self::load_client_secret(&credentials_provider, url, cx)
1273                        .await
1274                        .ok()
1275                        .flatten()
1276                };
1277                oauth::OAuthClientRegistration {
1278                    client_id: oauth_settings.client_id.clone(),
1279                    client_secret,
1280                }
1281            }
1282            _ => oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
1283                .await
1284                .context("Failed to resolve OAuth client registration")?,
1285        };
1286
1287        let auth_url = oauth::build_authorization_url(
1288            &discovery.auth_server_metadata,
1289            &client_registration.client_id,
1290            &redirect_uri,
1291            &discovery.scopes,
1292            &resource,
1293            &pkce,
1294            &state_param,
1295        );
1296
1297        cx.update(|cx| cx.open_url(auth_url.as_str()));
1298
1299        let callback = callback_rx
1300            .await
1301            .map_err(|_| {
1302                anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
1303            })?
1304            .context("OAuth callback server received an invalid request")?;
1305
1306        if callback.state != state_param {
1307            anyhow::bail!("OAuth state parameter mismatch (possible CSRF)");
1308        }
1309
1310        let tokens = oauth::exchange_code(
1311            &http_client,
1312            &discovery.auth_server_metadata,
1313            &callback.code,
1314            &client_registration.client_id,
1315            &redirect_uri,
1316            &pkce.verifier,
1317            &resource,
1318            client_registration.client_secret.as_deref(),
1319        )
1320        .await
1321        .context("Failed to exchange authorization code for tokens")?;
1322
1323        let session = OAuthSession {
1324            token_endpoint: discovery.auth_server_metadata.token_endpoint.clone(),
1325            resource: discovery.resource_metadata.resource.clone(),
1326            client_registration,
1327            tokens,
1328        };
1329
1330        Self::store_session(&credentials_provider, &server_url, &session, cx)
1331            .await
1332            .context("Failed to persist OAuth session in keychain")?;
1333
1334        let token_provider = Self::create_oauth_token_provider(
1335            &id,
1336            &server_url,
1337            session,
1338            http_client.clone(),
1339            credentials_provider,
1340            cx,
1341        );
1342
1343        let new_server = this.update(cx, |this, cx| {
1344            let global_timeout =
1345                Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
1346
1347            match configuration.as_ref() {
1348                ContextServerConfiguration::Http {
1349                    url,
1350                    headers,
1351                    timeout,
1352                    oauth: _,
1353                } => {
1354                    let transport = HttpTransport::new_with_token_provider(
1355                        http_client.clone(),
1356                        url.to_string(),
1357                        headers.clone(),
1358                        cx.background_executor().clone(),
1359                        Some(token_provider.clone()),
1360                    );
1361                    Ok(Arc::new(ContextServer::new_with_timeout(
1362                        id.clone(),
1363                        Arc::new(transport),
1364                        Some(Duration::from_secs(
1365                            timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
1366                        )),
1367                    )))
1368                }
1369                _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1370            }
1371        })??;
1372
1373        this.update(cx, |this, cx| {
1374            this.run_server(new_server, configuration, cx);
1375        })?;
1376
1377        Ok(())
1378    }
1379
1380    /// Store the full OAuth session in the system keychain, keyed by the
1381    /// server's canonical URI.
1382    async fn store_session(
1383        credentials_provider: &Arc<dyn CredentialsProvider>,
1384        server_url: &url::Url,
1385        session: &OAuthSession,
1386        cx: &AsyncApp,
1387    ) -> Result<()> {
1388        let key = Self::keychain_key(server_url);
1389        let json = serde_json::to_string(session)?;
1390        credentials_provider
1391            .write_credentials(&key, "mcp-oauth", json.as_bytes(), cx)
1392            .await
1393    }
1394
1395    /// Load the full OAuth session from the system keychain for the given
1396    /// server URL.
1397    async fn load_session(
1398        credentials_provider: &Arc<dyn CredentialsProvider>,
1399        server_url: &url::Url,
1400        cx: &AsyncApp,
1401    ) -> Result<Option<OAuthSession>> {
1402        let key = Self::keychain_key(server_url);
1403        match credentials_provider.read_credentials(&key, cx).await? {
1404            Some((_username, password_bytes)) => {
1405                let session: OAuthSession = serde_json::from_slice(&password_bytes)?;
1406                Ok(Some(session))
1407            }
1408            None => Ok(None),
1409        }
1410    }
1411
1412    /// Clear the stored OAuth session from the system keychain.
1413    async fn clear_session(
1414        credentials_provider: &Arc<dyn CredentialsProvider>,
1415        server_url: &url::Url,
1416        cx: &AsyncApp,
1417    ) -> Result<()> {
1418        let key = Self::keychain_key(server_url);
1419        credentials_provider.delete_credentials(&key, cx).await
1420    }
1421
1422    fn keychain_key(server_url: &url::Url) -> String {
1423        format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
1424    }
1425
1426    fn client_secret_keychain_key(server_url: &url::Url) -> String {
1427        format!(
1428            "mcp-oauth-client-secret:{}",
1429            oauth::canonical_server_uri(server_url)
1430        )
1431    }
1432
1433    async fn load_client_secret(
1434        credentials_provider: &Arc<dyn CredentialsProvider>,
1435        server_url: &url::Url,
1436        cx: &AsyncApp,
1437    ) -> Result<Option<String>> {
1438        let key = Self::client_secret_keychain_key(server_url);
1439        match credentials_provider.read_credentials(&key, cx).await? {
1440            Some((_username, secret_bytes)) => Ok(Some(String::from_utf8(secret_bytes)?)),
1441            None => Ok(None),
1442        }
1443    }
1444
1445    pub async fn store_client_secret(
1446        credentials_provider: &Arc<dyn CredentialsProvider>,
1447        server_url: &url::Url,
1448        secret: &str,
1449        cx: &AsyncApp,
1450    ) -> Result<()> {
1451        let key = Self::client_secret_keychain_key(server_url);
1452        credentials_provider
1453            .write_credentials(&key, "mcp-oauth-client-secret", secret.as_bytes(), cx)
1454            .await
1455    }
1456
1457    async fn clear_client_secret(
1458        credentials_provider: &Arc<dyn CredentialsProvider>,
1459        server_url: &url::Url,
1460        cx: &AsyncApp,
1461    ) -> Result<()> {
1462        let key = Self::client_secret_keychain_key(server_url);
1463        credentials_provider.delete_credentials(&key, cx).await
1464    }
1465
1466    /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
1467    /// session from the keychain and stop the server.
1468    pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
1469        let state = self.servers.get(id).context("Context server not found")?;
1470        let configuration = state.configuration();
1471
1472        let server_url = match configuration.as_ref() {
1473            ContextServerConfiguration::Http { url, .. } => url.clone(),
1474            _ => anyhow::bail!("logout only applies to HTTP servers with OAuth"),
1475        };
1476
1477        let id = id.clone();
1478        self.stop_server(&id, cx)?;
1479
1480        cx.spawn(async move |this, cx| {
1481            let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1482            if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
1483                log::error!("{} failed to clear OAuth session: {}", id, err);
1484            }
1485            // Also clear any interactively-provided client secret so the user
1486            // gets a fresh prompt on the next authentication attempt.
1487            Self::clear_client_secret(&credentials_provider, &server_url, &cx)
1488                .await
1489                .log_err();
1490            // Trigger server recreation so the next start uses a fresh
1491            // transport without the old (now-invalidated) token provider.
1492            this.update(cx, |this, cx| {
1493                this.available_context_servers_changed(cx);
1494            })
1495            .log_err();
1496        })
1497        .detach();
1498
1499        Ok(())
1500    }
1501
1502    fn update_server_state(
1503        &mut self,
1504        id: ContextServerId,
1505        state: ContextServerState,
1506        cx: &mut Context<Self>,
1507    ) {
1508        let status = ContextServerStatus::from_state(&state);
1509        self.servers.insert(id.clone(), state);
1510        cx.emit(ServerStatusChangedEvent {
1511            server_id: id,
1512            status,
1513        });
1514    }
1515
1516    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
1517        if self.update_servers_task.is_some() {
1518            self.needs_server_update = true;
1519        } else {
1520            self.needs_server_update = false;
1521            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
1522                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
1523                    log::error!("Error maintaining context servers: {}", err);
1524                }
1525
1526                this.update(cx, |this, cx| {
1527                    this.populate_server_ids(cx);
1528                    cx.notify();
1529                    this.update_servers_task.take();
1530                    if this.needs_server_update {
1531                        this.available_context_servers_changed(cx);
1532                    }
1533                })?;
1534
1535                Ok(())
1536            }));
1537        }
1538    }
1539
1540    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
1541        // Don't start context servers if AI is disabled
1542        let ai_disabled = this.update(cx, |_, cx| DisableAiSettings::get_global(cx).disable_ai)?;
1543        if ai_disabled {
1544            // Stop all running servers when AI is disabled
1545            this.update(cx, |this, cx| {
1546                let server_ids: Vec<_> = this.servers.keys().cloned().collect();
1547                for id in server_ids {
1548                    let _ = this.stop_server(&id, cx);
1549                }
1550            })?;
1551            return Ok(());
1552        }
1553
1554        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
1555            (
1556                this.context_server_settings.clone(),
1557                this.registry.clone(),
1558                this.worktree_store.clone(),
1559            )
1560        })?;
1561
1562        for (id, _) in registry.read_with(cx, |registry, _| registry.context_server_descriptors()) {
1563            configured_servers
1564                .entry(id)
1565                .or_insert(ContextServerSettings::default_extension());
1566        }
1567
1568        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
1569            configured_servers
1570                .into_iter()
1571                .partition(|(_, settings)| settings.enabled());
1572
1573        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
1574            let id = ContextServerId(id);
1575            ContextServerConfiguration::from_settings(
1576                settings,
1577                id.clone(),
1578                registry.clone(),
1579                worktree_store.clone(),
1580                cx,
1581            )
1582            .map(move |config| (id, config))
1583        }))
1584        .await
1585        .into_iter()
1586        .filter_map(|(id, config)| config.map(|config| (id, config)))
1587        .collect::<HashMap<_, _>>();
1588
1589        let mut servers_to_start = Vec::new();
1590        let mut servers_to_remove = HashSet::default();
1591        let mut servers_to_stop = HashSet::default();
1592
1593        this.update(cx, |this, _cx| {
1594            for server_id in this.servers.keys() {
1595                // All servers that are not in desired_servers should be removed from the store.
1596                // This can happen if the user removed a server from the context server settings.
1597                if !configured_servers.contains_key(server_id) {
1598                    if disabled_servers.contains_key(&server_id.0) {
1599                        servers_to_stop.insert(server_id.clone());
1600                    } else {
1601                        servers_to_remove.insert(server_id.clone());
1602                    }
1603                }
1604            }
1605
1606            for (id, config) in configured_servers {
1607                let state = this.servers.get(&id);
1608                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
1609                let existing_config = state.as_ref().map(|state| state.configuration());
1610                if existing_config.as_deref() != Some(&config) || is_stopped {
1611                    let config = Arc::new(config);
1612                    servers_to_start.push((id.clone(), config));
1613                    if this.servers.contains_key(&id) {
1614                        servers_to_stop.insert(id);
1615                    }
1616                }
1617            }
1618
1619            anyhow::Ok(())
1620        })??;
1621
1622        this.update(cx, |this, inner_cx| {
1623            for id in servers_to_stop {
1624                this.stop_server(&id, inner_cx)?;
1625            }
1626            for id in servers_to_remove {
1627                this.remove_server(&id, inner_cx)?;
1628            }
1629            anyhow::Ok(())
1630        })??;
1631
1632        for (id, config) in servers_to_start {
1633            match Self::create_context_server(this.clone(), id.clone(), config, cx).await {
1634                Ok((server, config)) => {
1635                    this.update(cx, |this, cx| {
1636                        this.run_server(server, config, cx);
1637                    })?;
1638                }
1639                Err(err) => {
1640                    log::error!("{id} context server failed to create: {err:#}");
1641                    this.update(cx, |_this, cx| {
1642                        cx.emit(ServerStatusChangedEvent {
1643                            server_id: id,
1644                            status: ContextServerStatus::Error(err.to_string().into()),
1645                        });
1646                        cx.notify();
1647                    })?;
1648                }
1649            }
1650        }
1651
1652        Ok(())
1653    }
1654}
1655
1656/// Determines the appropriate server state after a start attempt fails.
1657///
1658/// When the error is an HTTP 401 with no static auth header configured,
1659/// attempts OAuth discovery so the UI can offer an authentication flow.
1660async fn resolve_start_failure(
1661    id: &ContextServerId,
1662    err: anyhow::Error,
1663    server: Arc<ContextServer>,
1664    configuration: Arc<ContextServerConfiguration>,
1665    cx: &AsyncApp,
1666) -> ContextServerState {
1667    let www_authenticate = err.downcast_ref::<TransportError>().map(|e| match e {
1668        TransportError::AuthRequired { www_authenticate } => www_authenticate.clone(),
1669    });
1670
1671    if www_authenticate.is_some() && configuration.has_static_auth_header() {
1672        log::warn!("{id} received 401 with a static Authorization header configured");
1673        return ContextServerState::Error {
1674            configuration,
1675            server,
1676            error: "Server returned 401 Unauthorized. Check your configured Authorization header."
1677                .into(),
1678        };
1679    }
1680
1681    let server_url = match configuration.as_ref() {
1682        ContextServerConfiguration::Http { url, .. } if !configuration.has_static_auth_header() => {
1683            url.clone()
1684        }
1685        _ => {
1686            if www_authenticate.is_some() {
1687                log::error!("{id} got OAuth 401 on a non-HTTP transport or with static auth");
1688            } else {
1689                log::error!("{id} context server failed to start: {err}");
1690            }
1691            return ContextServerState::Error {
1692                configuration,
1693                server,
1694                error: err.to_string().into(),
1695            };
1696        }
1697    };
1698
1699    // When the error is NOT a 401 but there is a cached OAuth session in the
1700    // keychain, the session is likely stale/expired and caused the failure
1701    // (e.g. timeout because the server rejected the token silently). Clear it
1702    // so the next start attempt can get a clean 401 and trigger the auth flow.
1703    if www_authenticate.is_none() {
1704        let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1705        match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
1706            Ok(Some(_)) => {
1707                log::info!("{id} start failed with a cached OAuth session present; clearing it");
1708                ContextServerStore::clear_session(&credentials_provider, &server_url, cx)
1709                    .await
1710                    .log_err();
1711            }
1712            _ => {
1713                log::error!("{id} context server failed to start: {err}");
1714                return ContextServerState::Error {
1715                    configuration,
1716                    server,
1717                    error: err.to_string().into(),
1718                };
1719            }
1720        }
1721    }
1722
1723    let default_www_authenticate = oauth::WwwAuthenticate {
1724        resource_metadata: None,
1725        scope: None,
1726        error: None,
1727        error_description: None,
1728    };
1729    let www_authenticate = www_authenticate
1730        .as_ref()
1731        .unwrap_or(&default_www_authenticate);
1732    let http_client = cx.update(|cx| cx.http_client());
1733
1734    match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
1735        Ok(discovery) => {
1736            use context_server::oauth::{
1737                ClientRegistrationStrategy, determine_registration_strategy,
1738            };
1739
1740            let has_preregistered_client_id = matches!(
1741                configuration.as_ref(),
1742                ContextServerConfiguration::Http { oauth: Some(_), .. }
1743            );
1744
1745            let strategy = determine_registration_strategy(&discovery.auth_server_metadata);
1746
1747            if matches!(strategy, ClientRegistrationStrategy::Unavailable)
1748                && !has_preregistered_client_id
1749            {
1750                log::error!(
1751                    "{id} authorization server supports neither CIMD nor DCR, \
1752                     and no pre-registered client_id is configured"
1753                );
1754                return ContextServerState::Error {
1755                    configuration,
1756                    server,
1757                    error: "Authorization server supports neither CIMD nor DCR. \
1758                            Configure a pre-registered client_id in your settings \
1759                            under the \"oauth\" key."
1760                        .into(),
1761                };
1762            }
1763
1764            log::info!(
1765                "{id} requires OAuth authorization (auth server: {})",
1766                discovery.auth_server_metadata.issuer,
1767            );
1768            ContextServerState::AuthRequired {
1769                server,
1770                configuration,
1771                discovery: Arc::new(discovery),
1772            }
1773        }
1774        Err(discovery_err) => {
1775            log::error!("{id} OAuth discovery failed: {discovery_err}");
1776            ContextServerState::Error {
1777                configuration,
1778                server,
1779                error: format!("OAuth discovery failed: {discovery_err}").into(),
1780            }
1781        }
1782    }
1783}