context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::path::Path;
   5use std::sync::Arc;
   6use std::time::Duration;
   7
   8use anyhow::{Context as _, Result};
   9use collections::{HashMap, HashSet};
  10use context_server::oauth::{self, McpOAuthTokenProvider, OAuthDiscovery, OAuthSession};
  11use context_server::transport::{HttpTransport, TransportError};
  12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  13use credentials_provider::CredentialsProvider;
  14use futures::future::Either;
  15use futures::{FutureExt as _, StreamExt as _, future::join_all};
  16use gpui::{
  17    App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, TaskExt, WeakEntity, actions,
  18};
  19use http_client::HttpClient;
  20use itertools::Itertools;
  21use rand::Rng as _;
  22use registry::ContextServerDescriptorRegistry;
  23use remote::RemoteClient;
  24use rpc::{AnyProtoClient, TypedEnvelope, proto};
  25use settings::{Settings as _, SettingsStore};
  26use util::{ResultExt as _, rel_path::RelPath};
  27
  28use crate::{
  29    DisableAiSettings, Project,
  30    project_settings::{ContextServerSettings, ProjectSettings},
  31    worktree_store::WorktreeStore,
  32};
  33
  34/// Maximum timeout for context server requests
  35/// Prevents extremely large timeout values from tying up resources indefinitely.
  36const MAX_TIMEOUT_SECS: u64 = 600; // 10 minutes
  37
  38pub fn init(cx: &mut App) {
  39    extension::init(cx);
  40}
  41
  42actions!(
  43    context_server,
  44    [
  45        /// Restarts the context server.
  46        Restart
  47    ]
  48);
  49
  50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  51pub enum ContextServerStatus {
  52    Starting,
  53    Running,
  54    Stopped,
  55    Error(Arc<str>),
  56    /// The server returned 401 and OAuth authorization is needed. The UI
  57    /// should show an "Authenticate" button.
  58    AuthRequired,
  59    /// The OAuth browser flow is in progress — the user has been redirected
  60    /// to the authorization server and we're waiting for the callback.
  61    Authenticating,
  62}
  63
  64impl ContextServerStatus {
  65    fn from_state(state: &ContextServerState) -> Self {
  66        match state {
  67            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  68            ContextServerState::Running { .. } => ContextServerStatus::Running,
  69            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  70            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  71            ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
  72            ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
  73        }
  74    }
  75}
  76
  77enum ContextServerState {
  78    Starting {
  79        server: Arc<ContextServer>,
  80        configuration: Arc<ContextServerConfiguration>,
  81        _task: Task<()>,
  82    },
  83    Running {
  84        server: Arc<ContextServer>,
  85        configuration: Arc<ContextServerConfiguration>,
  86    },
  87    Stopped {
  88        server: Arc<ContextServer>,
  89        configuration: Arc<ContextServerConfiguration>,
  90    },
  91    Error {
  92        server: Arc<ContextServer>,
  93        configuration: Arc<ContextServerConfiguration>,
  94        error: Arc<str>,
  95    },
  96    /// The server requires OAuth authorization before it can be used. The
  97    /// `OAuthDiscovery` holds everything needed to start the browser flow.
  98    AuthRequired {
  99        server: Arc<ContextServer>,
 100        configuration: Arc<ContextServerConfiguration>,
 101        discovery: Arc<OAuthDiscovery>,
 102    },
 103    /// The OAuth browser flow is in progress. The user has been redirected
 104    /// to the authorization server and we're waiting for the callback.
 105    Authenticating {
 106        server: Arc<ContextServer>,
 107        configuration: Arc<ContextServerConfiguration>,
 108        _task: Task<()>,
 109    },
 110}
 111
 112impl ContextServerState {
 113    pub fn server(&self) -> Arc<ContextServer> {
 114        match self {
 115            ContextServerState::Starting { server, .. }
 116            | ContextServerState::Running { server, .. }
 117            | ContextServerState::Stopped { server, .. }
 118            | ContextServerState::Error { server, .. }
 119            | ContextServerState::AuthRequired { server, .. }
 120            | ContextServerState::Authenticating { server, .. } => server.clone(),
 121        }
 122    }
 123
 124    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
 125        match self {
 126            ContextServerState::Starting { configuration, .. }
 127            | ContextServerState::Running { configuration, .. }
 128            | ContextServerState::Stopped { configuration, .. }
 129            | ContextServerState::Error { configuration, .. }
 130            | ContextServerState::AuthRequired { configuration, .. }
 131            | ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
 132        }
 133    }
 134}
 135
 136#[derive(Debug, PartialEq, Eq)]
 137pub enum ContextServerConfiguration {
 138    Custom {
 139        command: ContextServerCommand,
 140        remote: bool,
 141    },
 142    Extension {
 143        command: ContextServerCommand,
 144        settings: serde_json::Value,
 145        remote: bool,
 146    },
 147    Http {
 148        url: url::Url,
 149        headers: HashMap<String, String>,
 150        timeout: Option<u64>,
 151    },
 152}
 153
 154impl ContextServerConfiguration {
 155    pub fn command(&self) -> Option<&ContextServerCommand> {
 156        match self {
 157            ContextServerConfiguration::Custom { command, .. } => Some(command),
 158            ContextServerConfiguration::Extension { command, .. } => Some(command),
 159            ContextServerConfiguration::Http { .. } => None,
 160        }
 161    }
 162
 163    pub fn has_static_auth_header(&self) -> bool {
 164        match self {
 165            ContextServerConfiguration::Http { headers, .. } => headers
 166                .keys()
 167                .any(|k| k.eq_ignore_ascii_case("authorization")),
 168            _ => false,
 169        }
 170    }
 171
 172    pub fn remote(&self) -> bool {
 173        match self {
 174            ContextServerConfiguration::Custom { remote, .. } => *remote,
 175            ContextServerConfiguration::Extension { remote, .. } => *remote,
 176            ContextServerConfiguration::Http { .. } => false,
 177        }
 178    }
 179
 180    pub async fn from_settings(
 181        settings: ContextServerSettings,
 182        id: ContextServerId,
 183        registry: Entity<ContextServerDescriptorRegistry>,
 184        worktree_store: Entity<WorktreeStore>,
 185        cx: &AsyncApp,
 186    ) -> Option<Self> {
 187        const EXTENSION_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
 188
 189        match settings {
 190            ContextServerSettings::Stdio {
 191                enabled: _,
 192                command,
 193                remote,
 194            } => Some(ContextServerConfiguration::Custom { command, remote }),
 195            ContextServerSettings::Extension {
 196                enabled: _,
 197                settings,
 198                remote,
 199            } => {
 200                let descriptor =
 201                    cx.update(|cx| registry.read(cx).context_server_descriptor(&id.0))?;
 202
 203                let command_future = descriptor.command(worktree_store, cx);
 204                let timeout_future = cx.background_executor().timer(EXTENSION_COMMAND_TIMEOUT);
 205
 206                match futures::future::select(command_future, timeout_future).await {
 207                    Either::Left((Ok(command), _)) => Some(ContextServerConfiguration::Extension {
 208                        command,
 209                        settings,
 210                        remote,
 211                    }),
 212                    Either::Left((Err(e), _)) => {
 213                        log::error!(
 214                            "Failed to create context server configuration from settings: {e:#}"
 215                        );
 216                        None
 217                    }
 218                    Either::Right(_) => {
 219                        log::error!(
 220                            "Timed out resolving command for extension context server {id}"
 221                        );
 222                        None
 223                    }
 224                }
 225            }
 226            ContextServerSettings::Http {
 227                enabled: _,
 228                url,
 229                headers: auth,
 230                timeout,
 231            } => {
 232                let url = url::Url::parse(&url).log_err()?;
 233                Some(ContextServerConfiguration::Http {
 234                    url,
 235                    headers: auth,
 236                    timeout,
 237                })
 238            }
 239        }
 240    }
 241}
 242
 243pub type ContextServerFactory =
 244    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 245
 246enum ContextServerStoreState {
 247    Local {
 248        downstream_client: Option<(u64, AnyProtoClient)>,
 249        is_headless: bool,
 250    },
 251    Remote {
 252        project_id: u64,
 253        upstream_client: Entity<RemoteClient>,
 254    },
 255}
 256
 257pub struct ContextServerStore {
 258    state: ContextServerStoreState,
 259    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 260    servers: HashMap<ContextServerId, ContextServerState>,
 261    server_ids: Vec<ContextServerId>,
 262    worktree_store: Entity<WorktreeStore>,
 263    project: Option<WeakEntity<Project>>,
 264    registry: Entity<ContextServerDescriptorRegistry>,
 265    update_servers_task: Option<Task<Result<()>>>,
 266    context_server_factory: Option<ContextServerFactory>,
 267    needs_server_update: bool,
 268    ai_disabled: bool,
 269    _subscriptions: Vec<Subscription>,
 270}
 271
 272pub struct ServerStatusChangedEvent {
 273    pub server_id: ContextServerId,
 274    pub status: ContextServerStatus,
 275}
 276
 277impl EventEmitter<ServerStatusChangedEvent> for ContextServerStore {}
 278
 279impl ContextServerStore {
 280    pub fn local(
 281        worktree_store: Entity<WorktreeStore>,
 282        weak_project: Option<WeakEntity<Project>>,
 283        headless: bool,
 284        cx: &mut Context<Self>,
 285    ) -> Self {
 286        Self::new_internal(
 287            !headless,
 288            None,
 289            ContextServerDescriptorRegistry::default_global(cx),
 290            worktree_store,
 291            weak_project,
 292            ContextServerStoreState::Local {
 293                downstream_client: None,
 294                is_headless: headless,
 295            },
 296            cx,
 297        )
 298    }
 299
 300    pub fn remote(
 301        project_id: u64,
 302        upstream_client: Entity<RemoteClient>,
 303        worktree_store: Entity<WorktreeStore>,
 304        weak_project: Option<WeakEntity<Project>>,
 305        cx: &mut Context<Self>,
 306    ) -> Self {
 307        Self::new_internal(
 308            true,
 309            None,
 310            ContextServerDescriptorRegistry::default_global(cx),
 311            worktree_store,
 312            weak_project,
 313            ContextServerStoreState::Remote {
 314                project_id,
 315                upstream_client,
 316            },
 317            cx,
 318        )
 319    }
 320
 321    pub fn init_headless(session: &AnyProtoClient) {
 322        session.add_entity_request_handler(Self::handle_get_context_server_command);
 323    }
 324
 325    pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) {
 326        if let ContextServerStoreState::Local {
 327            downstream_client, ..
 328        } = &mut self.state
 329        {
 330            *downstream_client = Some((project_id, client));
 331        }
 332    }
 333
 334    pub fn is_remote_project(&self) -> bool {
 335        matches!(self.state, ContextServerStoreState::Remote { .. })
 336    }
 337
 338    /// Returns all configured context server ids, excluding the ones that are disabled
 339    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 340        self.context_server_settings
 341            .iter()
 342            .filter(|(_, settings)| settings.enabled())
 343            .map(|(id, _)| ContextServerId(id.clone()))
 344            .collect()
 345    }
 346
 347    #[cfg(feature = "test-support")]
 348    pub fn test(
 349        registry: Entity<ContextServerDescriptorRegistry>,
 350        worktree_store: Entity<WorktreeStore>,
 351        weak_project: Option<WeakEntity<Project>>,
 352        cx: &mut Context<Self>,
 353    ) -> Self {
 354        Self::new_internal(
 355            false,
 356            None,
 357            registry,
 358            worktree_store,
 359            weak_project,
 360            ContextServerStoreState::Local {
 361                downstream_client: None,
 362                is_headless: false,
 363            },
 364            cx,
 365        )
 366    }
 367
 368    #[cfg(feature = "test-support")]
 369    pub fn test_maintain_server_loop(
 370        context_server_factory: Option<ContextServerFactory>,
 371        registry: Entity<ContextServerDescriptorRegistry>,
 372        worktree_store: Entity<WorktreeStore>,
 373        weak_project: Option<WeakEntity<Project>>,
 374        cx: &mut Context<Self>,
 375    ) -> Self {
 376        Self::new_internal(
 377            true,
 378            context_server_factory,
 379            registry,
 380            worktree_store,
 381            weak_project,
 382            ContextServerStoreState::Local {
 383                downstream_client: None,
 384                is_headless: false,
 385            },
 386            cx,
 387        )
 388    }
 389
 390    #[cfg(feature = "test-support")]
 391    pub fn set_context_server_factory(&mut self, factory: ContextServerFactory) {
 392        self.context_server_factory = Some(factory);
 393    }
 394
 395    #[cfg(feature = "test-support")]
 396    pub fn registry(&self) -> &Entity<ContextServerDescriptorRegistry> {
 397        &self.registry
 398    }
 399
 400    #[cfg(feature = "test-support")]
 401    pub fn test_start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 402        let configuration = Arc::new(ContextServerConfiguration::Custom {
 403            command: ContextServerCommand {
 404                path: "test".into(),
 405                args: vec![],
 406                env: None,
 407                timeout: None,
 408            },
 409            remote: false,
 410        });
 411        self.run_server(server, configuration, cx);
 412    }
 413
 414    fn new_internal(
 415        maintain_server_loop: bool,
 416        context_server_factory: Option<ContextServerFactory>,
 417        registry: Entity<ContextServerDescriptorRegistry>,
 418        worktree_store: Entity<WorktreeStore>,
 419        weak_project: Option<WeakEntity<Project>>,
 420        state: ContextServerStoreState,
 421        cx: &mut Context<Self>,
 422    ) -> Self {
 423        let mut subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
 424            let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
 425            let ai_was_disabled = this.ai_disabled;
 426            this.ai_disabled = ai_disabled;
 427
 428            let settings =
 429                &Self::resolve_project_settings(&this.worktree_store, cx).context_servers;
 430            let settings_changed = &this.context_server_settings != settings;
 431
 432            if settings_changed {
 433                this.context_server_settings = settings.clone();
 434            }
 435
 436            // When AI is disabled, stop all running servers
 437            if ai_disabled {
 438                let server_ids: Vec<_> = this.servers.keys().cloned().collect();
 439                for id in server_ids {
 440                    this.stop_server(&id, cx).log_err();
 441                }
 442                return;
 443            }
 444
 445            // Trigger updates if AI was re-enabled or settings changed
 446            if maintain_server_loop && (ai_was_disabled || settings_changed) {
 447                this.available_context_servers_changed(cx);
 448            }
 449        })];
 450
 451        if maintain_server_loop {
 452            subscriptions.push(cx.observe(&registry, |this, _registry, cx| {
 453                if !DisableAiSettings::get_global(cx).disable_ai {
 454                    this.available_context_servers_changed(cx);
 455                }
 456            }));
 457        }
 458
 459        let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
 460        let mut this = Self {
 461            state,
 462            _subscriptions: subscriptions,
 463            context_server_settings: Self::resolve_project_settings(&worktree_store, cx)
 464                .context_servers
 465                .clone(),
 466            worktree_store,
 467            project: weak_project,
 468            registry,
 469            needs_server_update: false,
 470            ai_disabled,
 471            servers: HashMap::default(),
 472            server_ids: Default::default(),
 473            update_servers_task: None,
 474            context_server_factory,
 475        };
 476        if maintain_server_loop && !DisableAiSettings::get_global(cx).disable_ai {
 477            this.available_context_servers_changed(cx);
 478        }
 479        this
 480    }
 481
 482    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 483        self.servers.get(id).map(|state| state.server())
 484    }
 485
 486    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 487        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 488            Some(server.clone())
 489        } else {
 490            None
 491        }
 492    }
 493
 494    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 495        self.servers.get(id).map(ContextServerStatus::from_state)
 496    }
 497
 498    pub fn configuration_for_server(
 499        &self,
 500        id: &ContextServerId,
 501    ) -> Option<Arc<ContextServerConfiguration>> {
 502        self.servers.get(id).map(|state| state.configuration())
 503    }
 504
 505    /// Returns a sorted slice of available unique context server IDs. Within the
 506    /// slice, context servers which have `mcp-server-` as a prefix in their ID will
 507    /// appear after servers that do not have this prefix in their ID.
 508    pub fn server_ids(&self) -> &[ContextServerId] {
 509        self.server_ids.as_slice()
 510    }
 511
 512    fn populate_server_ids(&mut self, cx: &App) {
 513        self.server_ids = self
 514            .servers
 515            .keys()
 516            .cloned()
 517            .chain(
 518                self.registry
 519                    .read(cx)
 520                    .context_server_descriptors()
 521                    .into_iter()
 522                    .map(|(id, _)| ContextServerId(id)),
 523            )
 524            .chain(
 525                self.context_server_settings
 526                    .keys()
 527                    .map(|id| ContextServerId(id.clone())),
 528            )
 529            .unique()
 530            .sorted_unstable_by(
 531                // Sort context servers: ones without mcp-server- prefix first, then prefixed ones
 532                |a, b| {
 533                    const MCP_PREFIX: &str = "mcp-server-";
 534                    match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) {
 535                        // If one has mcp-server- prefix and other doesn't, non-mcp comes first
 536                        (Some(_), None) => std::cmp::Ordering::Greater,
 537                        (None, Some(_)) => std::cmp::Ordering::Less,
 538                        // If both have same prefix status, sort by appropriate key
 539                        (Some(a), Some(b)) => a.cmp(b),
 540                        (None, None) => a.0.cmp(&b.0),
 541                    }
 542                },
 543            )
 544            .collect();
 545    }
 546
 547    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 548        self.servers
 549            .values()
 550            .filter_map(|state| {
 551                if let ContextServerState::Running { server, .. } = state {
 552                    Some(server.clone())
 553                } else {
 554                    None
 555                }
 556            })
 557            .collect()
 558    }
 559
 560    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 561        cx.spawn(async move |this, cx| {
 562            let this = this.upgrade().context("Context server store dropped")?;
 563            let id = server.id();
 564            let settings = this
 565                .update(cx, |this, _| {
 566                    this.context_server_settings.get(&id.0).cloned()
 567                })
 568                .context("Failed to get context server settings")?;
 569
 570            if !settings.enabled() {
 571                return anyhow::Ok(());
 572            }
 573
 574            let (registry, worktree_store) = this.update(cx, |this, _| {
 575                (this.registry.clone(), this.worktree_store.clone())
 576            });
 577            let configuration = ContextServerConfiguration::from_settings(
 578                settings,
 579                id.clone(),
 580                registry,
 581                worktree_store,
 582                cx,
 583            )
 584            .await
 585            .context("Failed to create context server configuration")?;
 586
 587            this.update(cx, |this, cx| {
 588                this.run_server(server, Arc::new(configuration), cx)
 589            });
 590            Ok(())
 591        })
 592        .detach_and_log_err(cx);
 593    }
 594
 595    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 596        if matches!(
 597            self.servers.get(id),
 598            Some(ContextServerState::Stopped { .. })
 599        ) {
 600            return Ok(());
 601        }
 602
 603        let state = self
 604            .servers
 605            .remove(id)
 606            .context("Context server not found")?;
 607
 608        let server = state.server();
 609        let configuration = state.configuration();
 610        let mut result = Ok(());
 611        if let ContextServerState::Running { server, .. } = &state {
 612            result = server.stop();
 613        }
 614        drop(state);
 615
 616        self.update_server_state(
 617            id.clone(),
 618            ContextServerState::Stopped {
 619                configuration,
 620                server,
 621            },
 622            cx,
 623        );
 624
 625        result
 626    }
 627
 628    fn run_server(
 629        &mut self,
 630        server: Arc<ContextServer>,
 631        configuration: Arc<ContextServerConfiguration>,
 632        cx: &mut Context<Self>,
 633    ) {
 634        let id = server.id();
 635        if matches!(
 636            self.servers.get(&id),
 637            Some(
 638                ContextServerState::Starting { .. }
 639                    | ContextServerState::Running { .. }
 640                    | ContextServerState::Authenticating { .. },
 641            )
 642        ) {
 643            self.stop_server(&id, cx).log_err();
 644        }
 645        let task = cx.spawn({
 646            let id = server.id();
 647            let server = server.clone();
 648            let configuration = configuration.clone();
 649
 650            async move |this, cx| {
 651                let new_state = match server.clone().start(cx).await {
 652                    Ok(_) => {
 653                        debug_assert!(server.client().is_some());
 654                        ContextServerState::Running {
 655                            server,
 656                            configuration,
 657                        }
 658                    }
 659                    Err(err) => resolve_start_failure(&id, err, server, configuration, cx).await,
 660                };
 661                this.update(cx, |this, cx| {
 662                    this.update_server_state(id.clone(), new_state, cx)
 663                })
 664                .log_err();
 665            }
 666        });
 667
 668        self.update_server_state(
 669            id.clone(),
 670            ContextServerState::Starting {
 671                configuration,
 672                _task: task,
 673                server,
 674            },
 675            cx,
 676        );
 677    }
 678
 679    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 680        let state = self
 681            .servers
 682            .remove(id)
 683            .context("Context server not found")?;
 684
 685        if let ContextServerConfiguration::Http { url, .. } = state.configuration().as_ref() {
 686            let server_url = url.clone();
 687            let id = id.clone();
 688            cx.spawn(async move |_this, cx| {
 689                let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
 690                if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
 691                {
 692                    log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
 693                }
 694            })
 695            .detach();
 696        }
 697
 698        drop(state);
 699        cx.emit(ServerStatusChangedEvent {
 700            server_id: id.clone(),
 701            status: ContextServerStatus::Stopped,
 702        });
 703        Ok(())
 704    }
 705
 706    pub async fn create_context_server(
 707        this: WeakEntity<Self>,
 708        id: ContextServerId,
 709        configuration: Arc<ContextServerConfiguration>,
 710        cx: &mut AsyncApp,
 711    ) -> Result<(Arc<ContextServer>, Arc<ContextServerConfiguration>)> {
 712        let remote = configuration.remote();
 713        let needs_remote_command = match configuration.as_ref() {
 714            ContextServerConfiguration::Custom { .. }
 715            | ContextServerConfiguration::Extension { .. } => remote,
 716            ContextServerConfiguration::Http { .. } => false,
 717        };
 718
 719        let (remote_state, is_remote_project) = this.update(cx, |this, _| {
 720            let remote_state = match &this.state {
 721                ContextServerStoreState::Remote {
 722                    project_id,
 723                    upstream_client,
 724                } if needs_remote_command => Some((*project_id, upstream_client.clone())),
 725                _ => None,
 726            };
 727            (remote_state, this.is_remote_project())
 728        })?;
 729
 730        let root_path: Option<Arc<Path>> = this.update(cx, |this, cx| {
 731            this.project
 732                .as_ref()
 733                .and_then(|project| {
 734                    project
 735                        .read_with(cx, |project, cx| project.active_project_directory(cx))
 736                        .ok()
 737                        .flatten()
 738                })
 739                .or_else(|| {
 740                    this.worktree_store.read_with(cx, |store, cx| {
 741                        store.visible_worktrees(cx).fold(None, |acc, item| {
 742                            if acc.is_none() {
 743                                item.read(cx).root_dir()
 744                            } else {
 745                                acc
 746                            }
 747                        })
 748                    })
 749                })
 750        })?;
 751
 752        let configuration = if let Some((project_id, upstream_client)) = remote_state {
 753            let root_dir = root_path.as_ref().map(|p| p.display().to_string());
 754
 755            let response = upstream_client
 756                .update(cx, |client, _| {
 757                    client
 758                        .proto_client()
 759                        .request(proto::GetContextServerCommand {
 760                            project_id,
 761                            server_id: id.0.to_string(),
 762                            root_dir: root_dir.clone(),
 763                        })
 764                })
 765                .await?;
 766
 767            let remote_command = upstream_client.update(cx, |client, _| {
 768                client.build_command(
 769                    Some(response.path),
 770                    &response.args,
 771                    &response.env.into_iter().collect(),
 772                    root_dir,
 773                    None,
 774                )
 775            })?;
 776
 777            let command = ContextServerCommand {
 778                path: remote_command.program.into(),
 779                args: remote_command.args,
 780                env: Some(remote_command.env.into_iter().collect()),
 781                timeout: None,
 782            };
 783
 784            Arc::new(ContextServerConfiguration::Custom { command, remote })
 785        } else {
 786            configuration
 787        };
 788
 789        if let Some(server) = this.update(cx, |this, _| {
 790            this.context_server_factory
 791                .as_ref()
 792                .map(|factory| factory(id.clone(), configuration.clone()))
 793        })? {
 794            return Ok((server, configuration));
 795        }
 796
 797        let cached_token_provider: Option<Arc<dyn oauth::OAuthTokenProvider>> =
 798            if let ContextServerConfiguration::Http { url, .. } = configuration.as_ref() {
 799                if configuration.has_static_auth_header() {
 800                    None
 801                } else {
 802                    let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
 803                    let http_client = cx.update(|cx| cx.http_client());
 804
 805                    match Self::load_session(&credentials_provider, url, &cx).await {
 806                        Ok(Some(session)) => {
 807                            log::info!("{} loaded cached OAuth session from keychain", id);
 808                            Some(Self::create_oauth_token_provider(
 809                                &id,
 810                                url,
 811                                session,
 812                                http_client,
 813                                credentials_provider,
 814                                cx,
 815                            ))
 816                        }
 817                        Ok(None) => None,
 818                        Err(err) => {
 819                            log::warn!("{} failed to load cached OAuth session: {}", id, err);
 820                            None
 821                        }
 822                    }
 823                }
 824            } else {
 825                None
 826            };
 827
 828        let server: Arc<ContextServer> = this.update(cx, |this, cx| {
 829            let global_timeout =
 830                Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
 831
 832            match configuration.as_ref() {
 833                ContextServerConfiguration::Http {
 834                    url,
 835                    headers,
 836                    timeout,
 837                } => {
 838                    let transport = HttpTransport::new_with_token_provider(
 839                        cx.http_client(),
 840                        url.to_string(),
 841                        headers.clone(),
 842                        cx.background_executor().clone(),
 843                        cached_token_provider.clone(),
 844                    );
 845                    anyhow::Ok(Arc::new(ContextServer::new_with_timeout(
 846                        id,
 847                        Arc::new(transport),
 848                        Some(Duration::from_secs(
 849                            timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
 850                        )),
 851                    )))
 852                }
 853                _ => {
 854                    let mut command = configuration
 855                        .command()
 856                        .context("Missing command configuration for stdio context server")?
 857                        .clone();
 858                    command.timeout = Some(
 859                        command
 860                            .timeout
 861                            .unwrap_or(global_timeout)
 862                            .min(MAX_TIMEOUT_SECS),
 863                    );
 864
 865                    // Don't pass remote paths as working directory for locally-spawned processes
 866                    let working_directory = if is_remote_project { None } else { root_path };
 867                    anyhow::Ok(Arc::new(ContextServer::stdio(
 868                        id,
 869                        command,
 870                        working_directory,
 871                    )))
 872                }
 873            }
 874        })??;
 875
 876        Ok((server, configuration))
 877    }
 878
 879    async fn handle_get_context_server_command(
 880        this: Entity<Self>,
 881        envelope: TypedEnvelope<proto::GetContextServerCommand>,
 882        mut cx: AsyncApp,
 883    ) -> Result<proto::ContextServerCommand> {
 884        let server_id = ContextServerId(envelope.payload.server_id.into());
 885
 886        let (settings, registry, worktree_store) = this.update(&mut cx, |this, inner_cx| {
 887            let ContextServerStoreState::Local {
 888                is_headless: true, ..
 889            } = &this.state
 890            else {
 891                anyhow::bail!("unexpected GetContextServerCommand request in a non-local project");
 892            };
 893
 894            let settings = this
 895                .context_server_settings
 896                .get(&server_id.0)
 897                .cloned()
 898                .or_else(|| {
 899                    this.registry
 900                        .read(inner_cx)
 901                        .context_server_descriptor(&server_id.0)
 902                        .map(|_| ContextServerSettings::default_extension())
 903                })
 904                .with_context(|| format!("context server `{}` not found", server_id))?;
 905
 906            anyhow::Ok((settings, this.registry.clone(), this.worktree_store.clone()))
 907        })?;
 908
 909        let configuration = ContextServerConfiguration::from_settings(
 910            settings,
 911            server_id.clone(),
 912            registry,
 913            worktree_store,
 914            &cx,
 915        )
 916        .await
 917        .with_context(|| format!("failed to build configuration for `{}`", server_id))?;
 918
 919        let command = configuration
 920            .command()
 921            .context("context server has no command (HTTP servers don't need RPC)")?;
 922
 923        Ok(proto::ContextServerCommand {
 924            path: command.path.display().to_string(),
 925            args: command.args.clone(),
 926            env: command
 927                .env
 928                .clone()
 929                .map(|env| env.into_iter().collect())
 930                .unwrap_or_default(),
 931        })
 932    }
 933
 934    fn resolve_project_settings<'a>(
 935        worktree_store: &'a Entity<WorktreeStore>,
 936        cx: &'a App,
 937    ) -> &'a ProjectSettings {
 938        let location = worktree_store
 939            .read(cx)
 940            .visible_worktrees(cx)
 941            .next()
 942            .map(|worktree| settings::SettingsLocation {
 943                worktree_id: worktree.read(cx).id(),
 944                path: RelPath::empty(),
 945            });
 946        ProjectSettings::get(location, cx)
 947    }
 948
 949    fn create_oauth_token_provider(
 950        id: &ContextServerId,
 951        server_url: &url::Url,
 952        session: OAuthSession,
 953        http_client: Arc<dyn HttpClient>,
 954        credentials_provider: Arc<dyn CredentialsProvider>,
 955        cx: &mut AsyncApp,
 956    ) -> Arc<dyn oauth::OAuthTokenProvider> {
 957        let (token_refresh_tx, mut token_refresh_rx) = futures::channel::mpsc::unbounded();
 958        let id = id.clone();
 959        let server_url = server_url.clone();
 960
 961        cx.spawn(async move |cx| {
 962            while let Some(refreshed_session) = token_refresh_rx.next().await {
 963                if let Err(err) =
 964                    Self::store_session(&credentials_provider, &server_url, &refreshed_session, &cx)
 965                        .await
 966                {
 967                    log::warn!("{} failed to persist refreshed OAuth session: {}", id, err);
 968                }
 969            }
 970            log::debug!("{} OAuth session persistence task ended", id);
 971        })
 972        .detach();
 973
 974        Arc::new(McpOAuthTokenProvider::new(
 975            session,
 976            http_client,
 977            Some(token_refresh_tx),
 978        ))
 979    }
 980
 981    /// Initiate the OAuth browser flow for a server in the `AuthRequired` state.
 982    ///
 983    /// This starts a loopback HTTP callback server on an ephemeral port, builds
 984    /// the authorization URL, opens the user's browser, waits for the callback,
 985    /// exchanges the code for tokens, persists them in the keychain, and restarts
 986    /// the server with the new token provider.
 987    pub fn authenticate_server(
 988        &mut self,
 989        id: &ContextServerId,
 990        cx: &mut Context<Self>,
 991    ) -> Result<()> {
 992        let state = self.servers.get(id).context("Context server not found")?;
 993
 994        let (discovery, server, configuration) = match state {
 995            ContextServerState::AuthRequired {
 996                discovery,
 997                server,
 998                configuration,
 999            } => (discovery.clone(), server.clone(), configuration.clone()),
1000            _ => anyhow::bail!("Server is not in AuthRequired state"),
1001        };
1002
1003        let id = id.clone();
1004
1005        let task = cx.spawn({
1006            let id = id.clone();
1007            let server = server.clone();
1008            let configuration = configuration.clone();
1009            async move |this, cx| {
1010                let result = Self::run_oauth_flow(
1011                    this.clone(),
1012                    id.clone(),
1013                    discovery.clone(),
1014                    configuration.clone(),
1015                    cx,
1016                )
1017                .await;
1018
1019                if let Err(err) = &result {
1020                    log::error!("{} OAuth authentication failed: {:?}", id, err);
1021                    // Transition back to AuthRequired so the user can retry
1022                    // rather than landing in a terminal Error state.
1023                    this.update(cx, |this, cx| {
1024                        this.update_server_state(
1025                            id.clone(),
1026                            ContextServerState::AuthRequired {
1027                                server,
1028                                configuration,
1029                                discovery,
1030                            },
1031                            cx,
1032                        )
1033                    })
1034                    .log_err();
1035                }
1036            }
1037        });
1038
1039        self.update_server_state(
1040            id,
1041            ContextServerState::Authenticating {
1042                server,
1043                configuration,
1044                _task: task,
1045            },
1046            cx,
1047        );
1048
1049        Ok(())
1050    }
1051
1052    async fn run_oauth_flow(
1053        this: WeakEntity<Self>,
1054        id: ContextServerId,
1055        discovery: Arc<OAuthDiscovery>,
1056        configuration: Arc<ContextServerConfiguration>,
1057        cx: &mut AsyncApp,
1058    ) -> Result<()> {
1059        let resource = oauth::canonical_server_uri(&discovery.resource_metadata.resource);
1060        let pkce = oauth::generate_pkce_challenge();
1061
1062        let mut state_bytes = [0u8; 32];
1063        rand::rng().fill(&mut state_bytes);
1064        let state_param: String = state_bytes.iter().map(|b| format!("{:02x}", b)).collect();
1065
1066        // Start a loopback HTTP server on an ephemeral port. The redirect URI
1067        // includes this port so the browser sends the callback directly to our
1068        // process.
1069        let (redirect_uri, callback_rx) = oauth::start_callback_server()
1070            .await
1071            .context("Failed to start OAuth callback server")?;
1072
1073        let http_client = cx.update(|cx| cx.http_client());
1074        let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1075        let server_url = match configuration.as_ref() {
1076            ContextServerConfiguration::Http { url, .. } => url.clone(),
1077            _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1078        };
1079
1080        let client_registration =
1081            oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
1082                .await
1083                .context("Failed to resolve OAuth client registration")?;
1084
1085        let auth_url = oauth::build_authorization_url(
1086            &discovery.auth_server_metadata,
1087            &client_registration.client_id,
1088            &redirect_uri,
1089            &discovery.scopes,
1090            &resource,
1091            &pkce,
1092            &state_param,
1093        );
1094
1095        cx.update(|cx| cx.open_url(auth_url.as_str()));
1096
1097        let callback = callback_rx
1098            .await
1099            .map_err(|_| {
1100                anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
1101            })?
1102            .context("OAuth callback server received an invalid request")?;
1103
1104        if callback.state != state_param {
1105            anyhow::bail!("OAuth state parameter mismatch (possible CSRF)");
1106        }
1107
1108        let tokens = oauth::exchange_code(
1109            &http_client,
1110            &discovery.auth_server_metadata,
1111            &callback.code,
1112            &client_registration.client_id,
1113            &redirect_uri,
1114            &pkce.verifier,
1115            &resource,
1116        )
1117        .await
1118        .context("Failed to exchange authorization code for tokens")?;
1119
1120        let session = OAuthSession {
1121            token_endpoint: discovery.auth_server_metadata.token_endpoint.clone(),
1122            resource: discovery.resource_metadata.resource.clone(),
1123            client_registration,
1124            tokens,
1125        };
1126
1127        Self::store_session(&credentials_provider, &server_url, &session, cx)
1128            .await
1129            .context("Failed to persist OAuth session in keychain")?;
1130
1131        let token_provider = Self::create_oauth_token_provider(
1132            &id,
1133            &server_url,
1134            session,
1135            http_client.clone(),
1136            credentials_provider,
1137            cx,
1138        );
1139
1140        let new_server = this.update(cx, |this, cx| {
1141            let global_timeout =
1142                Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
1143
1144            match configuration.as_ref() {
1145                ContextServerConfiguration::Http {
1146                    url,
1147                    headers,
1148                    timeout,
1149                } => {
1150                    let transport = HttpTransport::new_with_token_provider(
1151                        http_client.clone(),
1152                        url.to_string(),
1153                        headers.clone(),
1154                        cx.background_executor().clone(),
1155                        Some(token_provider.clone()),
1156                    );
1157                    Ok(Arc::new(ContextServer::new_with_timeout(
1158                        id.clone(),
1159                        Arc::new(transport),
1160                        Some(Duration::from_secs(
1161                            timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
1162                        )),
1163                    )))
1164                }
1165                _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1166            }
1167        })??;
1168
1169        this.update(cx, |this, cx| {
1170            this.run_server(new_server, configuration, cx);
1171        })?;
1172
1173        Ok(())
1174    }
1175
1176    /// Store the full OAuth session in the system keychain, keyed by the
1177    /// server's canonical URI.
1178    async fn store_session(
1179        credentials_provider: &Arc<dyn CredentialsProvider>,
1180        server_url: &url::Url,
1181        session: &OAuthSession,
1182        cx: &AsyncApp,
1183    ) -> Result<()> {
1184        let key = Self::keychain_key(server_url);
1185        let json = serde_json::to_string(session)?;
1186        credentials_provider
1187            .write_credentials(&key, "mcp-oauth", json.as_bytes(), cx)
1188            .await
1189    }
1190
1191    /// Load the full OAuth session from the system keychain for the given
1192    /// server URL.
1193    async fn load_session(
1194        credentials_provider: &Arc<dyn CredentialsProvider>,
1195        server_url: &url::Url,
1196        cx: &AsyncApp,
1197    ) -> Result<Option<OAuthSession>> {
1198        let key = Self::keychain_key(server_url);
1199        match credentials_provider.read_credentials(&key, cx).await? {
1200            Some((_username, password_bytes)) => {
1201                let session: OAuthSession = serde_json::from_slice(&password_bytes)?;
1202                Ok(Some(session))
1203            }
1204            None => Ok(None),
1205        }
1206    }
1207
1208    /// Clear the stored OAuth session from the system keychain.
1209    async fn clear_session(
1210        credentials_provider: &Arc<dyn CredentialsProvider>,
1211        server_url: &url::Url,
1212        cx: &AsyncApp,
1213    ) -> Result<()> {
1214        let key = Self::keychain_key(server_url);
1215        credentials_provider.delete_credentials(&key, cx).await
1216    }
1217
1218    fn keychain_key(server_url: &url::Url) -> String {
1219        format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
1220    }
1221
1222    /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
1223    /// session from the keychain and stop the server.
1224    pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
1225        let state = self.servers.get(id).context("Context server not found")?;
1226        let configuration = state.configuration();
1227
1228        let server_url = match configuration.as_ref() {
1229            ContextServerConfiguration::Http { url, .. } => url.clone(),
1230            _ => anyhow::bail!("logout only applies to HTTP servers with OAuth"),
1231        };
1232
1233        let id = id.clone();
1234        self.stop_server(&id, cx)?;
1235
1236        cx.spawn(async move |this, cx| {
1237            let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1238            if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
1239                log::error!("{} failed to clear OAuth session: {}", id, err);
1240            }
1241            // Trigger server recreation so the next start uses a fresh
1242            // transport without the old (now-invalidated) token provider.
1243            this.update(cx, |this, cx| {
1244                this.available_context_servers_changed(cx);
1245            })
1246            .log_err();
1247        })
1248        .detach();
1249
1250        Ok(())
1251    }
1252
1253    fn update_server_state(
1254        &mut self,
1255        id: ContextServerId,
1256        state: ContextServerState,
1257        cx: &mut Context<Self>,
1258    ) {
1259        let status = ContextServerStatus::from_state(&state);
1260        self.servers.insert(id.clone(), state);
1261        cx.emit(ServerStatusChangedEvent {
1262            server_id: id,
1263            status,
1264        });
1265    }
1266
1267    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
1268        if self.update_servers_task.is_some() {
1269            self.needs_server_update = true;
1270        } else {
1271            self.needs_server_update = false;
1272            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
1273                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
1274                    log::error!("Error maintaining context servers: {}", err);
1275                }
1276
1277                this.update(cx, |this, cx| {
1278                    this.populate_server_ids(cx);
1279                    cx.notify();
1280                    this.update_servers_task.take();
1281                    if this.needs_server_update {
1282                        this.available_context_servers_changed(cx);
1283                    }
1284                })?;
1285
1286                Ok(())
1287            }));
1288        }
1289    }
1290
1291    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
1292        // Don't start context servers if AI is disabled
1293        let ai_disabled = this.update(cx, |_, cx| DisableAiSettings::get_global(cx).disable_ai)?;
1294        if ai_disabled {
1295            // Stop all running servers when AI is disabled
1296            this.update(cx, |this, cx| {
1297                let server_ids: Vec<_> = this.servers.keys().cloned().collect();
1298                for id in server_ids {
1299                    let _ = this.stop_server(&id, cx);
1300                }
1301            })?;
1302            return Ok(());
1303        }
1304
1305        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
1306            (
1307                this.context_server_settings.clone(),
1308                this.registry.clone(),
1309                this.worktree_store.clone(),
1310            )
1311        })?;
1312
1313        for (id, _) in registry.read_with(cx, |registry, _| registry.context_server_descriptors()) {
1314            configured_servers
1315                .entry(id)
1316                .or_insert(ContextServerSettings::default_extension());
1317        }
1318
1319        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
1320            configured_servers
1321                .into_iter()
1322                .partition(|(_, settings)| settings.enabled());
1323
1324        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
1325            let id = ContextServerId(id);
1326            ContextServerConfiguration::from_settings(
1327                settings,
1328                id.clone(),
1329                registry.clone(),
1330                worktree_store.clone(),
1331                cx,
1332            )
1333            .map(move |config| (id, config))
1334        }))
1335        .await
1336        .into_iter()
1337        .filter_map(|(id, config)| config.map(|config| (id, config)))
1338        .collect::<HashMap<_, _>>();
1339
1340        let mut servers_to_start = Vec::new();
1341        let mut servers_to_remove = HashSet::default();
1342        let mut servers_to_stop = HashSet::default();
1343
1344        this.update(cx, |this, _cx| {
1345            for server_id in this.servers.keys() {
1346                // All servers that are not in desired_servers should be removed from the store.
1347                // This can happen if the user removed a server from the context server settings.
1348                if !configured_servers.contains_key(server_id) {
1349                    if disabled_servers.contains_key(&server_id.0) {
1350                        servers_to_stop.insert(server_id.clone());
1351                    } else {
1352                        servers_to_remove.insert(server_id.clone());
1353                    }
1354                }
1355            }
1356
1357            for (id, config) in configured_servers {
1358                let state = this.servers.get(&id);
1359                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
1360                let existing_config = state.as_ref().map(|state| state.configuration());
1361                if existing_config.as_deref() != Some(&config) || is_stopped {
1362                    let config = Arc::new(config);
1363                    servers_to_start.push((id.clone(), config));
1364                    if this.servers.contains_key(&id) {
1365                        servers_to_stop.insert(id);
1366                    }
1367                }
1368            }
1369
1370            anyhow::Ok(())
1371        })??;
1372
1373        this.update(cx, |this, inner_cx| {
1374            for id in servers_to_stop {
1375                this.stop_server(&id, inner_cx)?;
1376            }
1377            for id in servers_to_remove {
1378                this.remove_server(&id, inner_cx)?;
1379            }
1380            anyhow::Ok(())
1381        })??;
1382
1383        for (id, config) in servers_to_start {
1384            match Self::create_context_server(this.clone(), id.clone(), config, cx).await {
1385                Ok((server, config)) => {
1386                    this.update(cx, |this, cx| {
1387                        this.run_server(server, config, cx);
1388                    })?;
1389                }
1390                Err(err) => {
1391                    log::error!("{id} context server failed to create: {err:#}");
1392                    this.update(cx, |_this, cx| {
1393                        cx.emit(ServerStatusChangedEvent {
1394                            server_id: id,
1395                            status: ContextServerStatus::Error(err.to_string().into()),
1396                        });
1397                        cx.notify();
1398                    })?;
1399                }
1400            }
1401        }
1402
1403        Ok(())
1404    }
1405}
1406
1407/// Determines the appropriate server state after a start attempt fails.
1408///
1409/// When the error is an HTTP 401 with no static auth header configured,
1410/// attempts OAuth discovery so the UI can offer an authentication flow.
1411async fn resolve_start_failure(
1412    id: &ContextServerId,
1413    err: anyhow::Error,
1414    server: Arc<ContextServer>,
1415    configuration: Arc<ContextServerConfiguration>,
1416    cx: &AsyncApp,
1417) -> ContextServerState {
1418    let www_authenticate = err.downcast_ref::<TransportError>().map(|e| match e {
1419        TransportError::AuthRequired { www_authenticate } => www_authenticate.clone(),
1420    });
1421
1422    if www_authenticate.is_some() && configuration.has_static_auth_header() {
1423        log::warn!("{id} received 401 with a static Authorization header configured");
1424        return ContextServerState::Error {
1425            configuration,
1426            server,
1427            error: "Server returned 401 Unauthorized. Check your configured Authorization header."
1428                .into(),
1429        };
1430    }
1431
1432    let server_url = match configuration.as_ref() {
1433        ContextServerConfiguration::Http { url, .. } if !configuration.has_static_auth_header() => {
1434            url.clone()
1435        }
1436        _ => {
1437            if www_authenticate.is_some() {
1438                log::error!("{id} got OAuth 401 on a non-HTTP transport or with static auth");
1439            } else {
1440                log::error!("{id} context server failed to start: {err}");
1441            }
1442            return ContextServerState::Error {
1443                configuration,
1444                server,
1445                error: err.to_string().into(),
1446            };
1447        }
1448    };
1449
1450    // When the error is NOT a 401 but there is a cached OAuth session in the
1451    // keychain, the session is likely stale/expired and caused the failure
1452    // (e.g. timeout because the server rejected the token silently). Clear it
1453    // so the next start attempt can get a clean 401 and trigger the auth flow.
1454    if www_authenticate.is_none() {
1455        let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1456        match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
1457            Ok(Some(_)) => {
1458                log::info!("{id} start failed with a cached OAuth session present; clearing it");
1459                ContextServerStore::clear_session(&credentials_provider, &server_url, cx)
1460                    .await
1461                    .log_err();
1462            }
1463            _ => {
1464                log::error!("{id} context server failed to start: {err}");
1465                return ContextServerState::Error {
1466                    configuration,
1467                    server,
1468                    error: err.to_string().into(),
1469                };
1470            }
1471        }
1472    }
1473
1474    let default_www_authenticate = oauth::WwwAuthenticate {
1475        resource_metadata: None,
1476        scope: None,
1477        error: None,
1478        error_description: None,
1479    };
1480    let www_authenticate = www_authenticate
1481        .as_ref()
1482        .unwrap_or(&default_www_authenticate);
1483    let http_client = cx.update(|cx| cx.http_client());
1484
1485    match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
1486        Ok(discovery) => {
1487            log::info!(
1488                "{id} requires OAuth authorization (auth server: {})",
1489                discovery.auth_server_metadata.issuer,
1490            );
1491            ContextServerState::AuthRequired {
1492                server,
1493                configuration,
1494                discovery: Arc::new(discovery),
1495            }
1496        }
1497        Err(discovery_err) => {
1498            log::error!("{id} OAuth discovery failed: {discovery_err}");
1499            ContextServerState::Error {
1500                configuration,
1501                server,
1502                error: format!("OAuth discovery failed: {discovery_err}").into(),
1503            }
1504        }
1505    }
1506}