context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::path::Path;
   5use std::sync::Arc;
   6use std::time::Duration;
   7
   8use anyhow::{Context as _, Result};
   9use collections::{HashMap, HashSet};
  10use context_server::oauth::{self, McpOAuthTokenProvider, OAuthDiscovery, OAuthSession};
  11use context_server::transport::{HttpTransport, TransportError};
  12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  13use credentials_provider::CredentialsProvider;
  14use futures::future::Either;
  15use futures::{FutureExt as _, StreamExt as _, future::join_all};
  16use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  17use http_client::HttpClient;
  18use itertools::Itertools;
  19use rand::Rng as _;
  20use registry::ContextServerDescriptorRegistry;
  21use remote::RemoteClient;
  22use rpc::{AnyProtoClient, TypedEnvelope, proto};
  23use settings::{Settings as _, SettingsStore};
  24use util::{ResultExt as _, rel_path::RelPath};
  25
  26use crate::{
  27    DisableAiSettings, Project,
  28    project_settings::{ContextServerSettings, ProjectSettings},
  29    worktree_store::WorktreeStore,
  30};
  31
  32/// Maximum timeout for context server requests
  33/// Prevents extremely large timeout values from tying up resources indefinitely.
  34const MAX_TIMEOUT_SECS: u64 = 600; // 10 minutes
  35
  36pub fn init(cx: &mut App) {
  37    extension::init(cx);
  38}
  39
  40actions!(
  41    context_server,
  42    [
  43        /// Restarts the context server.
  44        Restart
  45    ]
  46);
  47
  48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  49pub enum ContextServerStatus {
  50    Starting,
  51    Running,
  52    Stopped,
  53    Error(Arc<str>),
  54    /// The server returned 401 and OAuth authorization is needed. The UI
  55    /// should show an "Authenticate" button.
  56    AuthRequired,
  57    /// The OAuth browser flow is in progress — the user has been redirected
  58    /// to the authorization server and we're waiting for the callback.
  59    Authenticating,
  60}
  61
  62impl ContextServerStatus {
  63    fn from_state(state: &ContextServerState) -> Self {
  64        match state {
  65            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  66            ContextServerState::Running { .. } => ContextServerStatus::Running,
  67            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  68            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  69            ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
  70            ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
  71        }
  72    }
  73}
  74
  75enum ContextServerState {
  76    Starting {
  77        server: Arc<ContextServer>,
  78        configuration: Arc<ContextServerConfiguration>,
  79        _task: Task<()>,
  80    },
  81    Running {
  82        server: Arc<ContextServer>,
  83        configuration: Arc<ContextServerConfiguration>,
  84    },
  85    Stopped {
  86        server: Arc<ContextServer>,
  87        configuration: Arc<ContextServerConfiguration>,
  88    },
  89    Error {
  90        server: Arc<ContextServer>,
  91        configuration: Arc<ContextServerConfiguration>,
  92        error: Arc<str>,
  93    },
  94    /// The server requires OAuth authorization before it can be used. The
  95    /// `OAuthDiscovery` holds everything needed to start the browser flow.
  96    AuthRequired {
  97        server: Arc<ContextServer>,
  98        configuration: Arc<ContextServerConfiguration>,
  99        discovery: Arc<OAuthDiscovery>,
 100    },
 101    /// The OAuth browser flow is in progress. The user has been redirected
 102    /// to the authorization server and we're waiting for the callback.
 103    Authenticating {
 104        server: Arc<ContextServer>,
 105        configuration: Arc<ContextServerConfiguration>,
 106        _task: Task<()>,
 107    },
 108}
 109
 110impl ContextServerState {
 111    pub fn server(&self) -> Arc<ContextServer> {
 112        match self {
 113            ContextServerState::Starting { server, .. }
 114            | ContextServerState::Running { server, .. }
 115            | ContextServerState::Stopped { server, .. }
 116            | ContextServerState::Error { server, .. }
 117            | ContextServerState::AuthRequired { server, .. }
 118            | ContextServerState::Authenticating { server, .. } => server.clone(),
 119        }
 120    }
 121
 122    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
 123        match self {
 124            ContextServerState::Starting { configuration, .. }
 125            | ContextServerState::Running { configuration, .. }
 126            | ContextServerState::Stopped { configuration, .. }
 127            | ContextServerState::Error { configuration, .. }
 128            | ContextServerState::AuthRequired { configuration, .. }
 129            | ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
 130        }
 131    }
 132}
 133
 134#[derive(Debug, PartialEq, Eq)]
 135pub enum ContextServerConfiguration {
 136    Custom {
 137        command: ContextServerCommand,
 138        remote: bool,
 139    },
 140    Extension {
 141        command: ContextServerCommand,
 142        settings: serde_json::Value,
 143        remote: bool,
 144    },
 145    Http {
 146        url: url::Url,
 147        headers: HashMap<String, String>,
 148        timeout: Option<u64>,
 149    },
 150}
 151
 152impl ContextServerConfiguration {
 153    pub fn command(&self) -> Option<&ContextServerCommand> {
 154        match self {
 155            ContextServerConfiguration::Custom { command, .. } => Some(command),
 156            ContextServerConfiguration::Extension { command, .. } => Some(command),
 157            ContextServerConfiguration::Http { .. } => None,
 158        }
 159    }
 160
 161    pub fn has_static_auth_header(&self) -> bool {
 162        match self {
 163            ContextServerConfiguration::Http { headers, .. } => headers
 164                .keys()
 165                .any(|k| k.eq_ignore_ascii_case("authorization")),
 166            _ => false,
 167        }
 168    }
 169
 170    pub fn remote(&self) -> bool {
 171        match self {
 172            ContextServerConfiguration::Custom { remote, .. } => *remote,
 173            ContextServerConfiguration::Extension { remote, .. } => *remote,
 174            ContextServerConfiguration::Http { .. } => false,
 175        }
 176    }
 177
 178    pub async fn from_settings(
 179        settings: ContextServerSettings,
 180        id: ContextServerId,
 181        registry: Entity<ContextServerDescriptorRegistry>,
 182        worktree_store: Entity<WorktreeStore>,
 183        cx: &AsyncApp,
 184    ) -> Option<Self> {
 185        const EXTENSION_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
 186
 187        match settings {
 188            ContextServerSettings::Stdio {
 189                enabled: _,
 190                command,
 191                remote,
 192            } => Some(ContextServerConfiguration::Custom { command, remote }),
 193            ContextServerSettings::Extension {
 194                enabled: _,
 195                settings,
 196                remote,
 197            } => {
 198                let descriptor =
 199                    cx.update(|cx| registry.read(cx).context_server_descriptor(&id.0))?;
 200
 201                let command_future = descriptor.command(worktree_store, cx);
 202                let timeout_future = cx.background_executor().timer(EXTENSION_COMMAND_TIMEOUT);
 203
 204                match futures::future::select(command_future, timeout_future).await {
 205                    Either::Left((Ok(command), _)) => Some(ContextServerConfiguration::Extension {
 206                        command,
 207                        settings,
 208                        remote,
 209                    }),
 210                    Either::Left((Err(e), _)) => {
 211                        log::error!(
 212                            "Failed to create context server configuration from settings: {e:#}"
 213                        );
 214                        None
 215                    }
 216                    Either::Right(_) => {
 217                        log::error!(
 218                            "Timed out resolving command for extension context server {id}"
 219                        );
 220                        None
 221                    }
 222                }
 223            }
 224            ContextServerSettings::Http {
 225                enabled: _,
 226                url,
 227                headers: auth,
 228                timeout,
 229            } => {
 230                let url = url::Url::parse(&url).log_err()?;
 231                Some(ContextServerConfiguration::Http {
 232                    url,
 233                    headers: auth,
 234                    timeout,
 235                })
 236            }
 237        }
 238    }
 239}
 240
 241pub type ContextServerFactory =
 242    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 243
 244enum ContextServerStoreState {
 245    Local {
 246        downstream_client: Option<(u64, AnyProtoClient)>,
 247        is_headless: bool,
 248    },
 249    Remote {
 250        project_id: u64,
 251        upstream_client: Entity<RemoteClient>,
 252    },
 253}
 254
 255pub struct ContextServerStore {
 256    state: ContextServerStoreState,
 257    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 258    servers: HashMap<ContextServerId, ContextServerState>,
 259    server_ids: Vec<ContextServerId>,
 260    worktree_store: Entity<WorktreeStore>,
 261    project: Option<WeakEntity<Project>>,
 262    registry: Entity<ContextServerDescriptorRegistry>,
 263    update_servers_task: Option<Task<Result<()>>>,
 264    context_server_factory: Option<ContextServerFactory>,
 265    needs_server_update: bool,
 266    ai_disabled: bool,
 267    _subscriptions: Vec<Subscription>,
 268}
 269
 270pub struct ServerStatusChangedEvent {
 271    pub server_id: ContextServerId,
 272    pub status: ContextServerStatus,
 273}
 274
 275impl EventEmitter<ServerStatusChangedEvent> for ContextServerStore {}
 276
 277impl ContextServerStore {
 278    pub fn local(
 279        worktree_store: Entity<WorktreeStore>,
 280        weak_project: Option<WeakEntity<Project>>,
 281        headless: bool,
 282        cx: &mut Context<Self>,
 283    ) -> Self {
 284        Self::new_internal(
 285            !headless,
 286            None,
 287            ContextServerDescriptorRegistry::default_global(cx),
 288            worktree_store,
 289            weak_project,
 290            ContextServerStoreState::Local {
 291                downstream_client: None,
 292                is_headless: headless,
 293            },
 294            cx,
 295        )
 296    }
 297
 298    pub fn remote(
 299        project_id: u64,
 300        upstream_client: Entity<RemoteClient>,
 301        worktree_store: Entity<WorktreeStore>,
 302        weak_project: Option<WeakEntity<Project>>,
 303        cx: &mut Context<Self>,
 304    ) -> Self {
 305        Self::new_internal(
 306            true,
 307            None,
 308            ContextServerDescriptorRegistry::default_global(cx),
 309            worktree_store,
 310            weak_project,
 311            ContextServerStoreState::Remote {
 312                project_id,
 313                upstream_client,
 314            },
 315            cx,
 316        )
 317    }
 318
 319    pub fn init_headless(session: &AnyProtoClient) {
 320        session.add_entity_request_handler(Self::handle_get_context_server_command);
 321    }
 322
 323    pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) {
 324        if let ContextServerStoreState::Local {
 325            downstream_client, ..
 326        } = &mut self.state
 327        {
 328            *downstream_client = Some((project_id, client));
 329        }
 330    }
 331
 332    pub fn is_remote_project(&self) -> bool {
 333        matches!(self.state, ContextServerStoreState::Remote { .. })
 334    }
 335
 336    /// Returns all configured context server ids, excluding the ones that are disabled
 337    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 338        self.context_server_settings
 339            .iter()
 340            .filter(|(_, settings)| settings.enabled())
 341            .map(|(id, _)| ContextServerId(id.clone()))
 342            .collect()
 343    }
 344
 345    #[cfg(feature = "test-support")]
 346    pub fn test(
 347        registry: Entity<ContextServerDescriptorRegistry>,
 348        worktree_store: Entity<WorktreeStore>,
 349        weak_project: Option<WeakEntity<Project>>,
 350        cx: &mut Context<Self>,
 351    ) -> Self {
 352        Self::new_internal(
 353            false,
 354            None,
 355            registry,
 356            worktree_store,
 357            weak_project,
 358            ContextServerStoreState::Local {
 359                downstream_client: None,
 360                is_headless: false,
 361            },
 362            cx,
 363        )
 364    }
 365
 366    #[cfg(feature = "test-support")]
 367    pub fn test_maintain_server_loop(
 368        context_server_factory: Option<ContextServerFactory>,
 369        registry: Entity<ContextServerDescriptorRegistry>,
 370        worktree_store: Entity<WorktreeStore>,
 371        weak_project: Option<WeakEntity<Project>>,
 372        cx: &mut Context<Self>,
 373    ) -> Self {
 374        Self::new_internal(
 375            true,
 376            context_server_factory,
 377            registry,
 378            worktree_store,
 379            weak_project,
 380            ContextServerStoreState::Local {
 381                downstream_client: None,
 382                is_headless: false,
 383            },
 384            cx,
 385        )
 386    }
 387
 388    #[cfg(feature = "test-support")]
 389    pub fn set_context_server_factory(&mut self, factory: ContextServerFactory) {
 390        self.context_server_factory = Some(factory);
 391    }
 392
 393    #[cfg(feature = "test-support")]
 394    pub fn registry(&self) -> &Entity<ContextServerDescriptorRegistry> {
 395        &self.registry
 396    }
 397
 398    #[cfg(feature = "test-support")]
 399    pub fn test_start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 400        let configuration = Arc::new(ContextServerConfiguration::Custom {
 401            command: ContextServerCommand {
 402                path: "test".into(),
 403                args: vec![],
 404                env: None,
 405                timeout: None,
 406            },
 407            remote: false,
 408        });
 409        self.run_server(server, configuration, cx);
 410    }
 411
 412    fn new_internal(
 413        maintain_server_loop: bool,
 414        context_server_factory: Option<ContextServerFactory>,
 415        registry: Entity<ContextServerDescriptorRegistry>,
 416        worktree_store: Entity<WorktreeStore>,
 417        weak_project: Option<WeakEntity<Project>>,
 418        state: ContextServerStoreState,
 419        cx: &mut Context<Self>,
 420    ) -> Self {
 421        let mut subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
 422            let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
 423            let ai_was_disabled = this.ai_disabled;
 424            this.ai_disabled = ai_disabled;
 425
 426            let settings =
 427                &Self::resolve_project_settings(&this.worktree_store, cx).context_servers;
 428            let settings_changed = &this.context_server_settings != settings;
 429
 430            if settings_changed {
 431                this.context_server_settings = settings.clone();
 432            }
 433
 434            // When AI is disabled, stop all running servers
 435            if ai_disabled {
 436                let server_ids: Vec<_> = this.servers.keys().cloned().collect();
 437                for id in server_ids {
 438                    this.stop_server(&id, cx).log_err();
 439                }
 440                return;
 441            }
 442
 443            // Trigger updates if AI was re-enabled or settings changed
 444            if maintain_server_loop && (ai_was_disabled || settings_changed) {
 445                this.available_context_servers_changed(cx);
 446            }
 447        })];
 448
 449        if maintain_server_loop {
 450            subscriptions.push(cx.observe(&registry, |this, _registry, cx| {
 451                if !DisableAiSettings::get_global(cx).disable_ai {
 452                    this.available_context_servers_changed(cx);
 453                }
 454            }));
 455        }
 456
 457        let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
 458        let mut this = Self {
 459            state,
 460            _subscriptions: subscriptions,
 461            context_server_settings: Self::resolve_project_settings(&worktree_store, cx)
 462                .context_servers
 463                .clone(),
 464            worktree_store,
 465            project: weak_project,
 466            registry,
 467            needs_server_update: false,
 468            ai_disabled,
 469            servers: HashMap::default(),
 470            server_ids: Default::default(),
 471            update_servers_task: None,
 472            context_server_factory,
 473        };
 474        if maintain_server_loop && !DisableAiSettings::get_global(cx).disable_ai {
 475            this.available_context_servers_changed(cx);
 476        }
 477        this
 478    }
 479
 480    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 481        self.servers.get(id).map(|state| state.server())
 482    }
 483
 484    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 485        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 486            Some(server.clone())
 487        } else {
 488            None
 489        }
 490    }
 491
 492    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 493        self.servers.get(id).map(ContextServerStatus::from_state)
 494    }
 495
 496    pub fn configuration_for_server(
 497        &self,
 498        id: &ContextServerId,
 499    ) -> Option<Arc<ContextServerConfiguration>> {
 500        self.servers.get(id).map(|state| state.configuration())
 501    }
 502
 503    /// Returns a sorted slice of available unique context server IDs. Within the
 504    /// slice, context servers which have `mcp-server-` as a prefix in their ID will
 505    /// appear after servers that do not have this prefix in their ID.
 506    pub fn server_ids(&self) -> &[ContextServerId] {
 507        self.server_ids.as_slice()
 508    }
 509
 510    fn populate_server_ids(&mut self, cx: &App) {
 511        self.server_ids = self
 512            .servers
 513            .keys()
 514            .cloned()
 515            .chain(
 516                self.registry
 517                    .read(cx)
 518                    .context_server_descriptors()
 519                    .into_iter()
 520                    .map(|(id, _)| ContextServerId(id)),
 521            )
 522            .chain(
 523                self.context_server_settings
 524                    .keys()
 525                    .map(|id| ContextServerId(id.clone())),
 526            )
 527            .unique()
 528            .sorted_unstable_by(
 529                // Sort context servers: ones without mcp-server- prefix first, then prefixed ones
 530                |a, b| {
 531                    const MCP_PREFIX: &str = "mcp-server-";
 532                    match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) {
 533                        // If one has mcp-server- prefix and other doesn't, non-mcp comes first
 534                        (Some(_), None) => std::cmp::Ordering::Greater,
 535                        (None, Some(_)) => std::cmp::Ordering::Less,
 536                        // If both have same prefix status, sort by appropriate key
 537                        (Some(a), Some(b)) => a.cmp(b),
 538                        (None, None) => a.0.cmp(&b.0),
 539                    }
 540                },
 541            )
 542            .collect();
 543    }
 544
 545    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 546        self.servers
 547            .values()
 548            .filter_map(|state| {
 549                if let ContextServerState::Running { server, .. } = state {
 550                    Some(server.clone())
 551                } else {
 552                    None
 553                }
 554            })
 555            .collect()
 556    }
 557
 558    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 559        cx.spawn(async move |this, cx| {
 560            let this = this.upgrade().context("Context server store dropped")?;
 561            let id = server.id();
 562            let settings = this
 563                .update(cx, |this, _| {
 564                    this.context_server_settings.get(&id.0).cloned()
 565                })
 566                .context("Failed to get context server settings")?;
 567
 568            if !settings.enabled() {
 569                return anyhow::Ok(());
 570            }
 571
 572            let (registry, worktree_store) = this.update(cx, |this, _| {
 573                (this.registry.clone(), this.worktree_store.clone())
 574            });
 575            let configuration = ContextServerConfiguration::from_settings(
 576                settings,
 577                id.clone(),
 578                registry,
 579                worktree_store,
 580                cx,
 581            )
 582            .await
 583            .context("Failed to create context server configuration")?;
 584
 585            this.update(cx, |this, cx| {
 586                this.run_server(server, Arc::new(configuration), cx)
 587            });
 588            Ok(())
 589        })
 590        .detach_and_log_err(cx);
 591    }
 592
 593    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 594        if matches!(
 595            self.servers.get(id),
 596            Some(ContextServerState::Stopped { .. })
 597        ) {
 598            return Ok(());
 599        }
 600
 601        let state = self
 602            .servers
 603            .remove(id)
 604            .context("Context server not found")?;
 605
 606        let server = state.server();
 607        let configuration = state.configuration();
 608        let mut result = Ok(());
 609        if let ContextServerState::Running { server, .. } = &state {
 610            result = server.stop();
 611        }
 612        drop(state);
 613
 614        self.update_server_state(
 615            id.clone(),
 616            ContextServerState::Stopped {
 617                configuration,
 618                server,
 619            },
 620            cx,
 621        );
 622
 623        result
 624    }
 625
 626    fn run_server(
 627        &mut self,
 628        server: Arc<ContextServer>,
 629        configuration: Arc<ContextServerConfiguration>,
 630        cx: &mut Context<Self>,
 631    ) {
 632        let id = server.id();
 633        if matches!(
 634            self.servers.get(&id),
 635            Some(
 636                ContextServerState::Starting { .. }
 637                    | ContextServerState::Running { .. }
 638                    | ContextServerState::Authenticating { .. },
 639            )
 640        ) {
 641            self.stop_server(&id, cx).log_err();
 642        }
 643        let task = cx.spawn({
 644            let id = server.id();
 645            let server = server.clone();
 646            let configuration = configuration.clone();
 647
 648            async move |this, cx| {
 649                let new_state = match server.clone().start(cx).await {
 650                    Ok(_) => {
 651                        debug_assert!(server.client().is_some());
 652                        ContextServerState::Running {
 653                            server,
 654                            configuration,
 655                        }
 656                    }
 657                    Err(err) => resolve_start_failure(&id, err, server, configuration, cx).await,
 658                };
 659                this.update(cx, |this, cx| {
 660                    this.update_server_state(id.clone(), new_state, cx)
 661                })
 662                .log_err();
 663            }
 664        });
 665
 666        self.update_server_state(
 667            id.clone(),
 668            ContextServerState::Starting {
 669                configuration,
 670                _task: task,
 671                server,
 672            },
 673            cx,
 674        );
 675    }
 676
 677    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 678        let state = self
 679            .servers
 680            .remove(id)
 681            .context("Context server not found")?;
 682
 683        if let ContextServerConfiguration::Http { url, .. } = state.configuration().as_ref() {
 684            let server_url = url.clone();
 685            let id = id.clone();
 686            cx.spawn(async move |_this, cx| {
 687                let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
 688                if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
 689                {
 690                    log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
 691                }
 692            })
 693            .detach();
 694        }
 695
 696        drop(state);
 697        cx.emit(ServerStatusChangedEvent {
 698            server_id: id.clone(),
 699            status: ContextServerStatus::Stopped,
 700        });
 701        Ok(())
 702    }
 703
 704    pub async fn create_context_server(
 705        this: WeakEntity<Self>,
 706        id: ContextServerId,
 707        configuration: Arc<ContextServerConfiguration>,
 708        cx: &mut AsyncApp,
 709    ) -> Result<(Arc<ContextServer>, Arc<ContextServerConfiguration>)> {
 710        let remote = configuration.remote();
 711        let needs_remote_command = match configuration.as_ref() {
 712            ContextServerConfiguration::Custom { .. }
 713            | ContextServerConfiguration::Extension { .. } => remote,
 714            ContextServerConfiguration::Http { .. } => false,
 715        };
 716
 717        let (remote_state, is_remote_project) = this.update(cx, |this, _| {
 718            let remote_state = match &this.state {
 719                ContextServerStoreState::Remote {
 720                    project_id,
 721                    upstream_client,
 722                } if needs_remote_command => Some((*project_id, upstream_client.clone())),
 723                _ => None,
 724            };
 725            (remote_state, this.is_remote_project())
 726        })?;
 727
 728        let root_path: Option<Arc<Path>> = this.update(cx, |this, cx| {
 729            this.project
 730                .as_ref()
 731                .and_then(|project| {
 732                    project
 733                        .read_with(cx, |project, cx| project.active_project_directory(cx))
 734                        .ok()
 735                        .flatten()
 736                })
 737                .or_else(|| {
 738                    this.worktree_store.read_with(cx, |store, cx| {
 739                        store.visible_worktrees(cx).fold(None, |acc, item| {
 740                            if acc.is_none() {
 741                                item.read(cx).root_dir()
 742                            } else {
 743                                acc
 744                            }
 745                        })
 746                    })
 747                })
 748        })?;
 749
 750        let configuration = if let Some((project_id, upstream_client)) = remote_state {
 751            let root_dir = root_path.as_ref().map(|p| p.display().to_string());
 752
 753            let response = upstream_client
 754                .update(cx, |client, _| {
 755                    client
 756                        .proto_client()
 757                        .request(proto::GetContextServerCommand {
 758                            project_id,
 759                            server_id: id.0.to_string(),
 760                            root_dir: root_dir.clone(),
 761                        })
 762                })
 763                .await?;
 764
 765            let remote_command = upstream_client.update(cx, |client, _| {
 766                client.build_command(
 767                    Some(response.path),
 768                    &response.args,
 769                    &response.env.into_iter().collect(),
 770                    root_dir,
 771                    None,
 772                )
 773            })?;
 774
 775            let command = ContextServerCommand {
 776                path: remote_command.program.into(),
 777                args: remote_command.args,
 778                env: Some(remote_command.env.into_iter().collect()),
 779                timeout: None,
 780            };
 781
 782            Arc::new(ContextServerConfiguration::Custom { command, remote })
 783        } else {
 784            configuration
 785        };
 786
 787        if let Some(server) = this.update(cx, |this, _| {
 788            this.context_server_factory
 789                .as_ref()
 790                .map(|factory| factory(id.clone(), configuration.clone()))
 791        })? {
 792            return Ok((server, configuration));
 793        }
 794
 795        let cached_token_provider: Option<Arc<dyn oauth::OAuthTokenProvider>> =
 796            if let ContextServerConfiguration::Http { url, .. } = configuration.as_ref() {
 797                if configuration.has_static_auth_header() {
 798                    None
 799                } else {
 800                    let credentials_provider =
 801                        cx.update(|cx| <dyn CredentialsProvider>::global(cx));
 802                    let http_client = cx.update(|cx| cx.http_client());
 803
 804                    match Self::load_session(&credentials_provider, url, &cx).await {
 805                        Ok(Some(session)) => {
 806                            log::info!("{} loaded cached OAuth session from keychain", id);
 807                            Some(Self::create_oauth_token_provider(
 808                                &id,
 809                                url,
 810                                session,
 811                                http_client,
 812                                credentials_provider,
 813                                cx,
 814                            ))
 815                        }
 816                        Ok(None) => None,
 817                        Err(err) => {
 818                            log::warn!("{} failed to load cached OAuth session: {}", id, err);
 819                            None
 820                        }
 821                    }
 822                }
 823            } else {
 824                None
 825            };
 826
 827        let server: Arc<ContextServer> = this.update(cx, |this, cx| {
 828            let global_timeout =
 829                Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
 830
 831            match configuration.as_ref() {
 832                ContextServerConfiguration::Http {
 833                    url,
 834                    headers,
 835                    timeout,
 836                } => {
 837                    let transport = HttpTransport::new_with_token_provider(
 838                        cx.http_client(),
 839                        url.to_string(),
 840                        headers.clone(),
 841                        cx.background_executor().clone(),
 842                        cached_token_provider.clone(),
 843                    );
 844                    anyhow::Ok(Arc::new(ContextServer::new_with_timeout(
 845                        id,
 846                        Arc::new(transport),
 847                        Some(Duration::from_secs(
 848                            timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
 849                        )),
 850                    )))
 851                }
 852                _ => {
 853                    let mut command = configuration
 854                        .command()
 855                        .context("Missing command configuration for stdio context server")?
 856                        .clone();
 857                    command.timeout = Some(
 858                        command
 859                            .timeout
 860                            .unwrap_or(global_timeout)
 861                            .min(MAX_TIMEOUT_SECS),
 862                    );
 863
 864                    // Don't pass remote paths as working directory for locally-spawned processes
 865                    let working_directory = if is_remote_project { None } else { root_path };
 866                    anyhow::Ok(Arc::new(ContextServer::stdio(
 867                        id,
 868                        command,
 869                        working_directory,
 870                    )))
 871                }
 872            }
 873        })??;
 874
 875        Ok((server, configuration))
 876    }
 877
 878    async fn handle_get_context_server_command(
 879        this: Entity<Self>,
 880        envelope: TypedEnvelope<proto::GetContextServerCommand>,
 881        mut cx: AsyncApp,
 882    ) -> Result<proto::ContextServerCommand> {
 883        let server_id = ContextServerId(envelope.payload.server_id.into());
 884
 885        let (settings, registry, worktree_store) = this.update(&mut cx, |this, inner_cx| {
 886            let ContextServerStoreState::Local {
 887                is_headless: true, ..
 888            } = &this.state
 889            else {
 890                anyhow::bail!("unexpected GetContextServerCommand request in a non-local project");
 891            };
 892
 893            let settings = this
 894                .context_server_settings
 895                .get(&server_id.0)
 896                .cloned()
 897                .or_else(|| {
 898                    this.registry
 899                        .read(inner_cx)
 900                        .context_server_descriptor(&server_id.0)
 901                        .map(|_| ContextServerSettings::default_extension())
 902                })
 903                .with_context(|| format!("context server `{}` not found", server_id))?;
 904
 905            anyhow::Ok((settings, this.registry.clone(), this.worktree_store.clone()))
 906        })?;
 907
 908        let configuration = ContextServerConfiguration::from_settings(
 909            settings,
 910            server_id.clone(),
 911            registry,
 912            worktree_store,
 913            &cx,
 914        )
 915        .await
 916        .with_context(|| format!("failed to build configuration for `{}`", server_id))?;
 917
 918        let command = configuration
 919            .command()
 920            .context("context server has no command (HTTP servers don't need RPC)")?;
 921
 922        Ok(proto::ContextServerCommand {
 923            path: command.path.display().to_string(),
 924            args: command.args.clone(),
 925            env: command
 926                .env
 927                .clone()
 928                .map(|env| env.into_iter().collect())
 929                .unwrap_or_default(),
 930        })
 931    }
 932
 933    fn resolve_project_settings<'a>(
 934        worktree_store: &'a Entity<WorktreeStore>,
 935        cx: &'a App,
 936    ) -> &'a ProjectSettings {
 937        let location = worktree_store
 938            .read(cx)
 939            .visible_worktrees(cx)
 940            .next()
 941            .map(|worktree| settings::SettingsLocation {
 942                worktree_id: worktree.read(cx).id(),
 943                path: RelPath::empty(),
 944            });
 945        ProjectSettings::get(location, cx)
 946    }
 947
 948    fn create_oauth_token_provider(
 949        id: &ContextServerId,
 950        server_url: &url::Url,
 951        session: OAuthSession,
 952        http_client: Arc<dyn HttpClient>,
 953        credentials_provider: Arc<dyn CredentialsProvider>,
 954        cx: &mut AsyncApp,
 955    ) -> Arc<dyn oauth::OAuthTokenProvider> {
 956        let (token_refresh_tx, mut token_refresh_rx) = futures::channel::mpsc::unbounded();
 957        let id = id.clone();
 958        let server_url = server_url.clone();
 959
 960        cx.spawn(async move |cx| {
 961            while let Some(refreshed_session) = token_refresh_rx.next().await {
 962                if let Err(err) =
 963                    Self::store_session(&credentials_provider, &server_url, &refreshed_session, &cx)
 964                        .await
 965                {
 966                    log::warn!("{} failed to persist refreshed OAuth session: {}", id, err);
 967                }
 968            }
 969            log::debug!("{} OAuth session persistence task ended", id);
 970        })
 971        .detach();
 972
 973        Arc::new(McpOAuthTokenProvider::new(
 974            session,
 975            http_client,
 976            Some(token_refresh_tx),
 977        ))
 978    }
 979
 980    /// Initiate the OAuth browser flow for a server in the `AuthRequired` state.
 981    ///
 982    /// This starts a loopback HTTP callback server on an ephemeral port, builds
 983    /// the authorization URL, opens the user's browser, waits for the callback,
 984    /// exchanges the code for tokens, persists them in the keychain, and restarts
 985    /// the server with the new token provider.
 986    pub fn authenticate_server(
 987        &mut self,
 988        id: &ContextServerId,
 989        cx: &mut Context<Self>,
 990    ) -> Result<()> {
 991        let state = self.servers.get(id).context("Context server not found")?;
 992
 993        let (discovery, server, configuration) = match state {
 994            ContextServerState::AuthRequired {
 995                discovery,
 996                server,
 997                configuration,
 998            } => (discovery.clone(), server.clone(), configuration.clone()),
 999            _ => anyhow::bail!("Server is not in AuthRequired state"),
1000        };
1001
1002        let id = id.clone();
1003
1004        let task = cx.spawn({
1005            let id = id.clone();
1006            let server = server.clone();
1007            let configuration = configuration.clone();
1008            async move |this, cx| {
1009                let result = Self::run_oauth_flow(
1010                    this.clone(),
1011                    id.clone(),
1012                    discovery.clone(),
1013                    configuration.clone(),
1014                    cx,
1015                )
1016                .await;
1017
1018                if let Err(err) = &result {
1019                    log::error!("{} OAuth authentication failed: {:?}", id, err);
1020                    // Transition back to AuthRequired so the user can retry
1021                    // rather than landing in a terminal Error state.
1022                    this.update(cx, |this, cx| {
1023                        this.update_server_state(
1024                            id.clone(),
1025                            ContextServerState::AuthRequired {
1026                                server,
1027                                configuration,
1028                                discovery,
1029                            },
1030                            cx,
1031                        )
1032                    })
1033                    .log_err();
1034                }
1035            }
1036        });
1037
1038        self.update_server_state(
1039            id,
1040            ContextServerState::Authenticating {
1041                server,
1042                configuration,
1043                _task: task,
1044            },
1045            cx,
1046        );
1047
1048        Ok(())
1049    }
1050
1051    async fn run_oauth_flow(
1052        this: WeakEntity<Self>,
1053        id: ContextServerId,
1054        discovery: Arc<OAuthDiscovery>,
1055        configuration: Arc<ContextServerConfiguration>,
1056        cx: &mut AsyncApp,
1057    ) -> Result<()> {
1058        let resource = oauth::canonical_server_uri(&discovery.resource_metadata.resource);
1059        let pkce = oauth::generate_pkce_challenge();
1060
1061        let mut state_bytes = [0u8; 32];
1062        rand::rng().fill(&mut state_bytes);
1063        let state_param: String = state_bytes.iter().map(|b| format!("{:02x}", b)).collect();
1064
1065        // Start a loopback HTTP server on an ephemeral port. The redirect URI
1066        // includes this port so the browser sends the callback directly to our
1067        // process.
1068        let (redirect_uri, callback_rx) = oauth::start_callback_server()
1069            .await
1070            .context("Failed to start OAuth callback server")?;
1071
1072        let http_client = cx.update(|cx| cx.http_client());
1073        let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
1074        let server_url = match configuration.as_ref() {
1075            ContextServerConfiguration::Http { url, .. } => url.clone(),
1076            _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1077        };
1078
1079        let client_registration =
1080            oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
1081                .await
1082                .context("Failed to resolve OAuth client registration")?;
1083
1084        let auth_url = oauth::build_authorization_url(
1085            &discovery.auth_server_metadata,
1086            &client_registration.client_id,
1087            &redirect_uri,
1088            &discovery.scopes,
1089            &resource,
1090            &pkce,
1091            &state_param,
1092        );
1093
1094        cx.update(|cx| cx.open_url(auth_url.as_str()));
1095
1096        let callback = callback_rx
1097            .await
1098            .map_err(|_| {
1099                anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
1100            })?
1101            .context("OAuth callback server received an invalid request")?;
1102
1103        if callback.state != state_param {
1104            anyhow::bail!("OAuth state parameter mismatch (possible CSRF)");
1105        }
1106
1107        let tokens = oauth::exchange_code(
1108            &http_client,
1109            &discovery.auth_server_metadata,
1110            &callback.code,
1111            &client_registration.client_id,
1112            &redirect_uri,
1113            &pkce.verifier,
1114            &resource,
1115        )
1116        .await
1117        .context("Failed to exchange authorization code for tokens")?;
1118
1119        let session = OAuthSession {
1120            token_endpoint: discovery.auth_server_metadata.token_endpoint.clone(),
1121            resource: discovery.resource_metadata.resource.clone(),
1122            client_registration,
1123            tokens,
1124        };
1125
1126        Self::store_session(&credentials_provider, &server_url, &session, cx)
1127            .await
1128            .context("Failed to persist OAuth session in keychain")?;
1129
1130        let token_provider = Self::create_oauth_token_provider(
1131            &id,
1132            &server_url,
1133            session,
1134            http_client.clone(),
1135            credentials_provider,
1136            cx,
1137        );
1138
1139        let new_server = this.update(cx, |this, cx| {
1140            let global_timeout =
1141                Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
1142
1143            match configuration.as_ref() {
1144                ContextServerConfiguration::Http {
1145                    url,
1146                    headers,
1147                    timeout,
1148                } => {
1149                    let transport = HttpTransport::new_with_token_provider(
1150                        http_client.clone(),
1151                        url.to_string(),
1152                        headers.clone(),
1153                        cx.background_executor().clone(),
1154                        Some(token_provider.clone()),
1155                    );
1156                    Ok(Arc::new(ContextServer::new_with_timeout(
1157                        id.clone(),
1158                        Arc::new(transport),
1159                        Some(Duration::from_secs(
1160                            timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
1161                        )),
1162                    )))
1163                }
1164                _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1165            }
1166        })??;
1167
1168        this.update(cx, |this, cx| {
1169            this.run_server(new_server, configuration, cx);
1170        })?;
1171
1172        Ok(())
1173    }
1174
1175    /// Store the full OAuth session in the system keychain, keyed by the
1176    /// server's canonical URI.
1177    async fn store_session(
1178        credentials_provider: &Arc<dyn CredentialsProvider>,
1179        server_url: &url::Url,
1180        session: &OAuthSession,
1181        cx: &AsyncApp,
1182    ) -> Result<()> {
1183        let key = Self::keychain_key(server_url);
1184        let json = serde_json::to_string(session)?;
1185        credentials_provider
1186            .write_credentials(&key, "mcp-oauth", json.as_bytes(), cx)
1187            .await
1188    }
1189
1190    /// Load the full OAuth session from the system keychain for the given
1191    /// server URL.
1192    async fn load_session(
1193        credentials_provider: &Arc<dyn CredentialsProvider>,
1194        server_url: &url::Url,
1195        cx: &AsyncApp,
1196    ) -> Result<Option<OAuthSession>> {
1197        let key = Self::keychain_key(server_url);
1198        match credentials_provider.read_credentials(&key, cx).await? {
1199            Some((_username, password_bytes)) => {
1200                let session: OAuthSession = serde_json::from_slice(&password_bytes)?;
1201                Ok(Some(session))
1202            }
1203            None => Ok(None),
1204        }
1205    }
1206
1207    /// Clear the stored OAuth session from the system keychain.
1208    async fn clear_session(
1209        credentials_provider: &Arc<dyn CredentialsProvider>,
1210        server_url: &url::Url,
1211        cx: &AsyncApp,
1212    ) -> Result<()> {
1213        let key = Self::keychain_key(server_url);
1214        credentials_provider.delete_credentials(&key, cx).await
1215    }
1216
1217    fn keychain_key(server_url: &url::Url) -> String {
1218        format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
1219    }
1220
1221    /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
1222    /// session from the keychain and stop the server.
1223    pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
1224        let state = self.servers.get(id).context("Context server not found")?;
1225        let configuration = state.configuration();
1226
1227        let server_url = match configuration.as_ref() {
1228            ContextServerConfiguration::Http { url, .. } => url.clone(),
1229            _ => anyhow::bail!("logout only applies to HTTP servers with OAuth"),
1230        };
1231
1232        let id = id.clone();
1233        self.stop_server(&id, cx)?;
1234
1235        cx.spawn(async move |this, cx| {
1236            let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
1237            if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
1238                log::error!("{} failed to clear OAuth session: {}", id, err);
1239            }
1240            // Trigger server recreation so the next start uses a fresh
1241            // transport without the old (now-invalidated) token provider.
1242            this.update(cx, |this, cx| {
1243                this.available_context_servers_changed(cx);
1244            })
1245            .log_err();
1246        })
1247        .detach();
1248
1249        Ok(())
1250    }
1251
1252    fn update_server_state(
1253        &mut self,
1254        id: ContextServerId,
1255        state: ContextServerState,
1256        cx: &mut Context<Self>,
1257    ) {
1258        let status = ContextServerStatus::from_state(&state);
1259        self.servers.insert(id.clone(), state);
1260        cx.emit(ServerStatusChangedEvent {
1261            server_id: id,
1262            status,
1263        });
1264    }
1265
1266    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
1267        if self.update_servers_task.is_some() {
1268            self.needs_server_update = true;
1269        } else {
1270            self.needs_server_update = false;
1271            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
1272                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
1273                    log::error!("Error maintaining context servers: {}", err);
1274                }
1275
1276                this.update(cx, |this, cx| {
1277                    this.populate_server_ids(cx);
1278                    cx.notify();
1279                    this.update_servers_task.take();
1280                    if this.needs_server_update {
1281                        this.available_context_servers_changed(cx);
1282                    }
1283                })?;
1284
1285                Ok(())
1286            }));
1287        }
1288    }
1289
1290    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
1291        // Don't start context servers if AI is disabled
1292        let ai_disabled = this.update(cx, |_, cx| DisableAiSettings::get_global(cx).disable_ai)?;
1293        if ai_disabled {
1294            // Stop all running servers when AI is disabled
1295            this.update(cx, |this, cx| {
1296                let server_ids: Vec<_> = this.servers.keys().cloned().collect();
1297                for id in server_ids {
1298                    let _ = this.stop_server(&id, cx);
1299                }
1300            })?;
1301            return Ok(());
1302        }
1303
1304        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
1305            (
1306                this.context_server_settings.clone(),
1307                this.registry.clone(),
1308                this.worktree_store.clone(),
1309            )
1310        })?;
1311
1312        for (id, _) in registry.read_with(cx, |registry, _| registry.context_server_descriptors()) {
1313            configured_servers
1314                .entry(id)
1315                .or_insert(ContextServerSettings::default_extension());
1316        }
1317
1318        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
1319            configured_servers
1320                .into_iter()
1321                .partition(|(_, settings)| settings.enabled());
1322
1323        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
1324            let id = ContextServerId(id);
1325            ContextServerConfiguration::from_settings(
1326                settings,
1327                id.clone(),
1328                registry.clone(),
1329                worktree_store.clone(),
1330                cx,
1331            )
1332            .map(move |config| (id, config))
1333        }))
1334        .await
1335        .into_iter()
1336        .filter_map(|(id, config)| config.map(|config| (id, config)))
1337        .collect::<HashMap<_, _>>();
1338
1339        let mut servers_to_start = Vec::new();
1340        let mut servers_to_remove = HashSet::default();
1341        let mut servers_to_stop = HashSet::default();
1342
1343        this.update(cx, |this, _cx| {
1344            for server_id in this.servers.keys() {
1345                // All servers that are not in desired_servers should be removed from the store.
1346                // This can happen if the user removed a server from the context server settings.
1347                if !configured_servers.contains_key(server_id) {
1348                    if disabled_servers.contains_key(&server_id.0) {
1349                        servers_to_stop.insert(server_id.clone());
1350                    } else {
1351                        servers_to_remove.insert(server_id.clone());
1352                    }
1353                }
1354            }
1355
1356            for (id, config) in configured_servers {
1357                let state = this.servers.get(&id);
1358                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
1359                let existing_config = state.as_ref().map(|state| state.configuration());
1360                if existing_config.as_deref() != Some(&config) || is_stopped {
1361                    let config = Arc::new(config);
1362                    servers_to_start.push((id.clone(), config));
1363                    if this.servers.contains_key(&id) {
1364                        servers_to_stop.insert(id);
1365                    }
1366                }
1367            }
1368
1369            anyhow::Ok(())
1370        })??;
1371
1372        this.update(cx, |this, inner_cx| {
1373            for id in servers_to_stop {
1374                this.stop_server(&id, inner_cx)?;
1375            }
1376            for id in servers_to_remove {
1377                this.remove_server(&id, inner_cx)?;
1378            }
1379            anyhow::Ok(())
1380        })??;
1381
1382        for (id, config) in servers_to_start {
1383            match Self::create_context_server(this.clone(), id.clone(), config, cx).await {
1384                Ok((server, config)) => {
1385                    this.update(cx, |this, cx| {
1386                        this.run_server(server, config, cx);
1387                    })?;
1388                }
1389                Err(err) => {
1390                    log::error!("{id} context server failed to create: {err:#}");
1391                    this.update(cx, |_this, cx| {
1392                        cx.emit(ServerStatusChangedEvent {
1393                            server_id: id,
1394                            status: ContextServerStatus::Error(err.to_string().into()),
1395                        });
1396                        cx.notify();
1397                    })?;
1398                }
1399            }
1400        }
1401
1402        Ok(())
1403    }
1404}
1405
1406/// Determines the appropriate server state after a start attempt fails.
1407///
1408/// When the error is an HTTP 401 with no static auth header configured,
1409/// attempts OAuth discovery so the UI can offer an authentication flow.
1410async fn resolve_start_failure(
1411    id: &ContextServerId,
1412    err: anyhow::Error,
1413    server: Arc<ContextServer>,
1414    configuration: Arc<ContextServerConfiguration>,
1415    cx: &AsyncApp,
1416) -> ContextServerState {
1417    let www_authenticate = err.downcast_ref::<TransportError>().map(|e| match e {
1418        TransportError::AuthRequired { www_authenticate } => www_authenticate.clone(),
1419    });
1420
1421    if www_authenticate.is_some() && configuration.has_static_auth_header() {
1422        log::warn!("{id} received 401 with a static Authorization header configured");
1423        return ContextServerState::Error {
1424            configuration,
1425            server,
1426            error: "Server returned 401 Unauthorized. Check your configured Authorization header."
1427                .into(),
1428        };
1429    }
1430
1431    let server_url = match configuration.as_ref() {
1432        ContextServerConfiguration::Http { url, .. } if !configuration.has_static_auth_header() => {
1433            url.clone()
1434        }
1435        _ => {
1436            if www_authenticate.is_some() {
1437                log::error!("{id} got OAuth 401 on a non-HTTP transport or with static auth");
1438            } else {
1439                log::error!("{id} context server failed to start: {err}");
1440            }
1441            return ContextServerState::Error {
1442                configuration,
1443                server,
1444                error: err.to_string().into(),
1445            };
1446        }
1447    };
1448
1449    // When the error is NOT a 401 but there is a cached OAuth session in the
1450    // keychain, the session is likely stale/expired and caused the failure
1451    // (e.g. timeout because the server rejected the token silently). Clear it
1452    // so the next start attempt can get a clean 401 and trigger the auth flow.
1453    if www_authenticate.is_none() {
1454        let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
1455        match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
1456            Ok(Some(_)) => {
1457                log::info!("{id} start failed with a cached OAuth session present; clearing it");
1458                ContextServerStore::clear_session(&credentials_provider, &server_url, cx)
1459                    .await
1460                    .log_err();
1461            }
1462            _ => {
1463                log::error!("{id} context server failed to start: {err}");
1464                return ContextServerState::Error {
1465                    configuration,
1466                    server,
1467                    error: err.to_string().into(),
1468                };
1469            }
1470        }
1471    }
1472
1473    let default_www_authenticate = oauth::WwwAuthenticate {
1474        resource_metadata: None,
1475        scope: None,
1476        error: None,
1477        error_description: None,
1478    };
1479    let www_authenticate = www_authenticate
1480        .as_ref()
1481        .unwrap_or(&default_www_authenticate);
1482    let http_client = cx.update(|cx| cx.http_client());
1483
1484    match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
1485        Ok(discovery) => {
1486            log::info!(
1487                "{id} requires OAuth authorization (auth server: {})",
1488                discovery.auth_server_metadata.issuer,
1489            );
1490            ContextServerState::AuthRequired {
1491                server,
1492                configuration,
1493                discovery: Arc::new(discovery),
1494            }
1495        }
1496        Err(discovery_err) => {
1497            log::error!("{id} OAuth discovery failed: {discovery_err}");
1498            ContextServerState::Error {
1499                configuration,
1500                server,
1501                error: format!("OAuth discovery failed: {discovery_err}").into(),
1502            }
1503        }
1504    }
1505}