context_server_store.rs

   1pub mod extension;
   2pub mod registry;
   3
   4use std::path::Path;
   5use std::sync::Arc;
   6use std::time::Duration;
   7
   8use anyhow::{Context as _, Result};
   9use collections::{HashMap, HashSet};
  10use context_server::oauth::{self, McpOAuthTokenProvider, OAuthDiscovery, OAuthSession};
  11use context_server::transport::{HttpTransport, TransportError};
  12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  13use credentials_provider::CredentialsProvider;
  14use futures::future::Either;
  15use futures::{FutureExt as _, StreamExt as _, future::join_all};
  16use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
  17use http_client::HttpClient;
  18use itertools::Itertools;
  19use rand::Rng as _;
  20use registry::ContextServerDescriptorRegistry;
  21use remote::RemoteClient;
  22use rpc::{AnyProtoClient, TypedEnvelope, proto};
  23use settings::{Settings as _, SettingsStore};
  24use util::{ResultExt as _, rel_path::RelPath};
  25
  26use crate::{
  27    DisableAiSettings, Project,
  28    project_settings::{ContextServerSettings, OAuthClientSettings, ProjectSettings},
  29    worktree_store::WorktreeStore,
  30};
  31
  32/// Maximum timeout for context server requests
  33/// Prevents extremely large timeout values from tying up resources indefinitely.
  34const MAX_TIMEOUT_SECS: u64 = 600; // 10 minutes
  35
  36pub fn init(cx: &mut App) {
  37    extension::init(cx);
  38}
  39
  40actions!(
  41    context_server,
  42    [
  43        /// Restarts the context server.
  44        Restart
  45    ]
  46);
  47
  48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  49pub enum ContextServerStatus {
  50    Starting,
  51    Running,
  52    Stopped,
  53    Error(Arc<str>),
  54    /// The server returned 401 and OAuth authorization is needed. The UI
  55    /// should show an "Authenticate" button.
  56    AuthRequired,
  57    /// The server has a pre-registered OAuth client_id, but a client_secret
  58    /// is needed and not available in settings or the keychain.
  59    ClientSecretRequired {
  60        error: Option<Arc<str>>,
  61    },
  62    /// The OAuth browser flow is in progress — the user has been redirected
  63    /// to the authorization server and we're waiting for the callback.
  64    Authenticating,
  65}
  66
  67impl ContextServerStatus {
  68    fn from_state(state: &ContextServerState) -> Self {
  69        match state {
  70            ContextServerState::Starting { .. } => ContextServerStatus::Starting,
  71            ContextServerState::Running { .. } => ContextServerStatus::Running,
  72            ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
  73            ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
  74            ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
  75            ContextServerState::ClientSecretRequired { error, .. } => {
  76                ContextServerStatus::ClientSecretRequired {
  77                    error: error.clone(),
  78                }
  79            }
  80            ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
  81        }
  82    }
  83}
  84
  85enum ContextServerState {
  86    Starting {
  87        server: Arc<ContextServer>,
  88        configuration: Arc<ContextServerConfiguration>,
  89        _task: Task<()>,
  90    },
  91    Running {
  92        server: Arc<ContextServer>,
  93        configuration: Arc<ContextServerConfiguration>,
  94    },
  95    Stopped {
  96        server: Arc<ContextServer>,
  97        configuration: Arc<ContextServerConfiguration>,
  98    },
  99    Error {
 100        server: Arc<ContextServer>,
 101        configuration: Arc<ContextServerConfiguration>,
 102        error: Arc<str>,
 103    },
 104    /// The server requires OAuth authorization before it can be used. The
 105    /// `OAuthDiscovery` holds everything needed to start the browser flow.
 106    AuthRequired {
 107        server: Arc<ContextServer>,
 108        configuration: Arc<ContextServerConfiguration>,
 109        discovery: Arc<OAuthDiscovery>,
 110    },
 111    /// A pre-registered client_id is configured but no client_secret was found
 112    /// in settings or the keychain.
 113    ClientSecretRequired {
 114        server: Arc<ContextServer>,
 115        configuration: Arc<ContextServerConfiguration>,
 116        discovery: Arc<OAuthDiscovery>,
 117        error: Option<Arc<str>>,
 118    },
 119    /// The OAuth browser flow is in progress. The user has been redirected
 120    /// to the authorization server and we're waiting for the callback.
 121    Authenticating {
 122        server: Arc<ContextServer>,
 123        configuration: Arc<ContextServerConfiguration>,
 124        _task: Task<()>,
 125    },
 126}
 127
 128impl ContextServerState {
 129    pub fn server(&self) -> Arc<ContextServer> {
 130        match self {
 131            ContextServerState::Starting { server, .. }
 132            | ContextServerState::Running { server, .. }
 133            | ContextServerState::Stopped { server, .. }
 134            | ContextServerState::Error { server, .. }
 135            | ContextServerState::AuthRequired { server, .. }
 136            | ContextServerState::ClientSecretRequired { server, .. }
 137            | ContextServerState::Authenticating { server, .. } => server.clone(),
 138        }
 139    }
 140
 141    pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
 142        match self {
 143            ContextServerState::Starting { configuration, .. }
 144            | ContextServerState::Running { configuration, .. }
 145            | ContextServerState::Stopped { configuration, .. }
 146            | ContextServerState::Error { configuration, .. }
 147            | ContextServerState::AuthRequired { configuration, .. }
 148            | ContextServerState::ClientSecretRequired { configuration, .. }
 149            | ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
 150        }
 151    }
 152}
 153
 154#[derive(Debug, PartialEq, Eq)]
 155pub enum ContextServerConfiguration {
 156    Custom {
 157        command: ContextServerCommand,
 158        remote: bool,
 159    },
 160    Extension {
 161        command: ContextServerCommand,
 162        settings: serde_json::Value,
 163        remote: bool,
 164    },
 165    Http {
 166        url: url::Url,
 167        headers: HashMap<String, String>,
 168        timeout: Option<u64>,
 169        oauth: Option<OAuthClientSettings>,
 170    },
 171}
 172
 173impl ContextServerConfiguration {
 174    pub fn command(&self) -> Option<&ContextServerCommand> {
 175        match self {
 176            ContextServerConfiguration::Custom { command, .. } => Some(command),
 177            ContextServerConfiguration::Extension { command, .. } => Some(command),
 178            ContextServerConfiguration::Http { .. } => None,
 179        }
 180    }
 181
 182    pub fn has_static_auth_header(&self) -> bool {
 183        match self {
 184            ContextServerConfiguration::Http { headers, .. } => headers
 185                .keys()
 186                .any(|k| k.eq_ignore_ascii_case("authorization")),
 187            _ => false,
 188        }
 189    }
 190
 191    pub fn remote(&self) -> bool {
 192        match self {
 193            ContextServerConfiguration::Custom { remote, .. } => *remote,
 194            ContextServerConfiguration::Extension { remote, .. } => *remote,
 195            ContextServerConfiguration::Http { .. } => false,
 196        }
 197    }
 198
 199    pub async fn from_settings(
 200        settings: ContextServerSettings,
 201        id: ContextServerId,
 202        registry: Entity<ContextServerDescriptorRegistry>,
 203        worktree_store: Entity<WorktreeStore>,
 204        cx: &AsyncApp,
 205    ) -> Option<Self> {
 206        const EXTENSION_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
 207
 208        match settings {
 209            ContextServerSettings::Stdio {
 210                enabled: _,
 211                command,
 212                remote,
 213            } => Some(ContextServerConfiguration::Custom { command, remote }),
 214            ContextServerSettings::Extension {
 215                enabled: _,
 216                settings,
 217                remote,
 218            } => {
 219                let descriptor =
 220                    cx.update(|cx| registry.read(cx).context_server_descriptor(&id.0))?;
 221
 222                let command_future = descriptor.command(worktree_store, cx);
 223                let timeout_future = cx.background_executor().timer(EXTENSION_COMMAND_TIMEOUT);
 224
 225                match futures::future::select(command_future, timeout_future).await {
 226                    Either::Left((Ok(command), _)) => Some(ContextServerConfiguration::Extension {
 227                        command,
 228                        settings,
 229                        remote,
 230                    }),
 231                    Either::Left((Err(e), _)) => {
 232                        log::error!(
 233                            "Failed to create context server configuration from settings: {e:#}"
 234                        );
 235                        None
 236                    }
 237                    Either::Right(_) => {
 238                        log::error!(
 239                            "Timed out resolving command for extension context server {id}"
 240                        );
 241                        None
 242                    }
 243                }
 244            }
 245            ContextServerSettings::Http {
 246                enabled: _,
 247                url,
 248                headers: auth,
 249                timeout,
 250                oauth,
 251            } => {
 252                let url = url::Url::parse(&url).log_err()?;
 253                Some(ContextServerConfiguration::Http {
 254                    url,
 255                    headers: auth,
 256                    timeout,
 257                    oauth,
 258                })
 259            }
 260        }
 261    }
 262}
 263
 264pub type ContextServerFactory =
 265    Box<dyn Fn(ContextServerId, Arc<ContextServerConfiguration>) -> Arc<ContextServer>>;
 266
 267enum ContextServerStoreState {
 268    Local {
 269        downstream_client: Option<(u64, AnyProtoClient)>,
 270        is_headless: bool,
 271    },
 272    Remote {
 273        project_id: u64,
 274        upstream_client: Entity<RemoteClient>,
 275    },
 276}
 277
 278pub struct ContextServerStore {
 279    state: ContextServerStoreState,
 280    context_server_settings: HashMap<Arc<str>, ContextServerSettings>,
 281    servers: HashMap<ContextServerId, ContextServerState>,
 282    server_ids: Vec<ContextServerId>,
 283    worktree_store: Entity<WorktreeStore>,
 284    project: Option<WeakEntity<Project>>,
 285    registry: Entity<ContextServerDescriptorRegistry>,
 286    update_servers_task: Option<Task<Result<()>>>,
 287    context_server_factory: Option<ContextServerFactory>,
 288    needs_server_update: bool,
 289    ai_disabled: bool,
 290    _subscriptions: Vec<Subscription>,
 291}
 292
 293pub struct ServerStatusChangedEvent {
 294    pub server_id: ContextServerId,
 295    pub status: ContextServerStatus,
 296}
 297
 298impl EventEmitter<ServerStatusChangedEvent> for ContextServerStore {}
 299
 300impl ContextServerStore {
 301    pub fn local(
 302        worktree_store: Entity<WorktreeStore>,
 303        weak_project: Option<WeakEntity<Project>>,
 304        headless: bool,
 305        cx: &mut Context<Self>,
 306    ) -> Self {
 307        Self::new_internal(
 308            !headless,
 309            None,
 310            ContextServerDescriptorRegistry::default_global(cx),
 311            worktree_store,
 312            weak_project,
 313            ContextServerStoreState::Local {
 314                downstream_client: None,
 315                is_headless: headless,
 316            },
 317            cx,
 318        )
 319    }
 320
 321    pub fn remote(
 322        project_id: u64,
 323        upstream_client: Entity<RemoteClient>,
 324        worktree_store: Entity<WorktreeStore>,
 325        weak_project: Option<WeakEntity<Project>>,
 326        cx: &mut Context<Self>,
 327    ) -> Self {
 328        Self::new_internal(
 329            true,
 330            None,
 331            ContextServerDescriptorRegistry::default_global(cx),
 332            worktree_store,
 333            weak_project,
 334            ContextServerStoreState::Remote {
 335                project_id,
 336                upstream_client,
 337            },
 338            cx,
 339        )
 340    }
 341
 342    pub fn init_headless(session: &AnyProtoClient) {
 343        session.add_entity_request_handler(Self::handle_get_context_server_command);
 344    }
 345
 346    pub fn shared(&mut self, project_id: u64, client: AnyProtoClient) {
 347        if let ContextServerStoreState::Local {
 348            downstream_client, ..
 349        } = &mut self.state
 350        {
 351            *downstream_client = Some((project_id, client));
 352        }
 353    }
 354
 355    pub fn is_remote_project(&self) -> bool {
 356        matches!(self.state, ContextServerStoreState::Remote { .. })
 357    }
 358
 359    /// Returns all configured context server ids, excluding the ones that are disabled
 360    pub fn configured_server_ids(&self) -> Vec<ContextServerId> {
 361        self.context_server_settings
 362            .iter()
 363            .filter(|(_, settings)| settings.enabled())
 364            .map(|(id, _)| ContextServerId(id.clone()))
 365            .collect()
 366    }
 367
 368    #[cfg(feature = "test-support")]
 369    pub fn test(
 370        registry: Entity<ContextServerDescriptorRegistry>,
 371        worktree_store: Entity<WorktreeStore>,
 372        weak_project: Option<WeakEntity<Project>>,
 373        cx: &mut Context<Self>,
 374    ) -> Self {
 375        Self::new_internal(
 376            false,
 377            None,
 378            registry,
 379            worktree_store,
 380            weak_project,
 381            ContextServerStoreState::Local {
 382                downstream_client: None,
 383                is_headless: false,
 384            },
 385            cx,
 386        )
 387    }
 388
 389    #[cfg(feature = "test-support")]
 390    pub fn test_maintain_server_loop(
 391        context_server_factory: Option<ContextServerFactory>,
 392        registry: Entity<ContextServerDescriptorRegistry>,
 393        worktree_store: Entity<WorktreeStore>,
 394        weak_project: Option<WeakEntity<Project>>,
 395        cx: &mut Context<Self>,
 396    ) -> Self {
 397        Self::new_internal(
 398            true,
 399            context_server_factory,
 400            registry,
 401            worktree_store,
 402            weak_project,
 403            ContextServerStoreState::Local {
 404                downstream_client: None,
 405                is_headless: false,
 406            },
 407            cx,
 408        )
 409    }
 410
 411    #[cfg(feature = "test-support")]
 412    pub fn set_context_server_factory(&mut self, factory: ContextServerFactory) {
 413        self.context_server_factory = Some(factory);
 414    }
 415
 416    #[cfg(feature = "test-support")]
 417    pub fn registry(&self) -> &Entity<ContextServerDescriptorRegistry> {
 418        &self.registry
 419    }
 420
 421    #[cfg(feature = "test-support")]
 422    pub fn test_start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 423        let configuration = Arc::new(ContextServerConfiguration::Custom {
 424            command: ContextServerCommand {
 425                path: "test".into(),
 426                args: vec![],
 427                env: None,
 428                timeout: None,
 429            },
 430            remote: false,
 431        });
 432        self.run_server(server, configuration, cx);
 433    }
 434
 435    fn new_internal(
 436        maintain_server_loop: bool,
 437        context_server_factory: Option<ContextServerFactory>,
 438        registry: Entity<ContextServerDescriptorRegistry>,
 439        worktree_store: Entity<WorktreeStore>,
 440        weak_project: Option<WeakEntity<Project>>,
 441        state: ContextServerStoreState,
 442        cx: &mut Context<Self>,
 443    ) -> Self {
 444        let mut subscriptions = vec![cx.observe_global::<SettingsStore>(move |this, cx| {
 445            let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
 446            let ai_was_disabled = this.ai_disabled;
 447            this.ai_disabled = ai_disabled;
 448
 449            let settings =
 450                &Self::resolve_project_settings(&this.worktree_store, cx).context_servers;
 451            let settings_changed = &this.context_server_settings != settings;
 452
 453            if settings_changed {
 454                this.context_server_settings = settings.clone();
 455            }
 456
 457            // When AI is disabled, stop all running servers
 458            if ai_disabled {
 459                let server_ids: Vec<_> = this.servers.keys().cloned().collect();
 460                for id in server_ids {
 461                    this.stop_server(&id, cx).log_err();
 462                }
 463                return;
 464            }
 465
 466            // Trigger updates if AI was re-enabled or settings changed
 467            if maintain_server_loop && (ai_was_disabled || settings_changed) {
 468                this.available_context_servers_changed(cx);
 469            }
 470        })];
 471
 472        if maintain_server_loop {
 473            subscriptions.push(cx.observe(&registry, |this, _registry, cx| {
 474                if !DisableAiSettings::get_global(cx).disable_ai {
 475                    this.available_context_servers_changed(cx);
 476                }
 477            }));
 478        }
 479
 480        let ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
 481        let mut this = Self {
 482            state,
 483            _subscriptions: subscriptions,
 484            context_server_settings: Self::resolve_project_settings(&worktree_store, cx)
 485                .context_servers
 486                .clone(),
 487            worktree_store,
 488            project: weak_project,
 489            registry,
 490            needs_server_update: false,
 491            ai_disabled,
 492            servers: HashMap::default(),
 493            server_ids: Default::default(),
 494            update_servers_task: None,
 495            context_server_factory,
 496        };
 497        if maintain_server_loop && !DisableAiSettings::get_global(cx).disable_ai {
 498            this.available_context_servers_changed(cx);
 499        }
 500        this
 501    }
 502
 503    pub fn get_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 504        self.servers.get(id).map(|state| state.server())
 505    }
 506
 507    pub fn get_running_server(&self, id: &ContextServerId) -> Option<Arc<ContextServer>> {
 508        if let Some(ContextServerState::Running { server, .. }) = self.servers.get(id) {
 509            Some(server.clone())
 510        } else {
 511            None
 512        }
 513    }
 514
 515    pub fn status_for_server(&self, id: &ContextServerId) -> Option<ContextServerStatus> {
 516        self.servers.get(id).map(ContextServerStatus::from_state)
 517    }
 518
 519    pub fn configuration_for_server(
 520        &self,
 521        id: &ContextServerId,
 522    ) -> Option<Arc<ContextServerConfiguration>> {
 523        self.servers.get(id).map(|state| state.configuration())
 524    }
 525
 526    /// Returns a sorted slice of available unique context server IDs. Within the
 527    /// slice, context servers which have `mcp-server-` as a prefix in their ID will
 528    /// appear after servers that do not have this prefix in their ID.
 529    pub fn server_ids(&self) -> &[ContextServerId] {
 530        self.server_ids.as_slice()
 531    }
 532
 533    fn populate_server_ids(&mut self, cx: &App) {
 534        self.server_ids = self
 535            .servers
 536            .keys()
 537            .cloned()
 538            .chain(
 539                self.registry
 540                    .read(cx)
 541                    .context_server_descriptors()
 542                    .into_iter()
 543                    .map(|(id, _)| ContextServerId(id)),
 544            )
 545            .chain(
 546                self.context_server_settings
 547                    .keys()
 548                    .map(|id| ContextServerId(id.clone())),
 549            )
 550            .unique()
 551            .sorted_unstable_by(
 552                // Sort context servers: ones without mcp-server- prefix first, then prefixed ones
 553                |a, b| {
 554                    const MCP_PREFIX: &str = "mcp-server-";
 555                    match (a.0.strip_prefix(MCP_PREFIX), b.0.strip_prefix(MCP_PREFIX)) {
 556                        // If one has mcp-server- prefix and other doesn't, non-mcp comes first
 557                        (Some(_), None) => std::cmp::Ordering::Greater,
 558                        (None, Some(_)) => std::cmp::Ordering::Less,
 559                        // If both have same prefix status, sort by appropriate key
 560                        (Some(a), Some(b)) => a.cmp(b),
 561                        (None, None) => a.0.cmp(&b.0),
 562                    }
 563                },
 564            )
 565            .collect();
 566    }
 567
 568    pub fn running_servers(&self) -> Vec<Arc<ContextServer>> {
 569        self.servers
 570            .values()
 571            .filter_map(|state| {
 572                if let ContextServerState::Running { server, .. } = state {
 573                    Some(server.clone())
 574                } else {
 575                    None
 576                }
 577            })
 578            .collect()
 579    }
 580
 581    pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
 582        cx.spawn(async move |this, cx| {
 583            let this = this.upgrade().context("Context server store dropped")?;
 584            let id = server.id();
 585            let settings = this
 586                .update(cx, |this, _| {
 587                    this.context_server_settings.get(&id.0).cloned()
 588                })
 589                .context("Failed to get context server settings")?;
 590
 591            if !settings.enabled() {
 592                return anyhow::Ok(());
 593            }
 594
 595            let (registry, worktree_store) = this.update(cx, |this, _| {
 596                (this.registry.clone(), this.worktree_store.clone())
 597            });
 598            let configuration = ContextServerConfiguration::from_settings(
 599                settings,
 600                id.clone(),
 601                registry,
 602                worktree_store,
 603                cx,
 604            )
 605            .await
 606            .context("Failed to create context server configuration")?;
 607
 608            this.update(cx, |this, cx| {
 609                this.run_server(server, Arc::new(configuration), cx)
 610            });
 611            Ok(())
 612        })
 613        .detach_and_log_err(cx);
 614    }
 615
 616    pub fn stop_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 617        if matches!(
 618            self.servers.get(id),
 619            Some(ContextServerState::Stopped { .. })
 620        ) {
 621            return Ok(());
 622        }
 623
 624        let state = self
 625            .servers
 626            .remove(id)
 627            .context("Context server not found")?;
 628
 629        let server = state.server();
 630        let configuration = state.configuration();
 631        let mut result = Ok(());
 632        if let ContextServerState::Running { server, .. } = &state {
 633            result = server.stop();
 634        }
 635        drop(state);
 636
 637        self.update_server_state(
 638            id.clone(),
 639            ContextServerState::Stopped {
 640                configuration,
 641                server,
 642            },
 643            cx,
 644        );
 645
 646        result
 647    }
 648
 649    fn run_server(
 650        &mut self,
 651        server: Arc<ContextServer>,
 652        configuration: Arc<ContextServerConfiguration>,
 653        cx: &mut Context<Self>,
 654    ) {
 655        let id = server.id();
 656        if matches!(
 657            self.servers.get(&id),
 658            Some(
 659                ContextServerState::Starting { .. }
 660                    | ContextServerState::Running { .. }
 661                    | ContextServerState::Authenticating { .. },
 662            )
 663        ) {
 664            self.stop_server(&id, cx).log_err();
 665        }
 666        let task = cx.spawn({
 667            let id = server.id();
 668            let server = server.clone();
 669            let configuration = configuration.clone();
 670
 671            async move |this, cx| {
 672                let new_state = match server.clone().start(cx).await {
 673                    Ok(_) => {
 674                        debug_assert!(server.client().is_some());
 675                        ContextServerState::Running {
 676                            server,
 677                            configuration,
 678                        }
 679                    }
 680                    Err(err) => resolve_start_failure(&id, err, server, configuration, cx).await,
 681                };
 682                this.update(cx, |this, cx| {
 683                    this.update_server_state(id.clone(), new_state, cx)
 684                })
 685                .log_err();
 686            }
 687        });
 688
 689        self.update_server_state(
 690            id.clone(),
 691            ContextServerState::Starting {
 692                configuration,
 693                _task: task,
 694                server,
 695            },
 696            cx,
 697        );
 698    }
 699
 700    fn remove_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
 701        let state = self
 702            .servers
 703            .remove(id)
 704            .context("Context server not found")?;
 705
 706        if let ContextServerConfiguration::Http { url, .. } = state.configuration().as_ref() {
 707            let server_url = url.clone();
 708            let id = id.clone();
 709            cx.spawn(async move |_this, cx| {
 710                let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
 711                if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
 712                {
 713                    log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
 714                }
 715            })
 716            .detach();
 717        }
 718
 719        drop(state);
 720        cx.emit(ServerStatusChangedEvent {
 721            server_id: id.clone(),
 722            status: ContextServerStatus::Stopped,
 723        });
 724        Ok(())
 725    }
 726
 727    pub async fn create_context_server(
 728        this: WeakEntity<Self>,
 729        id: ContextServerId,
 730        configuration: Arc<ContextServerConfiguration>,
 731        cx: &mut AsyncApp,
 732    ) -> Result<(Arc<ContextServer>, Arc<ContextServerConfiguration>)> {
 733        let remote = configuration.remote();
 734        let needs_remote_command = match configuration.as_ref() {
 735            ContextServerConfiguration::Custom { .. }
 736            | ContextServerConfiguration::Extension { .. } => remote,
 737            ContextServerConfiguration::Http { .. } => false,
 738        };
 739
 740        let (remote_state, is_remote_project) = this.update(cx, |this, _| {
 741            let remote_state = match &this.state {
 742                ContextServerStoreState::Remote {
 743                    project_id,
 744                    upstream_client,
 745                } if needs_remote_command => Some((*project_id, upstream_client.clone())),
 746                _ => None,
 747            };
 748            (remote_state, this.is_remote_project())
 749        })?;
 750
 751        let root_path: Option<Arc<Path>> = this.update(cx, |this, cx| {
 752            this.project
 753                .as_ref()
 754                .and_then(|project| {
 755                    project
 756                        .read_with(cx, |project, cx| project.active_project_directory(cx))
 757                        .ok()
 758                        .flatten()
 759                })
 760                .or_else(|| {
 761                    this.worktree_store.read_with(cx, |store, cx| {
 762                        store.visible_worktrees(cx).fold(None, |acc, item| {
 763                            if acc.is_none() {
 764                                item.read(cx).root_dir()
 765                            } else {
 766                                acc
 767                            }
 768                        })
 769                    })
 770                })
 771        })?;
 772
 773        let configuration = if let Some((project_id, upstream_client)) = remote_state {
 774            let root_dir = root_path.as_ref().map(|p| p.display().to_string());
 775
 776            let response = upstream_client
 777                .update(cx, |client, _| {
 778                    client
 779                        .proto_client()
 780                        .request(proto::GetContextServerCommand {
 781                            project_id,
 782                            server_id: id.0.to_string(),
 783                            root_dir: root_dir.clone(),
 784                        })
 785                })
 786                .await?;
 787
 788            let remote_command = upstream_client.update(cx, |client, _| {
 789                client.build_command(
 790                    Some(response.path),
 791                    &response.args,
 792                    &response.env.into_iter().collect(),
 793                    root_dir,
 794                    None,
 795                )
 796            })?;
 797
 798            let command = ContextServerCommand {
 799                path: remote_command.program.into(),
 800                args: remote_command.args,
 801                env: Some(remote_command.env.into_iter().collect()),
 802                timeout: None,
 803            };
 804
 805            Arc::new(ContextServerConfiguration::Custom { command, remote })
 806        } else {
 807            configuration
 808        };
 809
 810        if let Some(server) = this.update(cx, |this, _| {
 811            this.context_server_factory
 812                .as_ref()
 813                .map(|factory| factory(id.clone(), configuration.clone()))
 814        })? {
 815            return Ok((server, configuration));
 816        }
 817
 818        let cached_token_provider: Option<Arc<dyn oauth::OAuthTokenProvider>> =
 819            if let ContextServerConfiguration::Http { url, .. } = configuration.as_ref() {
 820                if configuration.has_static_auth_header() {
 821                    None
 822                } else {
 823                    let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
 824                    let http_client = cx.update(|cx| cx.http_client());
 825
 826                    match Self::load_session(&credentials_provider, url, &cx).await {
 827                        Ok(Some(session)) => {
 828                            log::info!("{} loaded cached OAuth session from keychain", id);
 829                            Some(Self::create_oauth_token_provider(
 830                                &id,
 831                                url,
 832                                session,
 833                                http_client,
 834                                credentials_provider,
 835                                cx,
 836                            ))
 837                        }
 838                        Ok(None) => None,
 839                        Err(err) => {
 840                            log::warn!("{} failed to load cached OAuth session: {}", id, err);
 841                            None
 842                        }
 843                    }
 844                }
 845            } else {
 846                None
 847            };
 848
 849        let server: Arc<ContextServer> = this.update(cx, |this, cx| {
 850            let global_timeout =
 851                Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
 852
 853            match configuration.as_ref() {
 854                ContextServerConfiguration::Http {
 855                    url,
 856                    headers,
 857                    timeout,
 858                    oauth: _,
 859                } => {
 860                    let transport = HttpTransport::new_with_token_provider(
 861                        cx.http_client(),
 862                        url.to_string(),
 863                        headers.clone(),
 864                        cx.background_executor().clone(),
 865                        cached_token_provider.clone(),
 866                    );
 867                    anyhow::Ok(Arc::new(ContextServer::new_with_timeout(
 868                        id,
 869                        Arc::new(transport),
 870                        Some(Duration::from_secs(
 871                            timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
 872                        )),
 873                    )))
 874                }
 875                _ => {
 876                    let mut command = configuration
 877                        .command()
 878                        .context("Missing command configuration for stdio context server")?
 879                        .clone();
 880                    command.timeout = Some(
 881                        command
 882                            .timeout
 883                            .unwrap_or(global_timeout)
 884                            .min(MAX_TIMEOUT_SECS),
 885                    );
 886
 887                    // Don't pass remote paths as working directory for locally-spawned processes
 888                    let working_directory = if is_remote_project { None } else { root_path };
 889                    anyhow::Ok(Arc::new(ContextServer::stdio(
 890                        id,
 891                        command,
 892                        working_directory,
 893                    )))
 894                }
 895            }
 896        })??;
 897
 898        Ok((server, configuration))
 899    }
 900
 901    async fn handle_get_context_server_command(
 902        this: Entity<Self>,
 903        envelope: TypedEnvelope<proto::GetContextServerCommand>,
 904        mut cx: AsyncApp,
 905    ) -> Result<proto::ContextServerCommand> {
 906        let server_id = ContextServerId(envelope.payload.server_id.into());
 907
 908        let (settings, registry, worktree_store) = this.update(&mut cx, |this, inner_cx| {
 909            let ContextServerStoreState::Local {
 910                is_headless: true, ..
 911            } = &this.state
 912            else {
 913                anyhow::bail!("unexpected GetContextServerCommand request in a non-local project");
 914            };
 915
 916            let settings = this
 917                .context_server_settings
 918                .get(&server_id.0)
 919                .cloned()
 920                .or_else(|| {
 921                    this.registry
 922                        .read(inner_cx)
 923                        .context_server_descriptor(&server_id.0)
 924                        .map(|_| ContextServerSettings::default_extension())
 925                })
 926                .with_context(|| format!("context server `{}` not found", server_id))?;
 927
 928            anyhow::Ok((settings, this.registry.clone(), this.worktree_store.clone()))
 929        })?;
 930
 931        let configuration = ContextServerConfiguration::from_settings(
 932            settings,
 933            server_id.clone(),
 934            registry,
 935            worktree_store,
 936            &cx,
 937        )
 938        .await
 939        .with_context(|| format!("failed to build configuration for `{}`", server_id))?;
 940
 941        let command = configuration
 942            .command()
 943            .context("context server has no command (HTTP servers don't need RPC)")?;
 944
 945        Ok(proto::ContextServerCommand {
 946            path: command.path.display().to_string(),
 947            args: command.args.clone(),
 948            env: command
 949                .env
 950                .clone()
 951                .map(|env| env.into_iter().collect())
 952                .unwrap_or_default(),
 953        })
 954    }
 955
 956    fn resolve_project_settings<'a>(
 957        worktree_store: &'a Entity<WorktreeStore>,
 958        cx: &'a App,
 959    ) -> &'a ProjectSettings {
 960        let location = worktree_store
 961            .read(cx)
 962            .visible_worktrees(cx)
 963            .next()
 964            .map(|worktree| settings::SettingsLocation {
 965                worktree_id: worktree.read(cx).id(),
 966                path: RelPath::empty(),
 967            });
 968        ProjectSettings::get(location, cx)
 969    }
 970
 971    fn create_oauth_token_provider(
 972        id: &ContextServerId,
 973        server_url: &url::Url,
 974        session: OAuthSession,
 975        http_client: Arc<dyn HttpClient>,
 976        credentials_provider: Arc<dyn CredentialsProvider>,
 977        cx: &mut AsyncApp,
 978    ) -> Arc<dyn oauth::OAuthTokenProvider> {
 979        let (token_refresh_tx, mut token_refresh_rx) = futures::channel::mpsc::unbounded();
 980        let id = id.clone();
 981        let server_url = server_url.clone();
 982
 983        cx.spawn(async move |cx| {
 984            while let Some(refreshed_session) = token_refresh_rx.next().await {
 985                if let Err(err) =
 986                    Self::store_session(&credentials_provider, &server_url, &refreshed_session, &cx)
 987                        .await
 988                {
 989                    log::warn!("{} failed to persist refreshed OAuth session: {}", id, err);
 990                }
 991            }
 992            log::debug!("{} OAuth session persistence task ended", id);
 993        })
 994        .detach();
 995
 996        Arc::new(McpOAuthTokenProvider::new(
 997            session,
 998            http_client,
 999            Some(token_refresh_tx),
1000        ))
1001    }
1002
1003    /// Initiate the OAuth browser flow for a server in the `AuthRequired` state.
1004    ///
1005    /// This starts a loopback HTTP callback server on an ephemeral port, builds
1006    /// the authorization URL, opens the user's browser, waits for the callback,
1007    /// exchanges the code for tokens, persists them in the keychain, and restarts
1008    /// the server with the new token provider.
1009    pub fn authenticate_server(
1010        &mut self,
1011        id: &ContextServerId,
1012        cx: &mut Context<Self>,
1013    ) -> Result<()> {
1014        let state = self.servers.get(id).context("Context server not found")?;
1015
1016        let (discovery, server, configuration) = match state {
1017            ContextServerState::AuthRequired {
1018                discovery,
1019                server,
1020                configuration,
1021            } => (discovery.clone(), server.clone(), configuration.clone()),
1022            _ => anyhow::bail!("Server is not in AuthRequired state"),
1023        };
1024
1025        let needs_keychain_check = match configuration.as_ref() {
1026            ContextServerConfiguration::Http {
1027                url,
1028                oauth: Some(oauth_settings),
1029                ..
1030            } if oauth_settings.client_secret.is_none() => Some(url.clone()),
1031            _ => None,
1032        };
1033
1034        let id = id.clone();
1035
1036        let task = cx.spawn({
1037            let id = id.clone();
1038            let server = server.clone();
1039            let configuration = configuration.clone();
1040            async move |this, cx| {
1041                if let Some(server_url) = needs_keychain_check {
1042                    let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1043                    let has_keychain_secret =
1044                        Self::load_client_secret(&credentials_provider, &server_url, cx)
1045                            .await
1046                            .ok()
1047                            .flatten()
1048                            .is_some();
1049
1050                    if !has_keychain_secret {
1051                        this.update(cx, |this, cx| {
1052                            this.update_server_state(
1053                                id.clone(),
1054                                ContextServerState::ClientSecretRequired {
1055                                    server,
1056                                    configuration,
1057                                    discovery,
1058                                    error: None,
1059                                },
1060                                cx,
1061                            );
1062                        })
1063                        .log_err();
1064                        return;
1065                    }
1066                }
1067
1068                let result = Self::run_oauth_flow(
1069                    this.clone(),
1070                    id.clone(),
1071                    discovery.clone(),
1072                    configuration.clone(),
1073                    cx,
1074                )
1075                .await;
1076
1077                if let Err(err) = &result {
1078                    log::error!("{} OAuth authentication failed: {:?}", id, err);
1079                    this.update(cx, |this, cx| {
1080                        this.update_server_state(
1081                            id.clone(),
1082                            ContextServerState::Error {
1083                                server,
1084                                configuration,
1085                                error: format!("{err:#}").into(),
1086                            },
1087                            cx,
1088                        )
1089                    })
1090                    .log_err();
1091                }
1092            }
1093        });
1094
1095        self.update_server_state(
1096            id,
1097            ContextServerState::Authenticating {
1098                server,
1099                configuration,
1100                _task: task,
1101            },
1102            cx,
1103        );
1104
1105        Ok(())
1106    }
1107
1108    /// Store the client secret and proceed with authentication.
1109    pub fn submit_client_secret(
1110        &mut self,
1111        id: &ContextServerId,
1112        secret: String,
1113        cx: &mut Context<Self>,
1114    ) -> Result<()> {
1115        let state = self.servers.get(id).context("Context server not found")?;
1116
1117        let (server, configuration, discovery) = match state {
1118            ContextServerState::ClientSecretRequired {
1119                server,
1120                configuration,
1121                discovery,
1122                ..
1123            } => (server.clone(), configuration.clone(), discovery.clone()),
1124            _ => anyhow::bail!("Server is not in ClientSecretRequired state"),
1125        };
1126
1127        let server_url = match configuration.as_ref() {
1128            ContextServerConfiguration::Http { url, .. } => url.clone(),
1129            _ => anyhow::bail!("OAuth only supported for HTTP servers"),
1130        };
1131
1132        let id = id.clone();
1133
1134        let task = cx.spawn({
1135            let id = id.clone();
1136            let server = server.clone();
1137            let configuration = configuration.clone();
1138            async move |this, cx| {
1139                // Store the secret if non-empty (empty means public client / skip).
1140                if !secret.is_empty() {
1141                    let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1142                    if let Err(err) =
1143                        Self::store_client_secret(&credentials_provider, &server_url, &secret, cx)
1144                            .await
1145                    {
1146                        log::error!(
1147                            "{} failed to store client secret in keychain: {:?}",
1148                            id,
1149                            err
1150                        );
1151                    }
1152                }
1153
1154                let result = Self::run_oauth_flow(
1155                    this.clone(),
1156                    id.clone(),
1157                    discovery.clone(),
1158                    configuration.clone(),
1159                    cx,
1160                )
1161                .await;
1162
1163                if let Err(err) = &result {
1164                    log::error!("{} OAuth authentication failed: {:?}", id, err);
1165
1166                    let is_bad_client_credentials = err
1167                        .downcast_ref::<oauth::OAuthTokenError>()
1168                        .is_some_and(|e| e.error == "unauthorized_client");
1169
1170                    if is_bad_client_credentials {
1171                        // Clear the bad secret from the keychain so the user
1172                        // gets a fresh prompt.
1173                        let credentials_provider =
1174                            cx.update(|cx| zed_credentials_provider::global(cx));
1175                        Self::clear_client_secret(&credentials_provider, &server_url, cx)
1176                            .await
1177                            .log_err();
1178
1179                        this.update(cx, |this, cx| {
1180                            this.update_server_state(
1181                                id.clone(),
1182                                ContextServerState::ClientSecretRequired {
1183                                    server,
1184                                    configuration,
1185                                    discovery,
1186                                    error: Some(format!("{err:#}").into()),
1187                                },
1188                                cx,
1189                            );
1190                        })
1191                        .log_err();
1192                    } else {
1193                        this.update(cx, |this, cx| {
1194                            this.update_server_state(
1195                                id.clone(),
1196                                ContextServerState::Error {
1197                                    server,
1198                                    configuration,
1199                                    error: format!("{err:#}").into(),
1200                                },
1201                                cx,
1202                            )
1203                        })
1204                        .log_err();
1205                    }
1206                }
1207            }
1208        });
1209
1210        self.update_server_state(
1211            id,
1212            ContextServerState::Authenticating {
1213                server,
1214                configuration,
1215                _task: task,
1216            },
1217            cx,
1218        );
1219
1220        Ok(())
1221    }
1222
1223    async fn run_oauth_flow(
1224        this: WeakEntity<Self>,
1225        id: ContextServerId,
1226        discovery: Arc<OAuthDiscovery>,
1227        configuration: Arc<ContextServerConfiguration>,
1228        cx: &mut AsyncApp,
1229    ) -> Result<()> {
1230        let resource = oauth::canonical_server_uri(&discovery.resource_metadata.resource);
1231        let pkce = oauth::generate_pkce_challenge();
1232
1233        let mut state_bytes = [0u8; 32];
1234        rand::rng().fill(&mut state_bytes);
1235        let state_param: String = state_bytes.iter().map(|b| format!("{:02x}", b)).collect();
1236
1237        // Start a loopback HTTP server on an ephemeral port. The redirect URI
1238        // includes this port so the browser sends the callback directly to our
1239        // process.
1240        let (redirect_uri, callback_rx) = oauth::start_callback_server()
1241            .await
1242            .context("Failed to start OAuth callback server")?;
1243
1244        let http_client = cx.update(|cx| cx.http_client());
1245        let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1246        let server_url = match configuration.as_ref() {
1247            ContextServerConfiguration::Http { url, .. } => url.clone(),
1248            _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1249        };
1250
1251        let client_registration = match configuration.as_ref() {
1252            ContextServerConfiguration::Http {
1253                url,
1254                oauth: Some(oauth_settings),
1255                ..
1256            } => {
1257                // Pre-registered client. Resolve the secret from settings, then keychain.
1258                let client_secret = if oauth_settings.client_secret.is_some() {
1259                    oauth_settings.client_secret.clone()
1260                } else {
1261                    Self::load_client_secret(&credentials_provider, url, cx)
1262                        .await
1263                        .ok()
1264                        .flatten()
1265                };
1266                oauth::OAuthClientRegistration {
1267                    client_id: oauth_settings.client_id.clone(),
1268                    client_secret,
1269                }
1270            }
1271            _ => oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
1272                .await
1273                .context("Failed to resolve OAuth client registration")?,
1274        };
1275
1276        let auth_url = oauth::build_authorization_url(
1277            &discovery.auth_server_metadata,
1278            &client_registration.client_id,
1279            &redirect_uri,
1280            &discovery.scopes,
1281            &resource,
1282            &pkce,
1283            &state_param,
1284        );
1285
1286        cx.update(|cx| cx.open_url(auth_url.as_str()));
1287
1288        let callback = callback_rx
1289            .await
1290            .map_err(|_| {
1291                anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
1292            })?
1293            .context("OAuth callback server received an invalid request")?;
1294
1295        if callback.state != state_param {
1296            anyhow::bail!("OAuth state parameter mismatch (possible CSRF)");
1297        }
1298
1299        let tokens = oauth::exchange_code(
1300            &http_client,
1301            &discovery.auth_server_metadata,
1302            &callback.code,
1303            &client_registration.client_id,
1304            &redirect_uri,
1305            &pkce.verifier,
1306            &resource,
1307            client_registration.client_secret.as_deref(),
1308        )
1309        .await
1310        .context("Failed to exchange authorization code for tokens")?;
1311
1312        let session = OAuthSession {
1313            token_endpoint: discovery.auth_server_metadata.token_endpoint.clone(),
1314            resource: discovery.resource_metadata.resource.clone(),
1315            client_registration,
1316            tokens,
1317        };
1318
1319        Self::store_session(&credentials_provider, &server_url, &session, cx)
1320            .await
1321            .context("Failed to persist OAuth session in keychain")?;
1322
1323        let token_provider = Self::create_oauth_token_provider(
1324            &id,
1325            &server_url,
1326            session,
1327            http_client.clone(),
1328            credentials_provider,
1329            cx,
1330        );
1331
1332        let new_server = this.update(cx, |this, cx| {
1333            let global_timeout =
1334                Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
1335
1336            match configuration.as_ref() {
1337                ContextServerConfiguration::Http {
1338                    url,
1339                    headers,
1340                    timeout,
1341                    oauth: _,
1342                } => {
1343                    let transport = HttpTransport::new_with_token_provider(
1344                        http_client.clone(),
1345                        url.to_string(),
1346                        headers.clone(),
1347                        cx.background_executor().clone(),
1348                        Some(token_provider.clone()),
1349                    );
1350                    Ok(Arc::new(ContextServer::new_with_timeout(
1351                        id.clone(),
1352                        Arc::new(transport),
1353                        Some(Duration::from_secs(
1354                            timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
1355                        )),
1356                    )))
1357                }
1358                _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
1359            }
1360        })??;
1361
1362        this.update(cx, |this, cx| {
1363            this.run_server(new_server, configuration, cx);
1364        })?;
1365
1366        Ok(())
1367    }
1368
1369    /// Store the full OAuth session in the system keychain, keyed by the
1370    /// server's canonical URI.
1371    async fn store_session(
1372        credentials_provider: &Arc<dyn CredentialsProvider>,
1373        server_url: &url::Url,
1374        session: &OAuthSession,
1375        cx: &AsyncApp,
1376    ) -> Result<()> {
1377        let key = Self::keychain_key(server_url);
1378        let json = serde_json::to_string(session)?;
1379        credentials_provider
1380            .write_credentials(&key, "mcp-oauth", json.as_bytes(), cx)
1381            .await
1382    }
1383
1384    /// Load the full OAuth session from the system keychain for the given
1385    /// server URL.
1386    async fn load_session(
1387        credentials_provider: &Arc<dyn CredentialsProvider>,
1388        server_url: &url::Url,
1389        cx: &AsyncApp,
1390    ) -> Result<Option<OAuthSession>> {
1391        let key = Self::keychain_key(server_url);
1392        match credentials_provider.read_credentials(&key, cx).await? {
1393            Some((_username, password_bytes)) => {
1394                let session: OAuthSession = serde_json::from_slice(&password_bytes)?;
1395                Ok(Some(session))
1396            }
1397            None => Ok(None),
1398        }
1399    }
1400
1401    /// Clear the stored OAuth session from the system keychain.
1402    async fn clear_session(
1403        credentials_provider: &Arc<dyn CredentialsProvider>,
1404        server_url: &url::Url,
1405        cx: &AsyncApp,
1406    ) -> Result<()> {
1407        let key = Self::keychain_key(server_url);
1408        credentials_provider.delete_credentials(&key, cx).await
1409    }
1410
1411    fn keychain_key(server_url: &url::Url) -> String {
1412        format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
1413    }
1414
1415    fn client_secret_keychain_key(server_url: &url::Url) -> String {
1416        format!(
1417            "mcp-oauth-client-secret:{}",
1418            oauth::canonical_server_uri(server_url)
1419        )
1420    }
1421
1422    async fn load_client_secret(
1423        credentials_provider: &Arc<dyn CredentialsProvider>,
1424        server_url: &url::Url,
1425        cx: &AsyncApp,
1426    ) -> Result<Option<String>> {
1427        let key = Self::client_secret_keychain_key(server_url);
1428        match credentials_provider.read_credentials(&key, cx).await? {
1429            Some((_username, secret_bytes)) => Ok(Some(String::from_utf8(secret_bytes)?)),
1430            None => Ok(None),
1431        }
1432    }
1433
1434    pub async fn store_client_secret(
1435        credentials_provider: &Arc<dyn CredentialsProvider>,
1436        server_url: &url::Url,
1437        secret: &str,
1438        cx: &AsyncApp,
1439    ) -> Result<()> {
1440        let key = Self::client_secret_keychain_key(server_url);
1441        credentials_provider
1442            .write_credentials(&key, "mcp-oauth-client-secret", secret.as_bytes(), cx)
1443            .await
1444    }
1445
1446    async fn clear_client_secret(
1447        credentials_provider: &Arc<dyn CredentialsProvider>,
1448        server_url: &url::Url,
1449        cx: &AsyncApp,
1450    ) -> Result<()> {
1451        let key = Self::client_secret_keychain_key(server_url);
1452        credentials_provider.delete_credentials(&key, cx).await
1453    }
1454
1455    /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
1456    /// session from the keychain and stop the server.
1457    pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
1458        let state = self.servers.get(id).context("Context server not found")?;
1459        let configuration = state.configuration();
1460
1461        let server_url = match configuration.as_ref() {
1462            ContextServerConfiguration::Http { url, .. } => url.clone(),
1463            _ => anyhow::bail!("logout only applies to HTTP servers with OAuth"),
1464        };
1465
1466        let id = id.clone();
1467        self.stop_server(&id, cx)?;
1468
1469        cx.spawn(async move |this, cx| {
1470            let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1471            if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
1472                log::error!("{} failed to clear OAuth session: {}", id, err);
1473            }
1474            // Also clear any client secret so the user gets a fresh prompt on
1475            // the next authentication attempt.
1476            Self::clear_client_secret(&credentials_provider, &server_url, &cx)
1477                .await
1478                .log_err();
1479            // Trigger server recreation so the next start uses a fresh
1480            // transport without the old (now-invalidated) token provider.
1481            this.update(cx, |this, cx| {
1482                this.available_context_servers_changed(cx);
1483            })
1484            .log_err();
1485        })
1486        .detach();
1487
1488        Ok(())
1489    }
1490
1491    fn update_server_state(
1492        &mut self,
1493        id: ContextServerId,
1494        state: ContextServerState,
1495        cx: &mut Context<Self>,
1496    ) {
1497        let status = ContextServerStatus::from_state(&state);
1498        self.servers.insert(id.clone(), state);
1499        cx.emit(ServerStatusChangedEvent {
1500            server_id: id,
1501            status,
1502        });
1503    }
1504
1505    fn available_context_servers_changed(&mut self, cx: &mut Context<Self>) {
1506        if self.update_servers_task.is_some() {
1507            self.needs_server_update = true;
1508        } else {
1509            self.needs_server_update = false;
1510            self.update_servers_task = Some(cx.spawn(async move |this, cx| {
1511                if let Err(err) = Self::maintain_servers(this.clone(), cx).await {
1512                    log::error!("Error maintaining context servers: {}", err);
1513                }
1514
1515                this.update(cx, |this, cx| {
1516                    this.populate_server_ids(cx);
1517                    cx.notify();
1518                    this.update_servers_task.take();
1519                    if this.needs_server_update {
1520                        this.available_context_servers_changed(cx);
1521                    }
1522                })?;
1523
1524                Ok(())
1525            }));
1526        }
1527    }
1528
1529    async fn maintain_servers(this: WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
1530        // Don't start context servers if AI is disabled
1531        let ai_disabled = this.update(cx, |_, cx| DisableAiSettings::get_global(cx).disable_ai)?;
1532        if ai_disabled {
1533            // Stop all running servers when AI is disabled
1534            this.update(cx, |this, cx| {
1535                let server_ids: Vec<_> = this.servers.keys().cloned().collect();
1536                for id in server_ids {
1537                    let _ = this.stop_server(&id, cx);
1538                }
1539            })?;
1540            return Ok(());
1541        }
1542
1543        let (mut configured_servers, registry, worktree_store) = this.update(cx, |this, _| {
1544            (
1545                this.context_server_settings.clone(),
1546                this.registry.clone(),
1547                this.worktree_store.clone(),
1548            )
1549        })?;
1550
1551        for (id, _) in registry.read_with(cx, |registry, _| registry.context_server_descriptors()) {
1552            configured_servers
1553                .entry(id)
1554                .or_insert(ContextServerSettings::default_extension());
1555        }
1556
1557        let (enabled_servers, disabled_servers): (HashMap<_, _>, HashMap<_, _>) =
1558            configured_servers
1559                .into_iter()
1560                .partition(|(_, settings)| settings.enabled());
1561
1562        let configured_servers = join_all(enabled_servers.into_iter().map(|(id, settings)| {
1563            let id = ContextServerId(id);
1564            ContextServerConfiguration::from_settings(
1565                settings,
1566                id.clone(),
1567                registry.clone(),
1568                worktree_store.clone(),
1569                cx,
1570            )
1571            .map(move |config| (id, config))
1572        }))
1573        .await
1574        .into_iter()
1575        .filter_map(|(id, config)| config.map(|config| (id, config)))
1576        .collect::<HashMap<_, _>>();
1577
1578        let mut servers_to_start = Vec::new();
1579        let mut servers_to_remove = HashSet::default();
1580        let mut servers_to_stop = HashSet::default();
1581
1582        this.update(cx, |this, _cx| {
1583            for server_id in this.servers.keys() {
1584                // All servers that are not in desired_servers should be removed from the store.
1585                // This can happen if the user removed a server from the context server settings.
1586                if !configured_servers.contains_key(server_id) {
1587                    if disabled_servers.contains_key(&server_id.0) {
1588                        servers_to_stop.insert(server_id.clone());
1589                    } else {
1590                        servers_to_remove.insert(server_id.clone());
1591                    }
1592                }
1593            }
1594
1595            for (id, config) in configured_servers {
1596                let state = this.servers.get(&id);
1597                let is_stopped = matches!(state, Some(ContextServerState::Stopped { .. }));
1598                let existing_config = state.as_ref().map(|state| state.configuration());
1599                if existing_config.as_deref() != Some(&config) || is_stopped {
1600                    let config = Arc::new(config);
1601                    servers_to_start.push((id.clone(), config));
1602                    if this.servers.contains_key(&id) {
1603                        servers_to_stop.insert(id);
1604                    }
1605                }
1606            }
1607
1608            anyhow::Ok(())
1609        })??;
1610
1611        this.update(cx, |this, inner_cx| {
1612            for id in servers_to_stop {
1613                this.stop_server(&id, inner_cx)?;
1614            }
1615            for id in servers_to_remove {
1616                this.remove_server(&id, inner_cx)?;
1617            }
1618            anyhow::Ok(())
1619        })??;
1620
1621        for (id, config) in servers_to_start {
1622            match Self::create_context_server(this.clone(), id.clone(), config, cx).await {
1623                Ok((server, config)) => {
1624                    this.update(cx, |this, cx| {
1625                        this.run_server(server, config, cx);
1626                    })?;
1627                }
1628                Err(err) => {
1629                    log::error!("{id} context server failed to create: {err:#}");
1630                    this.update(cx, |_this, cx| {
1631                        cx.emit(ServerStatusChangedEvent {
1632                            server_id: id,
1633                            status: ContextServerStatus::Error(err.to_string().into()),
1634                        });
1635                        cx.notify();
1636                    })?;
1637                }
1638            }
1639        }
1640
1641        Ok(())
1642    }
1643}
1644
1645/// Determines the appropriate server state after a start attempt fails.
1646///
1647/// When the error is an HTTP 401 with no static auth header configured,
1648/// attempts OAuth discovery so the UI can offer an authentication flow.
1649async fn resolve_start_failure(
1650    id: &ContextServerId,
1651    err: anyhow::Error,
1652    server: Arc<ContextServer>,
1653    configuration: Arc<ContextServerConfiguration>,
1654    cx: &AsyncApp,
1655) -> ContextServerState {
1656    let www_authenticate = err.downcast_ref::<TransportError>().map(|e| match e {
1657        TransportError::AuthRequired { www_authenticate } => www_authenticate.clone(),
1658    });
1659
1660    if www_authenticate.is_some() && configuration.has_static_auth_header() {
1661        log::warn!("{id} received 401 with a static Authorization header configured");
1662        return ContextServerState::Error {
1663            configuration,
1664            server,
1665            error: "Server returned 401 Unauthorized. Check your configured Authorization header."
1666                .into(),
1667        };
1668    }
1669
1670    let server_url = match configuration.as_ref() {
1671        ContextServerConfiguration::Http { url, .. } if !configuration.has_static_auth_header() => {
1672            url.clone()
1673        }
1674        _ => {
1675            if www_authenticate.is_some() {
1676                log::error!("{id} got OAuth 401 on a non-HTTP transport or with static auth");
1677            } else {
1678                log::error!("{id} context server failed to start: {err}");
1679            }
1680            return ContextServerState::Error {
1681                configuration,
1682                server,
1683                error: err.to_string().into(),
1684            };
1685        }
1686    };
1687
1688    // When the error is NOT a 401 but there is a cached OAuth session in the
1689    // keychain, the session is likely stale/expired and caused the failure
1690    // (e.g. timeout because the server rejected the token silently). Clear it
1691    // so the next start attempt can get a clean 401 and trigger the auth flow.
1692    if www_authenticate.is_none() {
1693        let credentials_provider = cx.update(|cx| zed_credentials_provider::global(cx));
1694        match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
1695            Ok(Some(_)) => {
1696                log::info!("{id} start failed with a cached OAuth session present; clearing it");
1697                ContextServerStore::clear_session(&credentials_provider, &server_url, cx)
1698                    .await
1699                    .log_err();
1700            }
1701            _ => {
1702                log::error!("{id} context server failed to start: {err}");
1703                return ContextServerState::Error {
1704                    configuration,
1705                    server,
1706                    error: err.to_string().into(),
1707                };
1708            }
1709        }
1710    }
1711
1712    let default_www_authenticate = oauth::WwwAuthenticate {
1713        resource_metadata: None,
1714        scope: None,
1715        error: None,
1716        error_description: None,
1717    };
1718    let www_authenticate = www_authenticate
1719        .as_ref()
1720        .unwrap_or(&default_www_authenticate);
1721    let http_client = cx.update(|cx| cx.http_client());
1722
1723    match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
1724        Ok(discovery) => {
1725            use context_server::oauth::{
1726                ClientRegistrationStrategy, determine_registration_strategy,
1727            };
1728
1729            let has_preregistered_client_id = matches!(
1730                configuration.as_ref(),
1731                ContextServerConfiguration::Http { oauth: Some(_), .. }
1732            );
1733
1734            let strategy = determine_registration_strategy(&discovery.auth_server_metadata);
1735
1736            if matches!(strategy, ClientRegistrationStrategy::Unavailable)
1737                && !has_preregistered_client_id
1738            {
1739                log::error!(
1740                    "{id} authorization server supports neither CIMD nor DCR, \
1741                     and no pre-registered client_id is configured"
1742                );
1743                return ContextServerState::Error {
1744                    configuration,
1745                    server,
1746                    error: "Authorization server supports neither CIMD nor DCR. \
1747                            Configure a pre-registered client_id in your settings \
1748                            under the \"oauth\" key."
1749                        .into(),
1750                };
1751            }
1752
1753            log::info!(
1754                "{id} requires OAuth authorization (auth server: {})",
1755                discovery.auth_server_metadata.issuer,
1756            );
1757            ContextServerState::AuthRequired {
1758                server,
1759                configuration,
1760                discovery: Arc::new(discovery),
1761            }
1762        }
1763        Err(discovery_err) => {
1764            log::error!("{id} OAuth discovery failed: {discovery_err}");
1765            ContextServerState::Error {
1766                configuration,
1767                server,
1768                error: format!("OAuth discovery failed: {discovery_err}").into(),
1769            }
1770        }
1771    }
1772}