context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::path::Path;
   5use std::sync::Arc;
   6use std::time::Duration;
   7
   8use anyhow::{Context as _, Result};
   9use collections::{HashMap, HashSet};
  10use context_server::oauth::{self, McpOAuthTokenProvider, OAuthDiscovery, OAuthSession};
  11use context_server::transport::{HttpTransport, TransportError};
  12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  13use credentials_provider::CredentialsProvider;
  14use futures::future::Either;
  15use futures::{FutureExt as _, StreamExt as _, future::join_all};
  16use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  17use http_client::HttpClient;
  18use itertools::Itertools;
  19use rand::Rng as _;
  20use registry::ContextServerDescriptorRegistry;
  21use remote::RemoteClient;
  22use rpc::{AnyProtoClient, TypedEnvelope, proto};
  23use settings::{Settings as _, SettingsStore};
  24use util::{ResultExt as _, rel_path::RelPath};
  25
  26use crate::{
  27    DisableAiSettings, Project,
  28    project_settings::{ContextServerSettings, ProjectSettings},
  29    worktree_store::WorktreeStore,
  30};
  31
  32/// Maximum timeout for context server requests
  33/// Prevents extremely large timeout values from tying up resources indefinitely.
  34const MAX_TIMEOUT_SECS: u64 = 600; // 10 minutes
  35
  36pub fn init(cx: &mut App) {
  37    extension::init(cx);
  38}
  39
  40actions!(
  41    context_server,
  42    [
  43        /// Restarts the context server.
  44        Restart
  45    ]
  46);
  47
  48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  49pub enum ContextServerStatus {
  50    Starting,
  51    Running,
  52    Stopped,
  53    Error(Arc<str>),
  54    /// The server returned 401 and OAuth authorization is needed. The UI
  55    /// should show an "Authenticate" button.
  56    AuthRequired,
  57    /// The OAuth browser flow is in progress — the user has been redirected
  58    /// to the authorization server and we're waiting for the callback.
  59    Authenticating,
  60}
  61
  62impl ContextServerStatus {
  63    fn from_state(state: &ContextServerState) -> Self {
  64        match state {
  65            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  66            ContextServerState::Running { .. } => ContextServerStatus::Running,
  67            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  68            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  69            ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
  70            ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
  71        }
  72    }
  73}
  74
  75enum ContextServerState {
  76    Starting {
  77        server: Arc<ContextServer>,
  78        configuration: Arc<ContextServerConfiguration>,
  79        _task: Task<()>,
  80    },
  81    Running {
  82        server: Arc<ContextServer>,
  83        configuration: Arc<ContextServerConfiguration>,
  84    },
  85    Stopped {
  86        server: Arc<ContextServer>,
  87        configuration: Arc<ContextServerConfiguration>,
  88    },
  89    Error {
  90        server: Arc<ContextServer>,
  91        configuration: Arc<ContextServerConfiguration>,
  92        error: Arc<str>,
  93    },
  94    /// The server requires OAuth authorization before it can be used. The
  95    /// `OAuthDiscovery` holds everything needed to start the browser flow.
  96    AuthRequired {
  97        server: Arc<ContextServer>,
  98        configuration: Arc<ContextServerConfiguration>,
  99        discovery: Arc<OAuthDiscovery>,
 100    },
 101    /// The OAuth browser flow is in progress. The user has been redirected
 102    /// to the authorization server and we're waiting for the callback.
 103    Authenticating {
 104        server: Arc<ContextServer>,
 105        configuration: Arc<ContextServerConfiguration>,
 106        _task: Task<()>,
 107    },
 108}
 109
 110impl ContextServerState {
 111    pub fn server(&self) -> Arc<ContextServer> {
 112        match self {
 113            ContextServerState::Starting { server, .. }
 114            | ContextServerState::Running { server, .. }
 115            | ContextServerState::Stopped { server, .. }
 116            | ContextServerState::Error { server, .. }
 117            | ContextServerState::AuthRequired { server, .. }
 118            | ContextServerState::Authenticating { server, .. } => server.clone(),
 119        }
 120    }
 121
 122    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
 123        match self {
 124            ContextServerState::Starting { configuration, .. }
 125            | ContextServerState::Running { configuration, .. }
 126            | ContextServerState::Stopped { configuration, .. }
 127            | ContextServerState::Error { configuration, .. }
 128            | ContextServerState::AuthRequired { configuration, .. }
 129            | ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
 130        }
 131    }
 132}
 133
 134#[derive(Debug, PartialEq, Eq)]
 135pub enum ContextServerConfiguration {
 136    Custom {
 137        command: ContextServerCommand,
 138        remote: bool,
 139    },
 140    Extension {
 141        command: ContextServerCommand,
 142        settings: serde_json::Value,
 143        remote: bool,
 144    },
 145    Http {
 146        url: url::Url,
 147        headers: HashMap<String, String>,
 148        timeout: Option<u64>,
 149    },
 150}
 151
 152impl ContextServerConfiguration {
 153    pub fn command(&self) -> Option<&ContextServerCommand> {
 154        match self {
 155            ContextServerConfiguration::Custom { command, .. } => Some(command),
 156            ContextServerConfiguration::Extension { command, .. } => Some(command),
 157            ContextServerConfiguration::Http { .. } => None,
 158        }
 159    }
 160
 161    pub fn has_static_auth_header(&self) -> bool {
 162        match self {
 163            ContextServerConfiguration::Http { headers, .. } => headers
 164                .keys()
 165                .any(|k| k.eq_ignore_ascii_case("authorization")),
 166            _ => false,
 167        }
 168    }
 169
 170    pub fn remote(&self) -> bool {
 171        match self {
 172            ContextServerConfiguration::Custom { remote, .. } => *remote,
 173            ContextServerConfiguration::Extension { remote, .. } => *remote,
 174            ContextServerConfiguration::Http { .. } => false,
 175        }
 176    }
 177
 178    pub async fn from_settings(
 179        settings: ContextServerSettings,
 180        id: ContextServerId,
 181        registry: Entity<ContextServerDescriptorRegistry>,
 182        worktree_store: Entity<WorktreeStore>,
 183        cx: &AsyncApp,
 184    ) -> Option<Self> {
 185        const EXTENSION_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
 186
 187        match settings {
 188            ContextServerSettings::Stdio {
 189                enabled: _,
 190                command,
 191                remote,
 192            } => Some(ContextServerConfiguration::Custom { command, remote }),
 193            ContextServerSettings::Extension {
 194                enabled: _,
 195                settings,
 196                remote,
 197            } => {
 198                let descriptor =
 199                    cx.update(|cx| registry.read(cx).context_server_descriptor(&id.0))?;
 200
 201                let command_future = descriptor.command(worktree_store, cx);
 202                let timeout_future = cx.background_executor().timer(EXTENSION_COMMAND_TIMEOUT);
 203
 204                match futures::future::select(command_future, timeout_future).await {
 205                    Either::Left((Ok(command), _)) => Some(ContextServerConfiguration::Extension {
 206                        command,
 207                        settings,
 208                        remote,
 209                    }),
 210                    Either::Left((Err(e), _)) => {
 211                        log::error!(
 212                            "Failed to create context server configuration from settings: {e:#}"
 213                        );
 214                        None
 215                    }
 216                    Either::Right(_) => {
 217                        log::error!(
 218                            "Timed out resolving command for extension context server {id}"
 219                        );
 220                        None
 221                    }
 222                }
 223            }
 224            ContextServerSettings::Http {
 225                enabled: _,
 226                url,
 227                headers: auth,
 228                timeout,
 229            } => {
 230                let url = url::Url::parse(&url).log_err()?;
 231                Some(ContextServerConfiguration::Http {
 232                    url,
 233                    headers: auth,
 234                    timeout,
 235                })
 236            }
 237        }
 238    }
 239}
 240
 241pub type ContextServerFactory =
 242    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 243
 244enum ContextServerStoreState {
 245    Local {
 246        downstream_client: Option<(u64, AnyProtoClient)>,
 247        is_headless: bool,
 248    },
 249    Remote {
 250        project_id: u64,
 251        upstream_client: Entity<RemoteClient>,
 252    },
 253}
 254
 255pub struct ContextServerStore {
 256    state: ContextServerStoreState,
 257    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 258    servers: HashMap<ContextServerId, ContextServerState>,
 259    server_ids: Vec<ContextServerId>,
 260    worktree_store: Entity<WorktreeStore>,
 261    project: Option<WeakEntity<Project>>,
 262    registry: Entity<ContextServerDescriptorRegistry>,
 263    update_servers_task: Option<Task<Result<()>>>,
 264    context_server_factory: Option<ContextServerFactory>,
 265    needs_server_update: bool,
 266    ai_disabled: bool,
 267    _subscriptions: Vec<Subscription>,
 268}
 269
 270pub struct ServerStatusChangedEvent {
 271    pub server_id: ContextServerId,
 272    pub status: ContextServerStatus,
 273}
 274
 275impl EventEmitter<ServerStatusChangedEvent> for ContextServerStore {}
 276
 277impl ContextServerStore {
 278    pub fn local(
 279        worktree_store: Entity<WorktreeStore>,
 280        weak_project: Option<WeakEntity<Project>>,
 281        headless: bool,
 282        cx: &mut Context<Self>,
 283    ) -> Self {
 284        Self::new_internal(
 285            !headless,
 286            None,
 287            ContextServerDescriptorRegistry::default_global(cx),
 288            worktree_store,
 289            weak_project,
 290            ContextServerStoreState::Local {
 291                downstream_client: None,
 292                is_headless: headless,
 293            },
 294            cx,
 295        )
 296    }
 297
 298    pub fn remote(
 299        project_id: u64,
 300        upstream_client: Entity<RemoteClient>,
 301        worktree_store: Entity<WorktreeStore>,
 302        weak_project: Option<WeakEntity<Project>>,
 303        cx: &mut Context<Self>,
 304    ) -> Self {
 305        Self::new_internal(
 306            true,
 307            None,
 308            ContextServerDescriptorRegistry::default_global(cx),
 309            worktree_store,
 310            weak_project,
 311            ContextServerStoreState::Remote {
 312                project_id,
 313                upstream_client,
 314            },
 315            cx,
 316        )
 317    }
 318
 319    pub fn init_headless(session: &AnyProtoClient) {
 320        session.add_entity_request_handler(Self::handle_get_context_server_command);
 321    }
 322
 323    pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) {
 324        if let ContextServerStoreState::Local {
 325            downstream_client, ..
 326        } = &mut self.state
 327        {
 328            *downstream_client = Some((project_id, client));
 329        }
 330    }
 331
 332    pub fn is_remote_project(&self) -> bool {
 333        matches!(self.state, ContextServerStoreState::Remote { .. })
 334    }
 335
 336    /// Returns all configured context server ids, excluding the ones that are disabled
 337    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 338        self.context_server_settings
 339            .iter()
 340            .filter(|(_, settings)| settings.enabled())
 341            .map(|(id, _)| ContextServerId(id.clone()))
 342            .collect()
 343    }
 344
 345    #[cfg(feature = "test-support")]
 346    pub fn test(
 347        registry: Entity<ContextServerDescriptorRegistry>,
 348        worktree_store: Entity<WorktreeStore>,
 349        weak_project: Option<WeakEntity<Project>>,
 350        cx: &mut Context<Self>,
 351    ) -> Self {
 352        Self::new_internal(
 353            false,
 354            None,
 355            registry,
 356            worktree_store,
 357            weak_project,
 358            ContextServerStoreState::Local {
 359                downstream_client: None,
 360                is_headless: false,
 361            },
 362            cx,
 363        )
 364    }
 365
 366    #[cfg(feature = "test-support")]
 367    pub fn test_maintain_server_loop(
 368        context_server_factory: Option<ContextServerFactory>,
 369        registry: Entity<ContextServerDescriptorRegistry>,
 370        worktree_store: Entity<WorktreeStore>,
 371        weak_project: Option<WeakEntity<Project>>,
 372        cx: &mut Context<Self>,
 373    ) -> Self {
 374        Self::new_internal(
 375            true,
 376            context_server_factory,
 377            registry,
 378            worktree_store,
 379            weak_project,
 380            ContextServerStoreState::Local {
 381                downstream_client: None,
 382                is_headless: false,
 383            },
 384            cx,
 385        )
 386    }
 387
 388    #[cfg(feature = "test-support")]
 389    pub fn set_context_server_factory(&mut self, factory: ContextServerFactory) {
 390        self.context_server_factory = Some(factory);
 391    }
 392
 393    #[cfg(feature = "test-support")]
 394    pub fn registry(&self) -> &Entity<ContextServerDescriptorRegistry> {
 395        &self.registry
 396    }
 397
 398    #[cfg(feature = "test-support")]
 399    pub fn test_start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 400        let configuration = Arc::new(ContextServerConfiguration::Custom {
 401            command: ContextServerCommand {
 402                path: "test".into(),
 403                args: vec![],
 404                env: None,
 405                timeout: None,
 406            },
 407            remote: false,
 408        });
 409        self.run_server(server, configuration, cx);
 410    }
 411
 412    fn new_internal(
 413        maintain_server_loop: bool,
 414        context_server_factory: Option<ContextServerFactory>,
 415        registry: Entity<ContextServerDescriptorRegistry>,
 416        worktree_store: Entity<WorktreeStore>,
 417        weak_project: Option<WeakEntity<Project>>,
 418        state: ContextServerStoreState,
 419        cx: &mut Context<Self>,
 420    ) -> Self {
 421        let mut subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
 422            let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
 423            let ai_was_disabled = this.ai_disabled;
 424            this.ai_disabled = ai_disabled;
 425
 426            let settings =
 427                &Self::resolve_project_settings(&this.worktree_store, cx).context_servers;
 428            let settings_changed = &this.context_server_settings != settings;
 429
 430            if settings_changed {
 431                this.context_server_settings = settings.clone();
 432            }
 433
 434            // When AI is disabled, stop all running servers
 435            if ai_disabled {
 436                let server_ids: Vec<_> = this.servers.keys().cloned().collect();
 437                for id in server_ids {
 438                    this.stop_server(&id, cx).log_err();
 439                }
 440                return;
 441            }
 442
 443            // Trigger updates if AI was re-enabled or settings changed
 444            if maintain_server_loop && (ai_was_disabled || settings_changed) {
 445                this.available_context_servers_changed(cx);
 446            }
 447        })];
 448
 449        if maintain_server_loop {
 450            subscriptions.push(cx.observe(&registry, |this, _registry, cx| {
 451                if !DisableAiSettings::get_global(cx).disable_ai {
 452                    this.available_context_servers_changed(cx);
 453                }
 454            }));
 455        }
 456
 457        let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
 458        let mut this = Self {
 459            state,
 460            _subscriptions: subscriptions,
 461            context_server_settings: Self::resolve_project_settings(&worktree_store, cx)
 462                .context_servers
 463                .clone(),
 464            worktree_store,
 465            project: weak_project,
 466            registry,
 467            needs_server_update: false,
 468            ai_disabled,
 469            servers: HashMap::default(),
 470            server_ids: Default::default(),
 471            update_servers_task: None,
 472            context_server_factory,
 473        };
 474        if maintain_server_loop && !DisableAiSettings::get_global(cx).disable_ai {
 475            this.available_context_servers_changed(cx);
 476        }
 477        this
 478    }
 479
 480    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 481        self.servers.get(id).map(|state| state.server())
 482    }
 483
 484    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 485        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 486            Some(server.clone())
 487        } else {
 488            None
 489        }
 490    }
 491
 492    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 493        self.servers.get(id).map(ContextServerStatus::from_state)
 494    }
 495
 496    pub fn configuration_for_server(
 497        &self,
 498        id: &ContextServerId,
 499    ) -> Option<Arc<ContextServerConfiguration>> {
 500        self.servers.get(id).map(|state| state.configuration())
 501    }
 502
 503    /// Returns a sorted slice of available unique context server IDs. Within the
 504    /// slice, context servers which have `mcp-server-` as a prefix in their ID will
 505    /// appear after servers that do not have this prefix in their ID.
 506    pub fn server_ids(&self) -> &[ContextServerId] {
 507        self.server_ids.as_slice()
 508    }
 509
 510    fn populate_server_ids(&mut self, cx: &App) {
 511        self.server_ids = self
 512            .servers
 513            .keys()
 514            .cloned()
 515            .chain(
 516                self.registry
 517                    .read(cx)
 518                    .context_server_descriptors()
 519                    .into_iter()
 520                    .map(|(id, _)| ContextServerId(id)),
 521            )
 522            .chain(
 523                self.context_server_settings
 524                    .keys()
 525                    .map(|id| ContextServerId(id.clone())),
 526            )
 527            .unique()
 528            .sorted_unstable_by(
 529                // Sort context servers: ones without mcp-server- prefix first, then prefixed ones
 530                |a, b| {
 531                    const MCP_PREFIX: &str = "mcp-server-";
 532                    match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) {
 533                        // If one has mcp-server- prefix and other doesn't, non-mcp comes first
 534                        (Some(_), None) => std::cmp::Ordering::Greater,
 535                        (None, Some(_)) => std::cmp::Ordering::Less,
 536                        // If both have same prefix status, sort by appropriate key
 537                        (Some(a), Some(b)) => a.cmp(b),
 538                        (None, None) => a.0.cmp(&b.0),
 539                    }
 540                },
 541            )
 542            .collect();
 543    }
 544
 545    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 546        self.servers
 547            .values()
 548            .filter_map(|state| {
 549                if let ContextServerState::Running { server, .. } = state {
 550                    Some(server.clone())
 551                } else {
 552                    None
 553                }
 554            })
 555            .collect()
 556    }
 557
 558    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 559        cx.spawn(async move |this, cx| {
 560            let this = this.upgrade().context("Context server store dropped")?;
 561            let id = server.id();
 562            let settings = this
 563                .update(cx, |this, _| {
 564                    this.context_server_settings.get(&id.0).cloned()
 565                })
 566                .context("Failed to get context server settings")?;
 567
 568            if !settings.enabled() {
 569                return anyhow::Ok(());
 570            }
 571
 572            let (registry, worktree_store) = this.update(cx, |this, _| {
 573                (this.registry.clone(), this.worktree_store.clone())
 574            });
 575            let configuration = ContextServerConfiguration::from_settings(
 576                settings,
 577                id.clone(),
 578                registry,
 579                worktree_store,
 580                cx,
 581            )
 582            .await
 583            .context("Failed to create context server configuration")?;
 584
 585            this.update(cx, |this, cx| {
 586                this.run_server(server, Arc::new(configuration), cx)
 587            });
 588            Ok(())
 589        })
 590        .detach_and_log_err(cx);
 591    }
 592
 593    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 594        if matches!(
 595            self.servers.get(id),
 596            Some(ContextServerState::Stopped { .. })
 597        ) {
 598            return Ok(());
 599        }
 600
 601        let state = self
 602            .servers
 603            .remove(id)
 604            .context("Context server not found")?;
 605
 606        let server = state.server();
 607        let configuration = state.configuration();
 608        let mut result = Ok(());
 609        if let ContextServerState::Running { server, .. } = &state {
 610            result = server.stop();
 611        }
 612        drop(state);
 613
 614        self.update_server_state(
 615            id.clone(),
 616            ContextServerState::Stopped {
 617                configuration,
 618                server,
 619            },
 620            cx,
 621        );
 622
 623        result
 624    }
 625
 626    fn run_server(
 627        &mut self,
 628        server: Arc<ContextServer>,
 629        configuration: Arc<ContextServerConfiguration>,
 630        cx: &mut Context<Self>,
 631    ) {
 632        let id = server.id();
 633        if matches!(
 634            self.servers.get(&id),
 635            Some(
 636                ContextServerState::Starting { .. }
 637                    | ContextServerState::Running { .. }
 638                    | ContextServerState::Authenticating { .. },
 639            )
 640        ) {
 641            self.stop_server(&id, cx).log_err();
 642        }
 643        let task = cx.spawn({
 644            let id = server.id();
 645            let server = server.clone();
 646            let configuration = configuration.clone();
 647
 648            async move |this, cx| {
 649                let new_state = match server.clone().start(cx).await {
 650                    Ok(_) => {
 651                        debug_assert!(server.client().is_some());
 652                        ContextServerState::Running {
 653                            server,
 654                            configuration,
 655                        }
 656                    }
 657                    Err(err) => resolve_start_failure(&id, err, server, configuration, cx).await,
 658                };
 659                this.update(cx, |this, cx| {
 660                    this.update_server_state(id.clone(), new_state, cx)
 661                })
 662                .log_err();
 663            }
 664        });
 665
 666        self.update_server_state(
 667            id.clone(),
 668            ContextServerState::Starting {
 669                configuration,
 670                _task: task,
 671                server,
 672            },
 673            cx,
 674        );
 675    }
 676
 677    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 678        let state = self
 679            .servers
 680            .remove(id)
 681            .context("Context server not found")?;
 682
 683        if let ContextServerConfiguration::Http { url, .. } = state.configuration().as_ref() {
 684            let server_url = url.clone();
 685            let id = id.clone();
 686            cx.spawn(async move |_this, cx| {
 687                let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
 688                if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
 689                {
 690                    log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
 691                }
 692            })
 693            .detach();
 694        }
 695
 696        drop(state);
 697        cx.emit(ServerStatusChangedEvent {
 698            server_id: id.clone(),
 699            status: ContextServerStatus::Stopped,
 700        });
 701        Ok(())
 702    }
 703
 704    pub async fn create_context_server(
 705        this: WeakEntity<Self>,
 706        id: ContextServerId,
 707        configuration: Arc<ContextServerConfiguration>,
 708        cx: &mut AsyncApp,
 709    ) -> Result<(Arc<ContextServer>, Arc<ContextServerConfiguration>)> {
 710        let remote = configuration.remote();
 711        let needs_remote_command = match configuration.as_ref() {
 712            ContextServerConfiguration::Custom { .. }
 713            | ContextServerConfiguration::Extension { .. } => remote,
 714            ContextServerConfiguration::Http { .. } => false,
 715        };
 716
 717        let (remote_state, is_remote_project) = this.update(cx, |this, _| {
 718            let remote_state = match &this.state {
 719                ContextServerStoreState::Remote {
 720                    project_id,
 721                    upstream_client,
 722                } if needs_remote_command => Some((*project_id, upstream_client.clone())),
 723                _ => None,
 724            };
 725            (remote_state, this.is_remote_project())
 726        })?;
 727
 728        let root_path: Option<Arc<Path>> = this.update(cx, |this, cx| {
 729            this.project
 730                .as_ref()
 731                .and_then(|project| {
 732                    project
 733                        .read_with(cx, |project, cx| project.active_project_directory(cx))
 734                        .ok()
 735                        .flatten()
 736                })
 737                .or_else(|| {
 738                    this.worktree_store.read_with(cx, |store, cx| {
 739                        store.visible_worktrees(cx).fold(None, |acc, item| {
 740                            if acc.is_none() {
 741                                item.read(cx).root_dir()
 742                            } else {
 743                                acc
 744                            }
 745                        })
 746                    })
 747                })
 748        })?;
 749
 750        let configuration = if let Some((project_id, upstream_client)) = remote_state {
 751            let root_dir = root_path.as_ref().map(|p| p.display().to_string());
 752
 753            let response = upstream_client
 754                .update(cx, |client, _| {
 755                    client
 756                        .proto_client()
 757                        .request(proto::GetContextServerCommand {
 758                            project_id,
 759                            server_id: id.0.to_string(),
 760                            root_dir: root_dir.clone(),
 761                        })
 762                })
 763                .await?;
 764
 765            let remote_command = upstream_client.update(cx, |client, _| {
 766                client.build_command(
 767                    Some(response.path),
 768                    &response.args,
 769                    &response.env.into_iter().collect(),
 770                    root_dir,
 771                    None,
 772                )
 773            })?;
 774
 775            let command = ContextServerCommand {
 776                path: remote_command.program.into(),
 777                args: remote_command.args,
 778                env: Some(remote_command.env.into_iter().collect()),
 779                timeout: None,
 780            };
 781
 782            Arc::new(ContextServerConfiguration::Custom { command, remote })
 783        } else {
 784            configuration
 785        };
 786
 787        if let Some(server) = this.update(cx, |this, _| {
 788            this.context_server_factory
 789                .as_ref()
 790                .map(|factory| factory(id.clone(), configuration.clone()))
 791        })? {
 792            return Ok((server, configuration));
 793        }
 794
 795        let cached_token_provider: Option<Arc<dyn oauth::OAuthTokenProvider>> =
 796            if let ContextServerConfiguration::Http { url, .. } = configuration.as_ref() {
 797                if configuration.has_static_auth_header() {
 798                    None
 799                } else {
 800                    let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
 801                    let http_client = cx.update(|cx| cx.http_client());
 802
 803                    match Self::load_session(&credentials_provider, url, &cx).await {
 804                        Ok(Some(session)) => {
 805                            log::info!("{} loaded cached OAuth session from keychain", id);
 806                            Some(Self::create_oauth_token_provider(
 807                                &id,
 808                                url,
 809                                session,
 810                                http_client,
 811                                credentials_provider,
 812                                cx,
 813                            ))
 814                        }
 815                        Ok(None) => None,
 816                        Err(err) => {
 817                            log::warn!("{} failed to load cached OAuth session: {}", id, err);
 818                            None
 819                        }
 820                    }
 821                }
 822            } else {
 823                None
 824            };
 825
 826        let server: Arc<ContextServer> = this.update(cx, |this, cx| {
 827            let global_timeout =
 828                Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
 829
 830            match configuration.as_ref() {
 831                ContextServerConfiguration::Http {
 832                    url,
 833                    headers,
 834                    timeout,
 835                } => {
 836                    let transport = HttpTransport::new_with_token_provider(
 837                        cx.http_client(),
 838                        url.to_string(),
 839                        headers.clone(),
 840                        cx.background_executor().clone(),
 841                        cached_token_provider.clone(),
 842                    );
 843                    anyhow::Ok(Arc::new(ContextServer::new_with_timeout(
 844                        id,
 845                        Arc::new(transport),
 846                        Some(Duration::from_secs(
 847                            timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
 848                        )),
 849                    )))
 850                }
 851                _ => {
 852                    let mut command = configuration
 853                        .command()
 854                        .context("Missing command configuration for stdio context server")?
 855                        .clone();
 856                    command.timeout = Some(
 857                        command
 858                            .timeout
 859                            .unwrap_or(global_timeout)
 860                            .min(MAX_TIMEOUT_SECS),
 861                    );
 862
 863                    // Don't pass remote paths as working directory for locally-spawned processes
 864                    let working_directory = if is_remote_project { None } else { root_path };
 865                    anyhow::Ok(Arc::new(ContextServer::stdio(
 866                        id,
 867                        command,
 868                        working_directory,
 869                    )))
 870                }
 871            }
 872        })??;
 873
 874        Ok((server, configuration))
 875    }
 876
 877    async fn handle_get_context_server_command(
 878        this: Entity<Self>,
 879        envelope: TypedEnvelope<proto::GetContextServerCommand>,
 880        mut cx: AsyncApp,
 881    ) -> Result<proto::ContextServerCommand> {
 882        let server_id = ContextServerId(envelope.payload.server_id.into());
 883
 884        let (settings, registry, worktree_store) = this.update(&mut cx, |this, inner_cx| {
 885            let ContextServerStoreState::Local {
 886                is_headless: true, ..
 887            } = &this.state
 888            else {
 889                anyhow::bail!("unexpected GetContextServerCommand request in a non-local project");
 890            };
 891
 892            let settings = this
 893                .context_server_settings
 894                .get(&server_id.0)
 895                .cloned()
 896                .or_else(|| {
 897                    this.registry
 898                        .read(inner_cx)
 899                        .context_server_descriptor(&server_id.0)
 900                        .map(|_| ContextServerSettings::default_extension())
 901                })
 902                .with_context(|| format!("context server `{}` not found", server_id))?;
 903
 904            anyhow::Ok((settings, this.registry.clone(), this.worktree_store.clone()))
 905        })?;
 906
 907        let configuration = ContextServerConfiguration::from_settings(
 908            settings,
 909            server_id.clone(),
 910            registry,
 911            worktree_store,
 912            &cx,
 913        )
 914        .await
 915        .with_context(|| format!("failed to build configuration for `{}`", server_id))?;
 916
 917        let command = configuration
 918            .command()
 919            .context("context server has no command (HTTP servers don't need RPC)")?;
 920
 921        Ok(proto::ContextServerCommand {
 922            path: command.path.display().to_string(),
 923            args: command.args.clone(),
 924            env: command
 925                .env
 926                .clone()
 927                .map(|env| env.into_iter().collect())
 928                .unwrap_or_default(),
 929        })
 930    }
 931
 932    fn resolve_project_settings<'a>(
 933        worktree_store: &'a Entity<WorktreeStore>,
 934        cx: &'a App,
 935    ) -> &'a ProjectSettings {
 936        let location = worktree_store
 937            .read(cx)
 938            .visible_worktrees(cx)
 939            .next()
 940            .map(|worktree| settings::SettingsLocation {
 941                worktree_id: worktree.read(cx).id(),
 942                path: RelPath::empty(),
 943            });
 944        ProjectSettings::get(location, cx)
 945    }
 946
 947    fn create_oauth_token_provider(
 948        id: &ContextServerId,
 949        server_url: &url::Url,
 950        session: OAuthSession,
 951        http_client: Arc<dyn HttpClient>,
 952        credentials_provider: Arc<dyn CredentialsProvider>,
 953        cx: &mut AsyncApp,
 954    ) -> Arc<dyn oauth::OAuthTokenProvider> {
 955        let (token_refresh_tx, mut token_refresh_rx) = futures::channel::mpsc::unbounded();
 956        let id = id.clone();
 957        let server_url = server_url.clone();
 958
 959        cx.spawn(async move |cx| {
 960            while let Some(refreshed_session) = token_refresh_rx.next().await {
 961                if let Err(err) =
 962                    Self::store_session(&credentials_provider, &server_url, &refreshed_session, &cx)
 963                        .await
 964                {
 965                    log::warn!("{} failed to persist refreshed OAuth session: {}", id, err);
 966                }
 967            }
 968            log::debug!("{} OAuth session persistence task ended", id);
 969        })
 970        .detach();
 971
 972        Arc::new(McpOAuthTokenProvider::new(
 973            session,
 974            http_client,
 975            Some(token_refresh_tx),
 976        ))
 977    }
 978
 979    /// Initiate the OAuth browser flow for a server in the `AuthRequired` state.
 980    ///
 981    /// This starts a loopback HTTP callback server on an ephemeral port, builds
 982    /// the authorization URL, opens the user's browser, waits for the callback,
 983    /// exchanges the code for tokens, persists them in the keychain, and restarts
 984    /// the server with the new token provider.
 985    pub fn authenticate_server(
 986        &mut self,
 987        id: &ContextServerId,
 988        cx: &mut Context<Self>,
 989    ) -> Result<()> {
 990        let state = self.servers.get(id).context("Context server not found")?;
 991
 992        let (discovery, server, configuration) = match state {
 993            ContextServerState::AuthRequired {
 994                discovery,
 995                server,
 996                configuration,
 997            } => (discovery.clone(), server.clone(), configuration.clone()),
 998            _ => anyhow::bail!("Server is not in AuthRequired state"),
 999        };
1000
1001        let id = id.clone();
1002
1003        let task = cx.spawn({
1004            let id = id.clone();
1005            let server = server.clone();
1006            let configuration = configuration.clone();
1007            async move |this, cx| {
1008                let result = Self::run_oauth_flow(
1009                    this.clone(),
1010                    id.clone(),
1011                    discovery.clone(),
1012                    configuration.clone(),
1013                    cx,
1014                )
1015                .await;
1016
1017                if let Err(err) = &result {
1018                    log::error!("{} OAuth authentication failed: {:?}", id, err);
1019                    // Transition back to AuthRequired so the user can retry
1020                    // rather than landing in a terminal Error state.
1021                    this.update(cx, |this, cx| {
1022                        this.update_server_state(
1023                            id.clone(),
1024                            ContextServerState::AuthRequired {
1025                                server,
1026                                configuration,
1027                                discovery,
1028                            },
1029                            cx,
1030                        )
1031                    })
1032                    .log_err();
1033                }
1034            }
1035        });
1036
1037        self.update_server_state(
1038            id,
1039            ContextServerState::Authenticating {
1040                server,
1041                configuration,
1042                _task: task,
1043            },
1044            cx,
1045        );
1046
1047        Ok(())
1048    }
1049
1050    async fn run_oauth_flow(
1051        this: WeakEntity<Self>,
1052        id: ContextServerId,
1053        discovery: Arc<OAuthDiscovery>,
1054        configuration: Arc<ContextServerConfiguration>,
1055        cx: &mut AsyncApp,
1056    ) -> Result<()> {
1057        let resource = oauth::canonical_server_uri(&discovery.resource_metadata.resource);
1058        let pkce = oauth::generate_pkce_challenge();
1059
1060        let mut state_bytes = [0u8; 32];
1061        rand::rng().fill(&mut state_bytes);
1062        let state_param: String = state_bytes.iter().map(|b| format!("{:02x}", b)).collect();
1063
1064        // Start a loopback HTTP server on an ephemeral port. The redirect URI
1065        // includes this port so the browser sends the callback directly to our
1066        // process.
1067        let (redirect_uri, callback_rx) = oauth::start_callback_server()
1068            .await
1069            .context("Failed to start OAuth callback server")?;
1070
1071        let http_client = cx.update(|cx| cx.http_client());
1072        let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1073        let server_url = match configuration.as_ref() {
1074            ContextServerConfiguration::Http { url, .. } => url.clone(),
1075            _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1076        };
1077
1078        let client_registration =
1079            oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
1080                .await
1081                .context("Failed to resolve OAuth client registration")?;
1082
1083        let auth_url = oauth::build_authorization_url(
1084            &discovery.auth_server_metadata,
1085            &client_registration.client_id,
1086            &redirect_uri,
1087            &discovery.scopes,
1088            &resource,
1089            &pkce,
1090            &state_param,
1091        );
1092
1093        cx.update(|cx| cx.open_url(auth_url.as_str()));
1094
1095        let callback = callback_rx
1096            .await
1097            .map_err(|_| {
1098                anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
1099            })?
1100            .context("OAuth callback server received an invalid request")?;
1101
1102        if callback.state != state_param {
1103            anyhow::bail!("OAuth state parameter mismatch (possible CSRF)");
1104        }
1105
1106        let tokens = oauth::exchange_code(
1107            &http_client,
1108            &discovery.auth_server_metadata,
1109            &callback.code,
1110            &client_registration.client_id,
1111            &redirect_uri,
1112            &pkce.verifier,
1113            &resource,
1114        )
1115        .await
1116        .context("Failed to exchange authorization code for tokens")?;
1117
1118        let session = OAuthSession {
1119            token_endpoint: discovery.auth_server_metadata.token_endpoint.clone(),
1120            resource: discovery.resource_metadata.resource.clone(),
1121            client_registration,
1122            tokens,
1123        };
1124
1125        Self::store_session(&credentials_provider, &server_url, &session, cx)
1126            .await
1127            .context("Failed to persist OAuth session in keychain")?;
1128
1129        let token_provider = Self::create_oauth_token_provider(
1130            &id,
1131            &server_url,
1132            session,
1133            http_client.clone(),
1134            credentials_provider,
1135            cx,
1136        );
1137
1138        let new_server = this.update(cx, |this, cx| {
1139            let global_timeout =
1140                Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
1141
1142            match configuration.as_ref() {
1143                ContextServerConfiguration::Http {
1144                    url,
1145                    headers,
1146                    timeout,
1147                } => {
1148                    let transport = HttpTransport::new_with_token_provider(
1149                        http_client.clone(),
1150                        url.to_string(),
1151                        headers.clone(),
1152                        cx.background_executor().clone(),
1153                        Some(token_provider.clone()),
1154                    );
1155                    Ok(Arc::new(ContextServer::new_with_timeout(
1156                        id.clone(),
1157                        Arc::new(transport),
1158                        Some(Duration::from_secs(
1159                            timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
1160                        )),
1161                    )))
1162                }
1163                _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1164            }
1165        })??;
1166
1167        this.update(cx, |this, cx| {
1168            this.run_server(new_server, configuration, cx);
1169        })?;
1170
1171        Ok(())
1172    }
1173
1174    /// Store the full OAuth session in the system keychain, keyed by the
1175    /// server's canonical URI.
1176    async fn store_session(
1177        credentials_provider: &Arc<dyn CredentialsProvider>,
1178        server_url: &url::Url,
1179        session: &OAuthSession,
1180        cx: &AsyncApp,
1181    ) -> Result<()> {
1182        let key = Self::keychain_key(server_url);
1183        let json = serde_json::to_string(session)?;
1184        credentials_provider
1185            .write_credentials(&key, "mcp-oauth", json.as_bytes(), cx)
1186            .await
1187    }
1188
1189    /// Load the full OAuth session from the system keychain for the given
1190    /// server URL.
1191    async fn load_session(
1192        credentials_provider: &Arc<dyn CredentialsProvider>,
1193        server_url: &url::Url,
1194        cx: &AsyncApp,
1195    ) -> Result<Option<OAuthSession>> {
1196        let key = Self::keychain_key(server_url);
1197        match credentials_provider.read_credentials(&key, cx).await? {
1198            Some((_username, password_bytes)) => {
1199                let session: OAuthSession = serde_json::from_slice(&password_bytes)?;
1200                Ok(Some(session))
1201            }
1202            None => Ok(None),
1203        }
1204    }
1205
1206    /// Clear the stored OAuth session from the system keychain.
1207    async fn clear_session(
1208        credentials_provider: &Arc<dyn CredentialsProvider>,
1209        server_url: &url::Url,
1210        cx: &AsyncApp,
1211    ) -> Result<()> {
1212        let key = Self::keychain_key(server_url);
1213        credentials_provider.delete_credentials(&key, cx).await
1214    }
1215
1216    fn keychain_key(server_url: &url::Url) -> String {
1217        format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
1218    }
1219
1220    /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
1221    /// session from the keychain and stop the server.
1222    pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
1223        let state = self.servers.get(id).context("Context server not found")?;
1224        let configuration = state.configuration();
1225
1226        let server_url = match configuration.as_ref() {
1227            ContextServerConfiguration::Http { url, .. } => url.clone(),
1228            _ => anyhow::bail!("logout only applies to HTTP servers with OAuth"),
1229        };
1230
1231        let id = id.clone();
1232        self.stop_server(&id, cx)?;
1233
1234        cx.spawn(async move |this, cx| {
1235            let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1236            if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
1237                log::error!("{} failed to clear OAuth session: {}", id, err);
1238            }
1239            // Trigger server recreation so the next start uses a fresh
1240            // transport without the old (now-invalidated) token provider.
1241            this.update(cx, |this, cx| {
1242                this.available_context_servers_changed(cx);
1243            })
1244            .log_err();
1245        })
1246        .detach();
1247
1248        Ok(())
1249    }
1250
1251    fn update_server_state(
1252        &mut self,
1253        id: ContextServerId,
1254        state: ContextServerState,
1255        cx: &mut Context<Self>,
1256    ) {
1257        let status = ContextServerStatus::from_state(&state);
1258        self.servers.insert(id.clone(), state);
1259        cx.emit(ServerStatusChangedEvent {
1260            server_id: id,
1261            status,
1262        });
1263    }
1264
1265    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
1266        if self.update_servers_task.is_some() {
1267            self.needs_server_update = true;
1268        } else {
1269            self.needs_server_update = false;
1270            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
1271                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
1272                    log::error!("Error maintaining context servers: {}", err);
1273                }
1274
1275                this.update(cx, |this, cx| {
1276                    this.populate_server_ids(cx);
1277                    cx.notify();
1278                    this.update_servers_task.take();
1279                    if this.needs_server_update {
1280                        this.available_context_servers_changed(cx);
1281                    }
1282                })?;
1283
1284                Ok(())
1285            }));
1286        }
1287    }
1288
1289    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
1290        // Don't start context servers if AI is disabled
1291        let ai_disabled = this.update(cx, |_, cx| DisableAiSettings::get_global(cx).disable_ai)?;
1292        if ai_disabled {
1293            // Stop all running servers when AI is disabled
1294            this.update(cx, |this, cx| {
1295                let server_ids: Vec<_> = this.servers.keys().cloned().collect();
1296                for id in server_ids {
1297                    let _ = this.stop_server(&id, cx);
1298                }
1299            })?;
1300            return Ok(());
1301        }
1302
1303        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
1304            (
1305                this.context_server_settings.clone(),
1306                this.registry.clone(),
1307                this.worktree_store.clone(),
1308            )
1309        })?;
1310
1311        for (id, _) in registry.read_with(cx, |registry, _| registry.context_server_descriptors()) {
1312            configured_servers
1313                .entry(id)
1314                .or_insert(ContextServerSettings::default_extension());
1315        }
1316
1317        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
1318            configured_servers
1319                .into_iter()
1320                .partition(|(_, settings)| settings.enabled());
1321
1322        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
1323            let id = ContextServerId(id);
1324            ContextServerConfiguration::from_settings(
1325                settings,
1326                id.clone(),
1327                registry.clone(),
1328                worktree_store.clone(),
1329                cx,
1330            )
1331            .map(move |config| (id, config))
1332        }))
1333        .await
1334        .into_iter()
1335        .filter_map(|(id, config)| config.map(|config| (id, config)))
1336        .collect::<HashMap<_, _>>();
1337
1338        let mut servers_to_start = Vec::new();
1339        let mut servers_to_remove = HashSet::default();
1340        let mut servers_to_stop = HashSet::default();
1341
1342        this.update(cx, |this, _cx| {
1343            for server_id in this.servers.keys() {
1344                // All servers that are not in desired_servers should be removed from the store.
1345                // This can happen if the user removed a server from the context server settings.
1346                if !configured_servers.contains_key(server_id) {
1347                    if disabled_servers.contains_key(&server_id.0) {
1348                        servers_to_stop.insert(server_id.clone());
1349                    } else {
1350                        servers_to_remove.insert(server_id.clone());
1351                    }
1352                }
1353            }
1354
1355            for (id, config) in configured_servers {
1356                let state = this.servers.get(&id);
1357                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
1358                let existing_config = state.as_ref().map(|state| state.configuration());
1359                if existing_config.as_deref() != Some(&config) || is_stopped {
1360                    let config = Arc::new(config);
1361                    servers_to_start.push((id.clone(), config));
1362                    if this.servers.contains_key(&id) {
1363                        servers_to_stop.insert(id);
1364                    }
1365                }
1366            }
1367
1368            anyhow::Ok(())
1369        })??;
1370
1371        this.update(cx, |this, inner_cx| {
1372            for id in servers_to_stop {
1373                this.stop_server(&id, inner_cx)?;
1374            }
1375            for id in servers_to_remove {
1376                this.remove_server(&id, inner_cx)?;
1377            }
1378            anyhow::Ok(())
1379        })??;
1380
1381        for (id, config) in servers_to_start {
1382            match Self::create_context_server(this.clone(), id.clone(), config, cx).await {
1383                Ok((server, config)) => {
1384                    this.update(cx, |this, cx| {
1385                        this.run_server(server, config, cx);
1386                    })?;
1387                }
1388                Err(err) => {
1389                    log::error!("{id} context server failed to create: {err:#}");
1390                    this.update(cx, |_this, cx| {
1391                        cx.emit(ServerStatusChangedEvent {
1392                            server_id: id,
1393                            status: ContextServerStatus::Error(err.to_string().into()),
1394                        });
1395                        cx.notify();
1396                    })?;
1397                }
1398            }
1399        }
1400
1401        Ok(())
1402    }
1403}
1404
1405/// Determines the appropriate server state after a start attempt fails.
1406///
1407/// When the error is an HTTP 401 with no static auth header configured,
1408/// attempts OAuth discovery so the UI can offer an authentication flow.
1409async fn resolve_start_failure(
1410    id: &ContextServerId,
1411    err: anyhow::Error,
1412    server: Arc<ContextServer>,
1413    configuration: Arc<ContextServerConfiguration>,
1414    cx: &AsyncApp,
1415) -> ContextServerState {
1416    let www_authenticate = err.downcast_ref::<TransportError>().map(|e| match e {
1417        TransportError::AuthRequired { www_authenticate } => www_authenticate.clone(),
1418    });
1419
1420    if www_authenticate.is_some() && configuration.has_static_auth_header() {
1421        log::warn!("{id} received 401 with a static Authorization header configured");
1422        return ContextServerState::Error {
1423            configuration,
1424            server,
1425            error: "Server returned 401 Unauthorized. Check your configured Authorization header."
1426                .into(),
1427        };
1428    }
1429
1430    let server_url = match configuration.as_ref() {
1431        ContextServerConfiguration::Http { url, .. } if !configuration.has_static_auth_header() => {
1432            url.clone()
1433        }
1434        _ => {
1435            if www_authenticate.is_some() {
1436                log::error!("{id} got OAuth 401 on a non-HTTP transport or with static auth");
1437            } else {
1438                log::error!("{id} context server failed to start: {err}");
1439            }
1440            return ContextServerState::Error {
1441                configuration,
1442                server,
1443                error: err.to_string().into(),
1444            };
1445        }
1446    };
1447
1448    // When the error is NOT a 401 but there is a cached OAuth session in the
1449    // keychain, the session is likely stale/expired and caused the failure
1450    // (e.g. timeout because the server rejected the token silently). Clear it
1451    // so the next start attempt can get a clean 401 and trigger the auth flow.
1452    if www_authenticate.is_none() {
1453        let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1454        match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
1455            Ok(Some(_)) => {
1456                log::info!("{id} start failed with a cached OAuth session present; clearing it");
1457                ContextServerStore::clear_session(&credentials_provider, &server_url, cx)
1458                    .await
1459                    .log_err();
1460            }
1461            _ => {
1462                log::error!("{id} context server failed to start: {err}");
1463                return ContextServerState::Error {
1464                    configuration,
1465                    server,
1466                    error: err.to_string().into(),
1467                };
1468            }
1469        }
1470    }
1471
1472    let default_www_authenticate = oauth::WwwAuthenticate {
1473        resource_metadata: None,
1474        scope: None,
1475        error: None,
1476        error_description: None,
1477    };
1478    let www_authenticate = www_authenticate
1479        .as_ref()
1480        .unwrap_or(&default_www_authenticate);
1481    let http_client = cx.update(|cx| cx.http_client());
1482
1483    match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
1484        Ok(discovery) => {
1485            log::info!(
1486                "{id} requires OAuth authorization (auth server: {})",
1487                discovery.auth_server_metadata.issuer,
1488            );
1489            ContextServerState::AuthRequired {
1490                server,
1491                configuration,
1492                discovery: Arc::new(discovery),
1493            }
1494        }
1495        Err(discovery_err) => {
1496            log::error!("{id} OAuth discovery failed: {discovery_err}");
1497            ContextServerState::Error {
1498                configuration,
1499                server,
1500                error: format!("OAuth discovery failed: {discovery_err}").into(),
1501            }
1502        }
1503    }
1504}