Detailed changes
@@ -3572,6 +3572,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
+ "base64 0.22.1",
"collections",
"futures 0.3.31",
"gpui",
@@ -3580,14 +3581,17 @@ dependencies = [
"net",
"parking_lot",
"postage",
+ "rand 0.9.2",
"schemars",
"serde",
"serde_json",
"settings",
+ "sha2",
"slotmap",
"smol",
"tempfile",
"terminal",
+ "tiny_http",
"url",
"util",
]
@@ -13189,6 +13193,7 @@ dependencies = [
"clock",
"collections",
"context_server",
+ "credentials_provider",
"dap",
"encoding_rs",
"extension",
@@ -253,12 +253,14 @@ impl ContextServerRegistry {
let project::context_server_store::ServerStatusChangedEvent { server_id, status } = event;
match status {
- ContextServerStatus::Starting => {}
+ ContextServerStatus::Starting | ContextServerStatus::Authenticating => {}
ContextServerStatus::Running => {
self.reload_tools_for_server(server_id.clone(), cx);
self.reload_prompts_for_server(server_id.clone(), cx);
}
- ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
+ ContextServerStatus::Stopped
+ | ContextServerStatus::Error(_)
+ | ContextServerStatus::AuthRequired => {
if let Some(registered_server) = self.registered_servers.remove(server_id) {
if !registered_server.tools.is_empty() {
cx.emit(ContextServerRegistryEvent::ToolsChanged);
@@ -517,11 +517,7 @@ impl AgentConfiguration {
}
}
- fn render_context_servers_section(
- &mut self,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) -> impl IntoElement {
+ fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
let context_server_ids = self.context_server_store.read(cx).server_ids();
let add_server_popover = PopoverMenu::new("add-server-popover")
@@ -601,7 +597,7 @@ impl AgentConfiguration {
} else {
parent.children(itertools::intersperse_with(
context_server_ids.iter().cloned().map(|context_server_id| {
- self.render_context_server(context_server_id, window, cx)
+ self.render_context_server(context_server_id, cx)
.into_any_element()
}),
|| {
@@ -618,7 +614,6 @@ impl AgentConfiguration {
fn render_context_server(
&self,
context_server_id: ContextServerId,
- window: &mut Window,
cx: &Context<Self>,
) -> impl use<> + IntoElement {
let server_status = self
@@ -646,6 +641,9 @@ impl AgentConfiguration {
} else {
None
};
+ let auth_required = matches!(server_status, ContextServerStatus::AuthRequired);
+ let authenticating = matches!(server_status, ContextServerStatus::Authenticating);
+ let context_server_store = self.context_server_store.clone();
let tool_count = self
.context_server_registry
@@ -689,11 +687,33 @@ impl AgentConfiguration {
Indicator::dot().color(Color::Muted).into_any_element(),
"Server is stopped.",
),
+ ContextServerStatus::AuthRequired => (
+ Indicator::dot().color(Color::Warning).into_any_element(),
+ "Authentication required.",
+ ),
+ ContextServerStatus::Authenticating => (
+ Icon::new(IconName::LoadCircle)
+ .size(IconSize::XSmall)
+ .color(Color::Accent)
+ .with_keyed_rotate_animation(
+ SharedString::from(format!("{}-authenticating", context_server_id.0)),
+ 3,
+ )
+ .into_any_element(),
+ "Waiting for authorization...",
+ ),
};
+
let is_remote = server_configuration
.as_ref()
.map(|config| matches!(config.as_ref(), ContextServerConfiguration::Http { .. }))
.unwrap_or(false);
+
+ let should_show_logout_button = server_configuration.as_ref().is_some_and(|config| {
+ matches!(config.as_ref(), ContextServerConfiguration::Http { .. })
+ && !config.has_static_auth_header()
+ });
+
let context_server_configuration_menu = PopoverMenu::new("context-server-config-menu")
.trigger_with_tooltip(
IconButton::new("context-server-config-menu", IconName::Settings)
@@ -708,6 +728,7 @@ impl AgentConfiguration {
let language_registry = self.language_registry.clone();
let workspace = self.workspace.clone();
let context_server_registry = self.context_server_registry.clone();
+ let context_server_store = context_server_store.clone();
move |window, cx| {
Some(ContextMenu::build(window, cx, |menu, _window, _cx| {
@@ -754,6 +775,17 @@ impl AgentConfiguration {
.ok();
}
}))
+ .when(should_show_logout_button, |this| {
+ this.entry("Log Out", None, {
+ let context_server_store = context_server_store.clone();
+ let context_server_id = context_server_id.clone();
+ move |_window, cx| {
+ context_server_store.update(cx, |store, cx| {
+ store.logout_server(&context_server_id, cx).log_err();
+ });
+ }
+ })
+ })
.separator()
.entry("Uninstall", None, {
let fs = fs.clone();
@@ -810,6 +842,9 @@ impl AgentConfiguration {
}
});
+ let feedback_base_container =
+ || h_flex().py_1().min_w_0().w_full().gap_1().justify_between();
+
v_flex()
.min_w_0()
.id(item_id.clone())
@@ -868,6 +903,7 @@ impl AgentConfiguration {
.on_click({
let context_server_manager = self.context_server_store.clone();
let fs = self.fs.clone();
+ let context_server_id = context_server_id.clone();
move |state, _window, cx| {
let is_enabled = match state {
@@ -915,30 +951,111 @@ impl AgentConfiguration {
)
.map(|parent| {
if let Some(error) = error {
+ return parent
+ .child(
+ feedback_base_container()
+ .child(
+ h_flex()
+ .pr_4()
+ .min_w_0()
+ .w_full()
+ .gap_2()
+ .child(
+ Icon::new(IconName::XCircle)
+ .size(IconSize::XSmall)
+ .color(Color::Error),
+ )
+ .child(
+ div().min_w_0().flex_1().child(
+ Label::new(error)
+ .color(Color::Muted)
+ .size(LabelSize::Small),
+ ),
+ ),
+ )
+ .when(should_show_logout_button, |this| {
+ this.child(
+ Button::new("error-logout-server", "Log Out")
+ .style(ButtonStyle::Outlined)
+ .label_size(LabelSize::Small)
+ .on_click({
+ let context_server_store =
+ context_server_store.clone();
+ let context_server_id =
+ context_server_id.clone();
+ move |_event, _window, cx| {
+ context_server_store.update(
+ cx,
+ |store, cx| {
+ store
+ .logout_server(
+ &context_server_id,
+ cx,
+ )
+ .log_err();
+ },
+ );
+ }
+ }),
+ )
+ }),
+ );
+ }
+ if auth_required {
return parent.child(
- h_flex()
- .gap_2()
- .pr_4()
- .items_start()
+ feedback_base_container()
.child(
h_flex()
- .flex_none()
- .h(window.line_height() / 1.6_f32)
- .justify_center()
+ .pr_4()
+ .min_w_0()
+ .w_full()
+ .gap_2()
.child(
- Icon::new(IconName::XCircle)
+ Icon::new(IconName::Info)
.size(IconSize::XSmall)
- .color(Color::Error),
+ .color(Color::Muted),
+ )
+ .child(
+ Label::new("Authenticate to connect this server")
+ .color(Color::Muted)
+ .size(LabelSize::Small),
),
)
.child(
- div().w_full().child(
- Label::new(error)
- .buffer_font(cx)
- .color(Color::Muted)
- .size(LabelSize::Small),
- ),
+ Button::new("error-logout-server", "Authenticate")
+ .style(ButtonStyle::Outlined)
+ .label_size(LabelSize::Small)
+ .on_click({
+ let context_server_store = context_server_store.clone();
+ let context_server_id = context_server_id.clone();
+ move |_event, _window, cx| {
+ context_server_store.update(cx, |store, cx| {
+ store
+ .authenticate_server(&context_server_id, cx)
+ .log_err();
+ });
+ }
+ }),
+ ),
+ );
+ }
+ if authenticating {
+ return parent.child(
+ h_flex()
+ .mt_1()
+ .pr_4()
+ .min_w_0()
+ .w_full()
+ .gap_2()
+ .child(
+ div().size_3().flex_shrink_0(), // Alignment Div
+ )
+ .child(
+ Label::new("Authenticatingโฆ")
+ .color(Color::Muted)
+ .size(LabelSize::Small),
),
+
);
}
parent
@@ -1234,7 +1351,7 @@ impl Render for AgentConfiguration {
.min_w_0()
.overflow_y_scroll()
.child(self.render_agent_servers_section(cx))
- .child(self.render_context_servers_section(window, cx))
+ .child(self.render_context_servers_section(cx))
.child(self.render_provider_configuration_section(cx)),
)
.vertical_scrollbar_for(&self.scroll_handle, window, cx),
@@ -1,25 +1,27 @@
-use std::sync::{Arc, Mutex};
-
use anyhow::{Context as _, Result};
use collections::HashMap;
use context_server::{ContextServerCommand, ContextServerId};
use editor::{Editor, EditorElement, EditorStyle};
+
use gpui::{
AsyncWindowContext, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, ScrollHandle,
- Task, TextStyle, TextStyleRefinement, UnderlineStyle, WeakEntity, prelude::*,
+ Subscription, Task, TextStyle, TextStyleRefinement, UnderlineStyle, WeakEntity, prelude::*,
};
use language::{Language, LanguageRegistry};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
use notifications::status_toast::{StatusToast, ToastIcon};
+use parking_lot::Mutex;
use project::{
context_server_store::{
- ContextServerStatus, ContextServerStore, registry::ContextServerDescriptorRegistry,
+ ContextServerStatus, ContextServerStore, ServerStatusChangedEvent,
+ registry::ContextServerDescriptorRegistry,
},
project_settings::{ContextServerSettings, ProjectSettings},
worktree_store::WorktreeStore,
};
use serde::Deserialize;
use settings::{Settings as _, update_settings_file};
+use std::sync::Arc;
use theme::ThemeSettings;
use ui::{
CommonAnimationExt, KeyBinding, Modal, ModalFooter, ModalHeader, Section, Tooltip,
@@ -237,6 +239,8 @@ fn context_server_input(existing: Option<(ContextServerId, ContextServerCommand)
format!(
r#"{{
+ /// Configure an MCP server that runs locally via stdin/stdout
+ ///
/// The name of your MCP server
"{name}": {{
/// The command which runs the MCP server
@@ -280,6 +284,8 @@ fn context_server_http_input(
format!(
r#"{{
+ /// Configure an MCP server that you connect to over HTTP
+ ///
/// The name of your remote MCP server
"{name}": {{
/// The URL of the remote MCP server
@@ -342,6 +348,8 @@ fn resolve_context_server_extension(
enum State {
Idle,
Waiting,
+ AuthRequired { server_id: ContextServerId },
+ Authenticating { _server_id: ContextServerId },
Error(SharedString),
}
@@ -352,6 +360,7 @@ pub struct ConfigureContextServerModal {
state: State,
original_server_id: Option<ContextServerId>,
scroll_handle: ScrollHandle,
+ _auth_subscription: Option<Subscription>,
}
impl ConfigureContextServerModal {
@@ -475,6 +484,7 @@ impl ConfigureContextServerModal {
cx,
),
scroll_handle: ScrollHandle::new(),
+ _auth_subscription: None,
})
})
})
@@ -486,6 +496,13 @@ impl ConfigureContextServerModal {
}
fn confirm(&mut self, _: &menu::Confirm, cx: &mut Context<Self>) {
+ if matches!(
+ self.state,
+ State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. }
+ ) {
+ return;
+ }
+
self.state = State::Idle;
let Some(workspace) = self.workspace.upgrade() else {
return;
@@ -515,14 +532,19 @@ impl ConfigureContextServerModal {
async move |this, cx| {
let result = wait_for_context_server_task.await;
this.update(cx, |this, cx| match result {
- Ok(_) => {
+ Ok(ContextServerStatus::Running) => {
this.state = State::Idle;
this.show_configured_context_server_toast(id, cx);
cx.emit(DismissEvent);
}
+ Ok(ContextServerStatus::AuthRequired) => {
+ this.state = State::AuthRequired { server_id: id };
+ cx.notify();
+ }
Err(err) => {
this.set_error(err, cx);
}
+ Ok(_) => {}
})
}
})
@@ -558,6 +580,49 @@ impl ConfigureContextServerModal {
cx.emit(DismissEvent);
}
+ fn authenticate(&mut self, server_id: ContextServerId, cx: &mut Context<Self>) {
+ self.context_server_store.update(cx, |store, cx| {
+ store.authenticate_server(&server_id, cx).log_err();
+ });
+
+ self.state = State::Authenticating {
+ _server_id: server_id.clone(),
+ };
+
+ self._auth_subscription = Some(cx.subscribe(
+ &self.context_server_store,
+ move |this, _, event: &ServerStatusChangedEvent, cx| {
+ if event.server_id != server_id {
+ return;
+ }
+ match &event.status {
+ ContextServerStatus::Running => {
+ this._auth_subscription = None;
+ this.state = State::Idle;
+ this.show_configured_context_server_toast(event.server_id.clone(), cx);
+ cx.emit(DismissEvent);
+ }
+ ContextServerStatus::AuthRequired => {
+ this._auth_subscription = None;
+ this.state = State::AuthRequired {
+ server_id: event.server_id.clone(),
+ };
+ cx.notify();
+ }
+ ContextServerStatus::Error(error) => {
+ this._auth_subscription = None;
+ this.set_error(error.clone(), cx);
+ }
+ ContextServerStatus::Authenticating
+ | ContextServerStatus::Starting
+ | ContextServerStatus::Stopped => {}
+ }
+ },
+ ));
+
+ cx.notify();
+ }
+
fn show_configured_context_server_toast(&self, id: ContextServerId, cx: &mut App) {
self.workspace
.update(cx, {
@@ -615,7 +680,8 @@ impl ConfigureContextServerModal {
}
fn render_modal_description(&self, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
- const MODAL_DESCRIPTION: &str = "Visit the MCP server configuration docs to find all necessary arguments and environment variables.";
+ const MODAL_DESCRIPTION: &str =
+ "Check the server docs for required arguments and environment variables.";
if let ConfigurationSource::Extension {
installation_instructions: Some(installation_instructions),
@@ -637,6 +703,67 @@ impl ConfigureContextServerModal {
}
}
+ fn render_tab_bar(&self, cx: &mut Context<Self>) -> Option<AnyElement> {
+ let is_http = match &self.source {
+ ConfigurationSource::New { is_http, .. } => *is_http,
+ _ => return None,
+ };
+
+ let tab = |label: &'static str, active: bool| {
+ div()
+ .id(label)
+ .cursor_pointer()
+ .p_1()
+ .text_sm()
+ .border_b_1()
+ .when(active, |this| {
+ this.border_color(cx.theme().colors().border_focused)
+ })
+ .when(!active, |this| {
+ this.border_color(gpui::transparent_black())
+ .text_color(cx.theme().colors().text_muted)
+ .hover(|s| s.text_color(cx.theme().colors().text))
+ })
+ .child(label)
+ };
+
+ Some(
+ h_flex()
+ .pt_1()
+ .mb_2p5()
+ .gap_1()
+ .border_b_1()
+ .border_color(cx.theme().colors().border.opacity(0.5))
+ .child(
+ tab("Local", !is_http).on_click(cx.listener(|this, _, window, cx| {
+ if let ConfigurationSource::New { editor, is_http } = &mut this.source {
+ if *is_http {
+ *is_http = false;
+ let new_text = context_server_input(None);
+ editor.update(cx, |editor, cx| {
+ editor.set_text(new_text, window, cx);
+ });
+ }
+ }
+ })),
+ )
+ .child(
+ tab("Remote", is_http).on_click(cx.listener(|this, _, window, cx| {
+ if let ConfigurationSource::New { editor, is_http } = &mut this.source {
+ if !*is_http {
+ *is_http = true;
+ let new_text = context_server_http_input(None);
+ editor.update(cx, |editor, cx| {
+ editor.set_text(new_text, window, cx);
+ });
+ }
+ }
+ })),
+ )
+ .into_any_element(),
+ )
+ }
+
fn render_modal_content(&self, cx: &App) -> AnyElement {
let editor = match &self.source {
ConfigurationSource::New { editor, .. } => editor,
@@ -682,7 +809,10 @@ impl ConfigureContextServerModal {
fn render_modal_footer(&self, cx: &mut Context<Self>) -> ModalFooter {
let focus_handle = self.focus_handle(cx);
- let is_connecting = matches!(self.state, State::Waiting);
+ let is_busy = matches!(
+ self.state,
+ State::Waiting | State::AuthRequired { .. } | State::Authenticating { .. }
+ );
ModalFooter::new()
.start_slot::<Button>(
@@ -714,36 +844,6 @@ impl ConfigureContextServerModal {
move |_, _, cx| cx.open_url(&repository_url)
}),
)
- } else if let ConfigurationSource::New { is_http, .. } = &self.source {
- let label = if *is_http {
- "Configure Local"
- } else {
- "Configure Remote"
- };
- let tooltip = if *is_http {
- "Configure an MCP server that runs on stdin/stdout."
- } else {
- "Configure an MCP server that you connect to over HTTP"
- };
-
- Some(
- Button::new("toggle-kind", label)
- .tooltip(Tooltip::text(tooltip))
- .on_click(cx.listener(|this, _, window, cx| match &mut this.source {
- ConfigurationSource::New { editor, is_http } => {
- *is_http = !*is_http;
- let new_text = if *is_http {
- context_server_http_input(None)
- } else {
- context_server_input(None)
- };
- editor.update(cx, |editor, cx| {
- editor.set_text(new_text, window, cx);
- })
- }
- _ => {}
- })),
- )
} else {
None
},
@@ -777,7 +877,7 @@ impl ConfigureContextServerModal {
"Configure Server"
},
)
- .disabled(is_connecting)
+ .disabled(is_busy)
.key_binding(
KeyBinding::for_action_in(&menu::Confirm, &focus_handle, cx)
.map(|kb| kb.size(rems_from_px(12.))),
@@ -791,29 +891,62 @@ impl ConfigureContextServerModal {
)
}
- fn render_waiting_for_context_server() -> Div {
+ fn render_loading(&self, label: impl Into<SharedString>) -> Div {
h_flex()
- .gap_2()
+ .h_8()
+ .gap_1p5()
+ .justify_center()
.child(
- Icon::new(IconName::ArrowCircle)
+ Icon::new(IconName::LoadCircle)
.size(IconSize::XSmall)
- .color(Color::Info)
- .with_rotate_animation(2)
- .into_any_element(),
+ .color(Color::Muted)
+ .with_rotate_animation(3),
)
+ .child(Label::new(label).size(LabelSize::Small).color(Color::Muted))
+ }
+
+ fn render_auth_required(&self, server_id: &ContextServerId, cx: &mut Context<Self>) -> Div {
+ h_flex()
+ .h_8()
+ .min_w_0()
+ .w_full()
+ .gap_2()
+ .justify_center()
.child(
- Label::new("Waiting for Context Server")
- .size(LabelSize::Small)
- .color(Color::Muted),
+ h_flex()
+ .gap_1p5()
+ .child(
+ Icon::new(IconName::Info)
+ .size(IconSize::Small)
+ .color(Color::Muted),
+ )
+ .child(
+ Label::new("Authenticate to connect this server")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ )
+ .child(
+ Button::new("authenticate-server", "Authenticate")
+ .style(ButtonStyle::Outlined)
+ .label_size(LabelSize::Small)
+ .on_click({
+ let server_id = server_id.clone();
+ cx.listener(move |this, _event, _window, cx| {
+ this.authenticate(server_id.clone(), cx);
+ })
+ }),
)
}
fn render_modal_error(error: SharedString) -> Div {
h_flex()
- .gap_2()
+ .h_8()
+ .gap_1p5()
+ .justify_center()
.child(
Icon::new(IconName::Warning)
- .size(IconSize::XSmall)
+ .size(IconSize::Small)
.color(Color::Warning),
)
.child(
@@ -828,7 +961,7 @@ impl Render for ConfigureContextServerModal {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
div()
.elevation_3(cx)
- .w(rems(34.))
+ .w(rems(40.))
.key_context("ConfigureContextServerModal")
.on_action(
cx.listener(|this, _: &menu::Cancel, _window, cx| this.cancel(&menu::Cancel, cx)),
@@ -855,11 +988,18 @@ impl Render for ConfigureContextServerModal {
.overflow_y_scroll()
.track_scroll(&self.scroll_handle)
.child(self.render_modal_description(window, cx))
+ .children(self.render_tab_bar(cx))
.child(self.render_modal_content(cx))
.child(match &self.state {
State::Idle => div(),
State::Waiting => {
- Self::render_waiting_for_context_server()
+ self.render_loading("Connecting Serverโฆ")
+ }
+ State::AuthRequired { server_id } => {
+ self.render_auth_required(&server_id.clone(), cx)
+ }
+ State::Authenticating { .. } => {
+ self.render_loading("Authenticatingโฆ")
}
State::Error(error) => {
Self::render_modal_error(error.clone())
@@ -878,7 +1018,7 @@ fn wait_for_context_server(
context_server_store: &Entity<ContextServerStore>,
context_server_id: ContextServerId,
cx: &mut App,
-) -> Task<Result<(), Arc<str>>> {
+) -> Task<Result<ContextServerStatus, Arc<str>>> {
use std::time::Duration;
const WAIT_TIMEOUT: Duration = Duration::from_secs(120);
@@ -888,31 +1028,29 @@ fn wait_for_context_server(
let context_server_id_for_timeout = context_server_id.clone();
let subscription = cx.subscribe(context_server_store, move |_, event, _cx| {
- let project::context_server_store::ServerStatusChangedEvent { server_id, status } = event;
+ let ServerStatusChangedEvent { server_id, status } = event;
+
+ if server_id != &context_server_id {
+ return;
+ }
match status {
- ContextServerStatus::Running => {
- if server_id == &context_server_id
- && let Some(tx) = tx.lock().unwrap().take()
- {
- let _ = tx.send(Ok(()));
+ ContextServerStatus::Running | ContextServerStatus::AuthRequired => {
+ if let Some(tx) = tx.lock().take() {
+ let _ = tx.send(Ok(status.clone()));
}
}
ContextServerStatus::Stopped => {
- if server_id == &context_server_id
- && let Some(tx) = tx.lock().unwrap().take()
- {
+ if let Some(tx) = tx.lock().take() {
let _ = tx.send(Err("Context server stopped running".into()));
}
}
ContextServerStatus::Error(error) => {
- if server_id == &context_server_id
- && let Some(tx) = tx.lock().unwrap().take()
- {
+ if let Some(tx) = tx.lock().take() {
let _ = tx.send(Err(error.clone()));
}
}
- _ => {}
+ ContextServerStatus::Starting | ContextServerStatus::Authenticating => {}
}
});
@@ -901,14 +901,16 @@ impl TextThreadStore {
cx,
);
}
- ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
+ ContextServerStatus::Stopped
+ | ContextServerStatus::Error(_)
+ | ContextServerStatus::AuthRequired => {
if let Some(slash_command_ids) =
self.context_server_slash_command_ids.remove(server_id)
{
self.slash_commands.remove(&slash_command_ids);
}
}
- _ => {}
+ ContextServerStatus::Starting | ContextServerStatus::Authenticating => {}
}
}
@@ -17,6 +17,7 @@ test-support = ["gpui/test-support"]
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
+base64.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
@@ -24,14 +25,17 @@ http_client = { workspace = true, features = ["test-support"] }
log.workspace = true
net.workspace = true
parking_lot.workspace = true
+rand.workspace = true
postage.workspace = true
schemars.workspace = true
serde_json.workspace = true
serde.workspace = true
settings.workspace = true
+sha2.workspace = true
slotmap.workspace = true
smol.workspace = true
tempfile.workspace = true
+tiny_http.workspace = true
url = { workspace = true, features = ["serde"] }
util.workspace = true
terminal.workspace = true
@@ -35,7 +35,7 @@ pub const METHOD_NOT_FOUND: i32 = -32601;
pub const INVALID_PARAMS: i32 = -32602;
pub const INTERNAL_ERROR: i32 = -32603;
-type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
+type ResponseHandler = Box<dyn Send + FnOnce(String)>;
type NotificationHandler = Box<dyn Send + FnMut(Value, AsyncApp)>;
type RequestHandler = Box<dyn Send + FnMut(RequestId, &RawValue, AsyncApp)>;
@@ -62,6 +62,14 @@ pub(crate) struct Client {
#[allow(dead_code)]
transport: Arc<dyn Transport>,
request_timeout: Option<Duration>,
+ /// Single-slot side channel for the last transport-level error. When the
+ /// output task encounters a send failure it stashes the error here and
+ /// exits; the next request to observe cancellation `.take()`s it so it can
+ /// propagate a typed error (e.g. `TransportError::AuthRequired`) instead
+ /// of a generic "cancelled". This works because `initialize` is the sole
+ /// in-flight request at startup, but would need rethinking if concurrent
+ /// requests are ever issued during that phase.
+ last_transport_error: Arc<Mutex<Option<anyhow::Error>>>,
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
@@ -223,13 +231,16 @@ impl Client {
input.or(err)
});
+ let last_transport_error: Arc<Mutex<Option<anyhow::Error>>> = Arc::new(Mutex::new(None));
let output_task = cx.background_spawn({
let transport = transport.clone();
+ let last_transport_error = last_transport_error.clone();
Self::handle_output(
transport,
outbound_rx,
output_done_tx,
response_handlers.clone(),
+ last_transport_error,
)
.log_err()
});
@@ -246,6 +257,7 @@ impl Client {
output_done_rx: Mutex::new(Some(output_done_rx)),
transport,
request_timeout,
+ last_transport_error,
})
}
@@ -279,7 +291,7 @@ impl Client {
if let Some(handlers) = response_handlers.lock().as_mut()
&& let Some(handler) = handlers.remove(&response.id)
{
- handler(Ok(message.to_string()));
+ handler(message.to_string());
}
} else if let Ok(notification) = serde_json::from_str::<AnyNotification>(&message) {
subscription_set.lock().notify(
@@ -315,6 +327,7 @@ impl Client {
outbound_rx: channel::Receiver<String>,
output_done_tx: barrier::Sender,
response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+ last_transport_error: Arc<Mutex<Option<anyhow::Error>>>,
) -> anyhow::Result<()> {
let _clear_response_handlers = util::defer({
let response_handlers = response_handlers.clone();
@@ -324,7 +337,11 @@ impl Client {
});
while let Ok(message) = outbound_rx.recv().await {
log::trace!("outgoing message: {}", message);
- transport.send(message).await?;
+ if let Err(err) = transport.send(message).await {
+ log::debug!("transport send failed: {:#}", err);
+ *last_transport_error.lock() = Some(err);
+ return Ok(());
+ }
}
drop(output_done_tx);
Ok(())
@@ -408,7 +425,7 @@ impl Client {
response = rx.fuse() => {
let elapsed = started.elapsed();
log::trace!("took {elapsed:?} to receive response to {method:?} id {id}");
- match response? {
+ match response {
Ok(response) => {
let parsed: AnyResponse = serde_json::from_str(&response)?;
if let Some(error) = parsed.error {
@@ -419,7 +436,12 @@ impl Client {
anyhow::bail!("Invalid response: no result or error");
}
}
- Err(_) => anyhow::bail!("cancelled")
+ Err(_canceled) => {
+ if let Some(err) = self.last_transport_error.lock().take() {
+ return Err(err);
+ }
+ anyhow::bail!("cancelled")
+ }
}
}
_ = cancel_fut => {
@@ -1,5 +1,6 @@
pub mod client;
pub mod listener;
+pub mod oauth;
pub mod protocol;
#[cfg(any(test, feature = "test-support"))]
pub mod test;
@@ -0,0 +1,2800 @@
+//! OAuth 2.0 authentication for MCP servers using the Authorization Code +
+//! PKCE flow, per the MCP spec's OAuth profile.
+//!
+//! The flow is split into two phases:
+//!
+//! 1. **Discovery** ([`discover`]) fetches Protected Resource Metadata and
+//! Authorization Server Metadata. This can happen early (e.g. on a 401
+//! during server startup) because it doesn't need the redirect URI yet.
+//!
+//! 2. **Client registration** ([`resolve_client_registration`]) is separate
+//! because DCR requires the actual loopback redirect URI, which includes an
+//! ephemeral port that only exists once the callback server has started.
+//!
+//! After authentication, the full state is captured in [`OAuthSession`] which
+//! is persisted to the keychain. On next startup, the stored session feeds
+//! directly into [`McpOAuthTokenProvider`], giving a refresh-capable provider
+//! without requiring another browser flow.
+
+use anyhow::{Context as _, Result, anyhow, bail};
+use async_trait::async_trait;
+use base64::Engine as _;
+use futures::AsyncReadExt as _;
+use futures::channel::mpsc;
+use http_client::{AsyncBody, HttpClient, Request};
+use parking_lot::Mutex as SyncMutex;
+use rand::Rng as _;
+use serde::{Deserialize, Serialize};
+use sha2::{Digest, Sha256};
+
+use std::str::FromStr;
+use std::sync::Arc;
+use std::time::{Duration, SystemTime};
+use url::Url;
+use util::ResultExt as _;
+
+/// The CIMD URL where Zed's OAuth client metadata document is hosted.
+pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
+
+/// Validate that a URL is safe to use as an OAuth endpoint.
+///
+/// OAuth endpoints carry sensitive material (authorization codes, PKCE
+/// verifiers, tokens) and must use TLS. Plain HTTP is only permitted for
+/// loopback addresses, per RFC 8252 Section 8.3.
+fn require_https_or_loopback(url: &Url) -> Result<()> {
+ if url.scheme() == "https" {
+ return Ok(());
+ }
+ if url.scheme() == "http" {
+ if let Some(host) = url.host() {
+ match host {
+ url::Host::Ipv4(ip) if ip.is_loopback() => return Ok(()),
+ url::Host::Ipv6(ip) if ip.is_loopback() => return Ok(()),
+ url::Host::Domain(d) if d.eq_ignore_ascii_case("localhost") => return Ok(()),
+ _ => {}
+ }
+ }
+ }
+ bail!(
+ "OAuth endpoint must use HTTPS (got {}://{})",
+ url.scheme(),
+ url.host_str().unwrap_or("?")
+ )
+}
+
+/// Validate that a URL is safe to use as an OAuth endpoint, including SSRF
+/// protections against private/reserved IP ranges.
+///
+/// This wraps [`require_https_or_loopback`] and adds IP-range checks to prevent
+/// an attacker-controlled MCP server from directing Zed to fetch internal
+/// network resources via metadata URLs.
+///
+/// **Known limitation:** Domain-name URLs that resolve to private IPs are *not*
+/// blocked here โ full mitigation requires resolver-level validation (e.g. a
+/// custom `Resolve` implementation). This function only blocks IP-literal URLs.
+fn validate_oauth_url(url: &Url) -> Result<()> {
+ require_https_or_loopback(url)?;
+
+ if let Some(host) = url.host() {
+ match host {
+ url::Host::Ipv4(ip) => {
+ // Loopback is already allowed by require_https_or_loopback.
+ if ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified()
+ {
+ bail!(
+ "OAuth endpoint must not point to private/reserved IP: {}",
+ ip
+ );
+ }
+ }
+ url::Host::Ipv6(ip) => {
+ // Check for IPv4-mapped IPv6 addresses (::ffff:a.b.c.d) which
+ // could bypass the IPv4 checks above.
+ if let Some(mapped_v4) = ip.to_ipv4_mapped() {
+ if mapped_v4.is_private()
+ || mapped_v4.is_link_local()
+ || mapped_v4.is_broadcast()
+ || mapped_v4.is_unspecified()
+ {
+ bail!(
+ "OAuth endpoint must not point to private/reserved IP: ::ffff:{}",
+ mapped_v4
+ );
+ }
+ }
+
+ if ip.is_unspecified() || ip.is_multicast() {
+ bail!(
+ "OAuth endpoint must not point to reserved IPv6 address: {}",
+ ip
+ );
+ }
+ // IPv6 Unique Local Addresses (fc00::/7). is_unique_local() is
+ // nightly-only, so check the prefix manually.
+ if (ip.segments()[0] & 0xfe00) == 0xfc00 {
+ bail!(
+ "OAuth endpoint must not point to IPv6 unique-local address: {}",
+ ip
+ );
+ }
+ }
+ url::Host::Domain(_) => {
+ // Domain-based SSRF prevention requires resolver-level checks.
+ // See known limitation in the doc comment above.
+ }
+ }
+ }
+
+ Ok(())
+}
+
+/// Parsed from the MCP server's WWW-Authenticate header or well-known endpoint
+/// per RFC 9728 (OAuth 2.0 Protected Resource Metadata).
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ProtectedResourceMetadata {
+ pub resource: Url,
+ pub authorization_servers: Vec<Url>,
+ pub scopes_supported: Option<Vec<String>>,
+}
+
+/// Parsed from the authorization server's .well-known endpoint
+/// per RFC 8414 (OAuth 2.0 Authorization Server Metadata).
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct AuthServerMetadata {
+ pub issuer: Url,
+ pub authorization_endpoint: Url,
+ pub token_endpoint: Url,
+ pub registration_endpoint: Option<Url>,
+ pub scopes_supported: Option<Vec<String>>,
+ pub code_challenge_methods_supported: Option<Vec<String>>,
+ pub client_id_metadata_document_supported: bool,
+}
+
+/// The result of client registration โ either CIMD or DCR.
+#[derive(Clone, Serialize, Deserialize)]
+pub struct OAuthClientRegistration {
+ pub client_id: String,
+ /// Only present for DCR-minted registrations.
+ pub client_secret: Option<String>,
+}
+
+impl std::fmt::Debug for OAuthClientRegistration {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("OAuthClientRegistration")
+ .field("client_id", &self.client_id)
+ .field(
+ "client_secret",
+ &self.client_secret.as_ref().map(|_| "[redacted]"),
+ )
+ .finish()
+ }
+}
+
+/// Access and refresh tokens obtained from the token endpoint.
+#[derive(Clone, Serialize, Deserialize)]
+pub struct OAuthTokens {
+ pub access_token: String,
+ pub refresh_token: Option<String>,
+ pub expires_at: Option<SystemTime>,
+}
+
+impl std::fmt::Debug for OAuthTokens {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("OAuthTokens")
+ .field("access_token", &"[redacted]")
+ .field(
+ "refresh_token",
+ &self.refresh_token.as_ref().map(|_| "[redacted]"),
+ )
+ .field("expires_at", &self.expires_at)
+ .finish()
+ }
+}
+
+/// Everything discovered before the browser flow starts. Client registration is
+/// resolved separately, once the real redirect URI is known.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct OAuthDiscovery {
+ pub resource_metadata: ProtectedResourceMetadata,
+ pub auth_server_metadata: AuthServerMetadata,
+ pub scopes: Vec<String>,
+}
+
+/// The persisted OAuth session for a context server.
+///
+/// Stored in the keychain so startup can restore a refresh-capable provider
+/// without another browser flow. Deliberately excludes the full discovery
+/// metadata to keep the serialized size well within keychain item limits.
+#[derive(Clone, Serialize, Deserialize)]
+pub struct OAuthSession {
+ pub token_endpoint: Url,
+ pub resource: Url,
+ pub client_registration: OAuthClientRegistration,
+ pub tokens: OAuthTokens,
+}
+
+impl std::fmt::Debug for OAuthSession {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("OAuthSession")
+ .field("token_endpoint", &self.token_endpoint)
+ .field("resource", &self.resource)
+ .field("client_registration", &self.client_registration)
+ .field("tokens", &self.tokens)
+ .finish()
+ }
+}
+
+/// Error codes defined by RFC 6750 Section 3.1 for Bearer token authentication.
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+pub enum BearerError {
+ /// The request is missing a required parameter, includes an unsupported
+ /// parameter or parameter value, or is otherwise malformed.
+ InvalidRequest,
+ /// The access token provided is expired, revoked, malformed, or invalid.
+ InvalidToken,
+ /// The request requires higher privileges than provided by the access token.
+ InsufficientScope,
+ /// An unrecognized error code (extension or future spec addition).
+ Other,
+}
+
+impl BearerError {
+ fn parse(value: &str) -> Self {
+ match value {
+ "invalid_request" => BearerError::InvalidRequest,
+ "invalid_token" => BearerError::InvalidToken,
+ "insufficient_scope" => BearerError::InsufficientScope,
+ _ => BearerError::Other,
+ }
+ }
+}
+
+/// Fields extracted from a `WWW-Authenticate: Bearer` header.
+///
+/// Per RFC 9728 Section 5.1, MCP servers include `resource_metadata` to point
+/// at the Protected Resource Metadata document. The optional `scope` parameter
+/// (RFC 6750 Section 3) indicates scopes required for the request.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct WwwAuthenticate {
+ pub resource_metadata: Option<Url>,
+ pub scope: Option<Vec<String>>,
+ /// The parsed `error` parameter per RFC 6750 Section 3.1.
+ pub error: Option<BearerError>,
+ pub error_description: Option<String>,
+}
+
+/// Parse a `WWW-Authenticate` header value.
+///
+/// Expects the `Bearer` scheme followed by comma-separated `key="value"` pairs.
+/// Per RFC 6750 and RFC 9728, the relevant parameters are:
+/// - `resource_metadata` โ URL of the Protected Resource Metadata document
+/// - `scope` โ space-separated list of required scopes
+/// - `error` โ error code (e.g. "insufficient_scope")
+/// - `error_description` โ human-readable error description
+pub fn parse_www_authenticate(header: &str) -> Result<WwwAuthenticate> {
+ let header = header.trim();
+
+ let params_str = if header.len() >= 6 && header[..6].eq_ignore_ascii_case("bearer") {
+ header[6..].trim()
+ } else {
+ bail!("WWW-Authenticate header does not use Bearer scheme");
+ };
+
+ if params_str.is_empty() {
+ return Ok(WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ });
+ }
+
+ let params = parse_auth_params(params_str);
+
+ let resource_metadata = params
+ .get("resource_metadata")
+ .map(|v| Url::parse(v))
+ .transpose()
+ .map_err(|e| anyhow!("invalid resource_metadata URL: {}", e))?;
+
+ let scope = params
+ .get("scope")
+ .map(|v| v.split_whitespace().map(String::from).collect());
+
+ let error = params.get("error").map(|v| BearerError::parse(v));
+ let error_description = params.get("error_description").cloned();
+
+ Ok(WwwAuthenticate {
+ resource_metadata,
+ scope,
+ error,
+ error_description,
+ })
+}
+
+/// Parse comma-separated `key="value"` or `key=token` parameters from an
+/// auth-param list (RFC 7235 Section 2.1).
+fn parse_auth_params(input: &str) -> collections::HashMap<String, String> {
+ let mut params = collections::HashMap::default();
+ let mut remaining = input.trim();
+
+ while !remaining.is_empty() {
+ // Skip leading whitespace and commas.
+ remaining = remaining.trim_start_matches(|c: char| c == ',' || c.is_whitespace());
+ if remaining.is_empty() {
+ break;
+ }
+
+ // Find the key (everything before '=').
+ let eq_pos = match remaining.find('=') {
+ Some(pos) => pos,
+ None => break,
+ };
+
+ let key = remaining[..eq_pos].trim().to_lowercase();
+ remaining = &remaining[eq_pos + 1..];
+ remaining = remaining.trim_start();
+
+ // Parse the value: either quoted or unquoted (token).
+ let value;
+ if remaining.starts_with('"') {
+ // Quoted string: find the closing quote, handling escaped chars.
+ remaining = &remaining[1..]; // skip opening quote
+ let mut val = String::new();
+ let mut chars = remaining.char_indices();
+ loop {
+ match chars.next() {
+ Some((_, '\\')) => {
+ // Escaped character โ take the next char literally.
+ if let Some((_, c)) = chars.next() {
+ val.push(c);
+ }
+ }
+ Some((i, '"')) => {
+ remaining = &remaining[i + 1..];
+ break;
+ }
+ Some((_, c)) => val.push(c),
+ None => {
+ remaining = "";
+ break;
+ }
+ }
+ }
+ value = val;
+ } else {
+ // Unquoted token: read until comma or whitespace.
+ let end = remaining
+ .find(|c: char| c == ',' || c.is_whitespace())
+ .unwrap_or(remaining.len());
+ value = remaining[..end].to_string();
+ remaining = &remaining[end..];
+ }
+
+ if !key.is_empty() {
+ params.insert(key, value);
+ }
+ }
+
+ params
+}
+
+/// Construct the well-known Protected Resource Metadata URIs for a given MCP
+/// server URL, per RFC 9728 Section 3.
+///
+/// Returns URIs in priority order:
+/// 1. Path-specific: `https://<host>/.well-known/oauth-protected-resource/<path>`
+/// 2. Root: `https://<host>/.well-known/oauth-protected-resource`
+pub fn protected_resource_metadata_urls(server_url: &Url) -> Vec<Url> {
+ let mut urls = Vec::new();
+ let base = format!("{}://{}", server_url.scheme(), server_url.authority());
+
+ let path = server_url.path().trim_start_matches('/');
+ if !path.is_empty() {
+ if let Ok(url) = Url::parse(&format!(
+ "{}/.well-known/oauth-protected-resource/{}",
+ base, path
+ )) {
+ urls.push(url);
+ }
+ }
+
+ if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-protected-resource", base)) {
+ urls.push(url);
+ }
+
+ urls
+}
+
+/// Construct the well-known Authorization Server Metadata URIs for a given
+/// issuer URL, per RFC 8414 Section 3.1 and Section 5 (OIDC compat).
+///
+/// Returns URIs in priority order, which differs depending on whether the
+/// issuer URL has a path component.
+pub fn auth_server_metadata_urls(issuer: &Url) -> Vec<Url> {
+ let mut urls = Vec::new();
+ let base = format!("{}://{}", issuer.scheme(), issuer.authority());
+ let path = issuer.path().trim_matches('/');
+
+ if !path.is_empty() {
+ // Issuer with path: try path-inserted variants first.
+ if let Ok(url) = Url::parse(&format!(
+ "{}/.well-known/oauth-authorization-server/{}",
+ base, path
+ )) {
+ urls.push(url);
+ }
+ if let Ok(url) = Url::parse(&format!(
+ "{}/.well-known/openid-configuration/{}",
+ base, path
+ )) {
+ urls.push(url);
+ }
+ if let Ok(url) = Url::parse(&format!(
+ "{}/{}/.well-known/openid-configuration",
+ base, path
+ )) {
+ urls.push(url);
+ }
+ } else {
+ // No path: standard well-known locations.
+ if let Ok(url) = Url::parse(&format!("{}/.well-known/oauth-authorization-server", base)) {
+ urls.push(url);
+ }
+ if let Ok(url) = Url::parse(&format!("{}/.well-known/openid-configuration", base)) {
+ urls.push(url);
+ }
+ }
+
+ urls
+}
+
+// -- Canonical server URI (RFC 8707) -----------------------------------------
+
+/// Derive the canonical resource URI for an MCP server URL, suitable for the
+/// `resource` parameter in authorization and token requests per RFC 8707.
+///
+/// Lowercases the scheme and host, preserves the path (without trailing slash),
+/// strips fragments and query strings.
+pub fn canonical_server_uri(server_url: &Url) -> String {
+ let mut uri = format!(
+ "{}://{}",
+ server_url.scheme().to_ascii_lowercase(),
+ server_url.host_str().unwrap_or("").to_ascii_lowercase(),
+ );
+ if let Some(port) = server_url.port() {
+ uri.push_str(&format!(":{}", port));
+ }
+ let path = server_url.path();
+ if path != "/" {
+ uri.push_str(path.trim_end_matches('/'));
+ }
+ uri
+}
+
+// -- Scope selection ---------------------------------------------------------
+
+/// Select scopes following the MCP spec's Scope Selection Strategy:
+/// 1. Use `scope` from the `WWW-Authenticate` challenge if present.
+/// 2. Fall back to `scopes_supported` from Protected Resource Metadata.
+/// 3. Return empty if neither is available.
+pub fn select_scopes(
+ www_authenticate: &WwwAuthenticate,
+ resource_metadata: &ProtectedResourceMetadata,
+) -> Vec<String> {
+ if let Some(ref scopes) = www_authenticate.scope {
+ if !scopes.is_empty() {
+ return scopes.clone();
+ }
+ }
+ resource_metadata
+ .scopes_supported
+ .clone()
+ .unwrap_or_default()
+}
+
+// -- Client registration strategy --------------------------------------------
+
+/// The registration approach to use, determined from auth server metadata.
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum ClientRegistrationStrategy {
+ /// The auth server supports CIMD. Use the CIMD URL as client_id directly.
+ Cimd { client_id: String },
+ /// The auth server has a registration endpoint. Caller must POST to it.
+ Dcr { registration_endpoint: Url },
+ /// No supported registration mechanism.
+ Unavailable,
+}
+
+/// Determine how to register with the authorization server, following the
+/// spec's recommended priority: CIMD first, DCR fallback.
+pub fn determine_registration_strategy(
+ auth_server_metadata: &AuthServerMetadata,
+) -> ClientRegistrationStrategy {
+ if auth_server_metadata.client_id_metadata_document_supported {
+ ClientRegistrationStrategy::Cimd {
+ client_id: CIMD_URL.to_string(),
+ }
+ } else if let Some(ref endpoint) = auth_server_metadata.registration_endpoint {
+ ClientRegistrationStrategy::Dcr {
+ registration_endpoint: endpoint.clone(),
+ }
+ } else {
+ ClientRegistrationStrategy::Unavailable
+ }
+}
+
+// -- PKCE (RFC 7636) ---------------------------------------------------------
+
+/// A PKCE code verifier and its S256 challenge.
+#[derive(Clone)]
+pub struct PkceChallenge {
+ pub verifier: String,
+ pub challenge: String,
+}
+
+impl std::fmt::Debug for PkceChallenge {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("PkceChallenge")
+ .field("verifier", &"[redacted]")
+ .field("challenge", &self.challenge)
+ .finish()
+ }
+}
+
+/// Generate a PKCE code verifier and S256 challenge per RFC 7636.
+///
+/// The verifier is 43 base64url characters derived from 32 random bytes.
+/// The challenge is `BASE64URL(SHA256(verifier))`.
+pub fn generate_pkce_challenge() -> PkceChallenge {
+ let mut random_bytes = [0u8; 32];
+ rand::rng().fill(&mut random_bytes);
+ let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
+ let verifier = engine.encode(&random_bytes);
+
+ let digest = Sha256::digest(verifier.as_bytes());
+ let challenge = engine.encode(digest);
+
+ PkceChallenge {
+ verifier,
+ challenge,
+ }
+}
+
+// -- Authorization URL construction ------------------------------------------
+
+/// Build the authorization URL for the OAuth Authorization Code + PKCE flow.
+pub fn build_authorization_url(
+ auth_server_metadata: &AuthServerMetadata,
+ client_id: &str,
+ redirect_uri: &str,
+ scopes: &[String],
+ resource: &str,
+ pkce: &PkceChallenge,
+ state: &str,
+) -> Url {
+ let mut url = auth_server_metadata.authorization_endpoint.clone();
+ {
+ let mut query = url.query_pairs_mut();
+ query.append_pair("response_type", "code");
+ query.append_pair("client_id", client_id);
+ query.append_pair("redirect_uri", redirect_uri);
+ if !scopes.is_empty() {
+ query.append_pair("scope", &scopes.join(" "));
+ }
+ query.append_pair("resource", resource);
+ query.append_pair("code_challenge", &pkce.challenge);
+ query.append_pair("code_challenge_method", "S256");
+ query.append_pair("state", state);
+ }
+ url
+}
+
+// -- Token endpoint request bodies -------------------------------------------
+
+/// The JSON body returned by the token endpoint on success.
+#[derive(Deserialize)]
+pub struct TokenResponse {
+ pub access_token: String,
+ #[serde(default)]
+ pub refresh_token: Option<String>,
+ #[serde(default)]
+ pub expires_in: Option<u64>,
+ #[serde(default)]
+ pub token_type: Option<String>,
+}
+
+impl std::fmt::Debug for TokenResponse {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("TokenResponse")
+ .field("access_token", &"[redacted]")
+ .field(
+ "refresh_token",
+ &self.refresh_token.as_ref().map(|_| "[redacted]"),
+ )
+ .field("expires_in", &self.expires_in)
+ .field("token_type", &self.token_type)
+ .finish()
+ }
+}
+
+impl TokenResponse {
+ /// Convert into `OAuthTokens`, computing `expires_at` from `expires_in`.
+ pub fn into_tokens(self) -> OAuthTokens {
+ let expires_at = self
+ .expires_in
+ .map(|secs| SystemTime::now() + Duration::from_secs(secs));
+ OAuthTokens {
+ access_token: self.access_token,
+ refresh_token: self.refresh_token,
+ expires_at,
+ }
+ }
+}
+
+/// Build the form-encoded body for an authorization code token exchange.
+pub fn token_exchange_params(
+ code: &str,
+ client_id: &str,
+ redirect_uri: &str,
+ code_verifier: &str,
+ resource: &str,
+) -> Vec<(&'static str, String)> {
+ vec![
+ ("grant_type", "authorization_code".to_string()),
+ ("code", code.to_string()),
+ ("redirect_uri", redirect_uri.to_string()),
+ ("client_id", client_id.to_string()),
+ ("code_verifier", code_verifier.to_string()),
+ ("resource", resource.to_string()),
+ ]
+}
+
+/// Build the form-encoded body for a token refresh request.
+pub fn token_refresh_params(
+ refresh_token: &str,
+ client_id: &str,
+ resource: &str,
+) -> Vec<(&'static str, String)> {
+ vec![
+ ("grant_type", "refresh_token".to_string()),
+ ("refresh_token", refresh_token.to_string()),
+ ("client_id", client_id.to_string()),
+ ("resource", resource.to_string()),
+ ]
+}
+
+// -- DCR request body (RFC 7591) ---------------------------------------------
+
+/// Build the JSON body for a Dynamic Client Registration request.
+///
+/// The `redirect_uri` should be the actual loopback URI with the ephemeral
+/// port (e.g. `http://127.0.0.1:12345/callback`). Some auth servers do strict
+/// redirect URI matching even for loopback addresses, so we register the
+/// exact URI we intend to use.
+pub fn dcr_registration_body(redirect_uri: &str) -> serde_json::Value {
+ serde_json::json!({
+ "client_name": "Zed",
+ "redirect_uris": [redirect_uri],
+ "grant_types": ["authorization_code"],
+ "response_types": ["code"],
+ "token_endpoint_auth_method": "none"
+ })
+}
+
+// -- Discovery (async, hits real endpoints) ----------------------------------
+
+/// Fetch Protected Resource Metadata from the MCP server.
+///
+/// Tries the `resource_metadata` URL from the `WWW-Authenticate` header first,
+/// then falls back to well-known URIs constructed from `server_url`.
+pub async fn fetch_protected_resource_metadata(
+ http_client: &Arc<dyn HttpClient>,
+ server_url: &Url,
+ www_authenticate: &WwwAuthenticate,
+) -> Result<ProtectedResourceMetadata> {
+ let candidate_urls = match &www_authenticate.resource_metadata {
+ Some(url) if url.origin() == server_url.origin() => vec![url.clone()],
+ Some(url) => {
+ log::warn!(
+ "Ignoring cross-origin resource_metadata URL {} \
+ (server origin: {})",
+ url,
+ server_url.origin().unicode_serialization()
+ );
+ protected_resource_metadata_urls(server_url)
+ }
+ None => protected_resource_metadata_urls(server_url),
+ };
+
+ for url in &candidate_urls {
+ match fetch_json::<ProtectedResourceMetadataResponse>(http_client, url).await {
+ Ok(response) => {
+ if response.authorization_servers.is_empty() {
+ bail!(
+ "Protected Resource Metadata at {} has no authorization_servers",
+ url
+ );
+ }
+ return Ok(ProtectedResourceMetadata {
+ resource: response.resource.unwrap_or_else(|| server_url.clone()),
+ authorization_servers: response.authorization_servers,
+ scopes_supported: response.scopes_supported,
+ });
+ }
+ Err(err) => {
+ log::debug!(
+ "Failed to fetch Protected Resource Metadata from {}: {}",
+ url,
+ err
+ );
+ }
+ }
+ }
+
+ bail!(
+ "Could not fetch Protected Resource Metadata for {}",
+ server_url
+ )
+}
+
+/// Fetch Authorization Server Metadata, trying RFC 8414 and OIDC Discovery
+/// endpoints in the priority order specified by the MCP spec.
+pub async fn fetch_auth_server_metadata(
+ http_client: &Arc<dyn HttpClient>,
+ issuer: &Url,
+) -> Result<AuthServerMetadata> {
+ let candidate_urls = auth_server_metadata_urls(issuer);
+
+ for url in &candidate_urls {
+ match fetch_json::<AuthServerMetadataResponse>(http_client, url).await {
+ Ok(response) => {
+ let reported_issuer = response.issuer.unwrap_or_else(|| issuer.clone());
+ if reported_issuer != *issuer {
+ bail!(
+ "Auth server metadata issuer mismatch: expected {}, got {}",
+ issuer,
+ reported_issuer
+ );
+ }
+
+ return Ok(AuthServerMetadata {
+ issuer: reported_issuer,
+ authorization_endpoint: response
+ .authorization_endpoint
+ .ok_or_else(|| anyhow!("missing authorization_endpoint"))?,
+ token_endpoint: response
+ .token_endpoint
+ .ok_or_else(|| anyhow!("missing token_endpoint"))?,
+ registration_endpoint: response.registration_endpoint,
+ scopes_supported: response.scopes_supported,
+ code_challenge_methods_supported: response.code_challenge_methods_supported,
+ client_id_metadata_document_supported: response
+ .client_id_metadata_document_supported
+ .unwrap_or(false),
+ });
+ }
+ Err(err) => {
+ log::debug!("Failed to fetch Auth Server Metadata from {}: {}", url, err);
+ }
+ }
+ }
+
+ bail!(
+ "Could not fetch Authorization Server Metadata for {}",
+ issuer
+ )
+}
+
+/// Run the full discovery flow: fetch resource metadata, then auth server
+/// metadata, then select scopes. Client registration is resolved separately,
+/// once the real redirect URI is known.
+pub async fn discover(
+ http_client: &Arc<dyn HttpClient>,
+ server_url: &Url,
+ www_authenticate: &WwwAuthenticate,
+) -> Result<OAuthDiscovery> {
+ let resource_metadata =
+ fetch_protected_resource_metadata(http_client, server_url, www_authenticate).await?;
+
+ let auth_server_url = resource_metadata
+ .authorization_servers
+ .first()
+ .ok_or_else(|| anyhow!("no authorization servers in resource metadata"))?;
+
+ let auth_server_metadata = fetch_auth_server_metadata(http_client, auth_server_url).await?;
+
+ // Verify PKCE S256 support (spec requirement).
+ match &auth_server_metadata.code_challenge_methods_supported {
+ Some(methods) if methods.iter().any(|m| m == "S256") => {}
+ Some(_) => bail!("authorization server does not support S256 PKCE"),
+ None => bail!("authorization server does not advertise code_challenge_methods_supported"),
+ }
+
+ // Verify there is at least one supported registration strategy before we
+ // present the server as ready to authenticate.
+ match determine_registration_strategy(&auth_server_metadata) {
+ ClientRegistrationStrategy::Cimd { .. } | ClientRegistrationStrategy::Dcr { .. } => {}
+ ClientRegistrationStrategy::Unavailable => {
+ bail!("authorization server supports neither CIMD nor DCR")
+ }
+ }
+
+ let scopes = select_scopes(www_authenticate, &resource_metadata);
+
+ Ok(OAuthDiscovery {
+ resource_metadata,
+ auth_server_metadata,
+ scopes,
+ })
+}
+
+/// Resolve the OAuth client registration for an authorization flow.
+///
+/// CIMD uses the static client metadata document directly. For DCR, a fresh
+/// registration is performed each time because the loopback redirect URI
+/// includes an ephemeral port that changes every flow.
+pub async fn resolve_client_registration(
+ http_client: &Arc<dyn HttpClient>,
+ discovery: &OAuthDiscovery,
+ redirect_uri: &str,
+) -> Result<OAuthClientRegistration> {
+ match determine_registration_strategy(&discovery.auth_server_metadata) {
+ ClientRegistrationStrategy::Cimd { client_id } => Ok(OAuthClientRegistration {
+ client_id,
+ client_secret: None,
+ }),
+ ClientRegistrationStrategy::Dcr {
+ registration_endpoint,
+ } => perform_dcr(http_client, ®istration_endpoint, redirect_uri).await,
+ ClientRegistrationStrategy::Unavailable => {
+ bail!("authorization server supports neither CIMD nor DCR")
+ }
+ }
+}
+
+// -- Dynamic Client Registration (RFC 7591) ----------------------------------
+
+/// Perform Dynamic Client Registration with the authorization server.
+pub async fn perform_dcr(
+ http_client: &Arc<dyn HttpClient>,
+ registration_endpoint: &Url,
+ redirect_uri: &str,
+) -> Result<OAuthClientRegistration> {
+ validate_oauth_url(registration_endpoint)?;
+
+ let body = dcr_registration_body(redirect_uri);
+ let body_bytes = serde_json::to_vec(&body)?;
+
+ let request = Request::builder()
+ .method(http_client::http::Method::POST)
+ .uri(registration_endpoint.as_str())
+ .header("Content-Type", "application/json")
+ .header("Accept", "application/json")
+ .body(AsyncBody::from(body_bytes))?;
+
+ let mut response = http_client.send(request).await?;
+
+ if !response.status().is_success() {
+ let mut error_body = String::new();
+ response.body_mut().read_to_string(&mut error_body).await?;
+ bail!(
+ "DCR failed with status {}: {}",
+ response.status(),
+ error_body
+ );
+ }
+
+ let mut response_body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut response_body)
+ .await?;
+
+ let dcr_response: DcrResponse =
+ serde_json::from_str(&response_body).context("failed to parse DCR response")?;
+
+ Ok(OAuthClientRegistration {
+ client_id: dcr_response.client_id,
+ client_secret: dcr_response.client_secret,
+ })
+}
+
+// -- Token exchange and refresh (async) --------------------------------------
+
+/// Exchange an authorization code for tokens at the token endpoint.
+pub async fn exchange_code(
+ http_client: &Arc<dyn HttpClient>,
+ auth_server_metadata: &AuthServerMetadata,
+ code: &str,
+ client_id: &str,
+ redirect_uri: &str,
+ code_verifier: &str,
+ resource: &str,
+) -> Result<OAuthTokens> {
+ let params = token_exchange_params(code, client_id, redirect_uri, code_verifier, resource);
+ post_token_request(http_client, &auth_server_metadata.token_endpoint, ¶ms).await
+}
+
+/// Refresh tokens using a refresh token.
+pub async fn refresh_tokens(
+ http_client: &Arc<dyn HttpClient>,
+ token_endpoint: &Url,
+ refresh_token: &str,
+ client_id: &str,
+ resource: &str,
+) -> Result<OAuthTokens> {
+ let params = token_refresh_params(refresh_token, client_id, resource);
+ post_token_request(http_client, token_endpoint, ¶ms).await
+}
+
+/// POST form-encoded parameters to a token endpoint and parse the response.
+async fn post_token_request(
+ http_client: &Arc<dyn HttpClient>,
+ token_endpoint: &Url,
+ params: &[(&str, String)],
+) -> Result<OAuthTokens> {
+ validate_oauth_url(token_endpoint)?;
+
+ let body = url::form_urlencoded::Serializer::new(String::new())
+ .extend_pairs(params.iter().map(|(k, v)| (*k, v.as_str())))
+ .finish();
+
+ let request = Request::builder()
+ .method(http_client::http::Method::POST)
+ .uri(token_endpoint.as_str())
+ .header("Content-Type", "application/x-www-form-urlencoded")
+ .header("Accept", "application/json")
+ .body(AsyncBody::from(body.into_bytes()))?;
+
+ let mut response = http_client.send(request).await?;
+
+ if !response.status().is_success() {
+ let mut error_body = String::new();
+ response.body_mut().read_to_string(&mut error_body).await?;
+ bail!(
+ "token request failed with status {}: {}",
+ response.status(),
+ error_body
+ );
+ }
+
+ let mut response_body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut response_body)
+ .await?;
+
+ let token_response: TokenResponse =
+ serde_json::from_str(&response_body).context("failed to parse token response")?;
+
+ Ok(token_response.into_tokens())
+}
+
+// -- Loopback HTTP callback server -------------------------------------------
+
+/// An OAuth authorization callback received via the loopback HTTP server.
+pub struct OAuthCallback {
+ pub code: String,
+ pub state: String,
+}
+
+impl std::fmt::Debug for OAuthCallback {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("OAuthCallback")
+ .field("code", &"[redacted]")
+ .field("state", &"[redacted]")
+ .finish()
+ }
+}
+
+impl OAuthCallback {
+ /// Parse the query string from a callback URL like
+ /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
+ pub fn parse_query(query: &str) -> Result<Self> {
+ let mut code: Option<String> = None;
+ let mut state: Option<String> = None;
+ let mut error: Option<String> = None;
+ let mut error_description: Option<String> = None;
+
+ for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
+ match key.as_ref() {
+ "code" => {
+ if !value.is_empty() {
+ code = Some(value.into_owned());
+ }
+ }
+ "state" => {
+ if !value.is_empty() {
+ state = Some(value.into_owned());
+ }
+ }
+ "error" => {
+ if !value.is_empty() {
+ error = Some(value.into_owned());
+ }
+ }
+ "error_description" => {
+ if !value.is_empty() {
+ error_description = Some(value.into_owned());
+ }
+ }
+ _ => {}
+ }
+ }
+
+ // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
+ // checking for missing code/state.
+ if let Some(error_code) = error {
+ bail!(
+ "OAuth authorization failed: {} ({})",
+ error_code,
+ error_description.as_deref().unwrap_or("no description")
+ );
+ }
+
+ let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
+ let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
+
+ Ok(Self { code, state })
+ }
+}
+
+/// How long to wait for the browser to complete the OAuth flow before giving
+/// up and releasing the loopback port.
+const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
+
+/// Start a loopback HTTP server to receive the OAuth authorization callback.
+///
+/// Binds to an ephemeral loopback port for each flow.
+///
+/// Returns `(redirect_uri, callback_future)`. The caller should use the
+/// redirect URI in the authorization request, open the browser, then await
+/// the future to receive the callback.
+///
+/// The server accepts exactly one request on `/callback`, validates that it
+/// contains `code` and `state` query parameters, responds with a minimal
+/// HTML page telling the user they can close the tab, and shuts down.
+///
+/// The callback server shuts down when the returned oneshot receiver is dropped
+/// (e.g. because the authentication task was cancelled), or after a timeout
+/// ([CALLBACK_TIMEOUT]).
+pub async fn start_callback_server() -> Result<(
+ String,
+ futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
+)> {
+ let server = tiny_http::Server::http("127.0.0.1:0")
+ .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
+ let port = server
+ .server_addr()
+ .to_ip()
+ .context("server not bound to a TCP address")?
+ .port();
+
+ let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
+
+ let (tx, rx) = futures::channel::oneshot::channel();
+
+ // `tiny_http` is blocking, so we run it on a background thread.
+ // The `recv_timeout` loop lets us check for cancellation (the receiver
+ // being dropped) and enforce an overall timeout.
+ std::thread::spawn(move || {
+ let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
+
+ loop {
+ if tx.is_canceled() {
+ return;
+ }
+ let remaining = deadline.saturating_duration_since(std::time::Instant::now());
+ if remaining.is_zero() {
+ return;
+ }
+
+ let timeout = remaining.min(Duration::from_millis(500));
+ let Some(request) = (match server.recv_timeout(timeout) {
+ Ok(req) => req,
+ Err(_) => {
+ let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
+ return;
+ }
+ }) else {
+ // Timeout with no request โ loop back and check cancellation.
+ continue;
+ };
+
+ let result = handle_callback_request(&request);
+
+ let (status_code, body) = match &result {
+ Ok(_) => (
+ 200,
+ "<html><body><h1>Authorization successful</h1>\
+ <p>You can close this tab and return to Zed.</p></body></html>",
+ ),
+ Err(err) => {
+ log::error!("OAuth callback error: {}", err);
+ (
+ 400,
+ "<html><body><h1>Authorization failed</h1>\
+ <p>Something went wrong. Please try again from Zed.</p></body></html>",
+ )
+ }
+ };
+
+ let response = tiny_http::Response::from_string(body)
+ .with_status_code(status_code)
+ .with_header(
+ tiny_http::Header::from_str("Content-Type: text/html")
+ .expect("failed to construct response header"),
+ )
+ .with_header(
+ tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
+ .expect("failed to construct response header"),
+ );
+ request.respond(response).log_err();
+
+ let _ = tx.send(result);
+ return;
+ }
+ });
+
+ Ok((redirect_uri, rx))
+}
+
+/// Extract the `code` and `state` query parameters from an OAuth callback
+/// request to `/callback`.
+fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
+ let url = Url::parse(&format!("http://localhost{}", request.url()))
+ .context("malformed callback request URL")?;
+
+ if url.path() != "/callback" {
+ bail!("unexpected path in OAuth callback: {}", url.path());
+ }
+
+ let query = url
+ .query()
+ .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
+ OAuthCallback::parse_query(query)
+}
+
+// -- JSON fetch helper -------------------------------------------------------
+
+async fn fetch_json<T: serde::de::DeserializeOwned>(
+ http_client: &Arc<dyn HttpClient>,
+ url: &Url,
+) -> Result<T> {
+ validate_oauth_url(url)?;
+
+ let request = Request::builder()
+ .method(http_client::http::Method::GET)
+ .uri(url.as_str())
+ .header("Accept", "application/json")
+ .body(AsyncBody::default())?;
+
+ let mut response = http_client.send(request).await?;
+
+ if !response.status().is_success() {
+ bail!("HTTP {} fetching {}", response.status(), url);
+ }
+
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ serde_json::from_str(&body).with_context(|| format!("failed to parse JSON from {}", url))
+}
+
+// -- Serde response types for discovery --------------------------------------
+
+#[derive(Debug, Deserialize)]
+struct ProtectedResourceMetadataResponse {
+ #[serde(default)]
+ resource: Option<Url>,
+ #[serde(default)]
+ authorization_servers: Vec<Url>,
+ #[serde(default)]
+ scopes_supported: Option<Vec<String>>,
+}
+
+#[derive(Debug, Deserialize)]
+struct AuthServerMetadataResponse {
+ #[serde(default)]
+ issuer: Option<Url>,
+ #[serde(default)]
+ authorization_endpoint: Option<Url>,
+ #[serde(default)]
+ token_endpoint: Option<Url>,
+ #[serde(default)]
+ registration_endpoint: Option<Url>,
+ #[serde(default)]
+ scopes_supported: Option<Vec<String>>,
+ #[serde(default)]
+ code_challenge_methods_supported: Option<Vec<String>>,
+ #[serde(default)]
+ client_id_metadata_document_supported: Option<bool>,
+}
+
+#[derive(Debug, Deserialize)]
+struct DcrResponse {
+ client_id: String,
+ #[serde(default)]
+ client_secret: Option<String>,
+}
+
+/// Provides OAuth tokens to the HTTP transport layer.
+///
+/// The transport calls `access_token()` before each request. On a 401 response
+/// it calls `try_refresh()` and retries once if the refresh succeeds.
+#[async_trait]
+pub trait OAuthTokenProvider: Send + Sync {
+ /// Returns the current access token, if one is available.
+ fn access_token(&self) -> Option<String>;
+
+ /// Attempts to refresh the access token. Returns `true` if a new token was
+ /// obtained and the request should be retried.
+ async fn try_refresh(&self) -> Result<bool>;
+}
+
+/// Concrete `OAuthTokenProvider` backed by a full persisted OAuth session and
+/// an HTTP client for token refresh. The same provider type is used both after
+/// an interactive authentication flow and when restoring a saved session from
+/// the keychain on startup.
+pub struct McpOAuthTokenProvider {
+ session: SyncMutex<OAuthSession>,
+ http_client: Arc<dyn HttpClient>,
+ token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
+}
+
+impl McpOAuthTokenProvider {
+ pub fn new(
+ session: OAuthSession,
+ http_client: Arc<dyn HttpClient>,
+ token_refresh_tx: Option<mpsc::UnboundedSender<OAuthSession>>,
+ ) -> Self {
+ Self {
+ session: SyncMutex::new(session),
+ http_client,
+ token_refresh_tx,
+ }
+ }
+
+ fn access_token_is_expired(tokens: &OAuthTokens) -> bool {
+ tokens.expires_at.is_some_and(|expires_at| {
+ SystemTime::now()
+ .checked_add(Duration::from_secs(30))
+ .is_some_and(|now_with_buffer| expires_at <= now_with_buffer)
+ })
+ }
+}
+
+#[async_trait]
+impl OAuthTokenProvider for McpOAuthTokenProvider {
+ fn access_token(&self) -> Option<String> {
+ let session = self.session.lock();
+ if Self::access_token_is_expired(&session.tokens) {
+ return None;
+ }
+ Some(session.tokens.access_token.clone())
+ }
+
+ async fn try_refresh(&self) -> Result<bool> {
+ let (refresh_token, token_endpoint, resource, client_id) = {
+ let session = self.session.lock();
+ match session.tokens.refresh_token.clone() {
+ Some(refresh_token) => (
+ refresh_token,
+ session.token_endpoint.clone(),
+ session.resource.clone(),
+ session.client_registration.client_id.clone(),
+ ),
+ None => return Ok(false),
+ }
+ };
+
+ let resource_str = canonical_server_uri(&resource);
+
+ match refresh_tokens(
+ &self.http_client,
+ &token_endpoint,
+ &refresh_token,
+ &client_id,
+ &resource_str,
+ )
+ .await
+ {
+ Ok(mut new_tokens) => {
+ if new_tokens.refresh_token.is_none() {
+ new_tokens.refresh_token = Some(refresh_token);
+ }
+
+ {
+ let mut session = self.session.lock();
+ session.tokens = new_tokens;
+
+ if let Some(ref tx) = self.token_refresh_tx {
+ tx.unbounded_send(session.clone()).ok();
+ }
+ }
+
+ Ok(true)
+ }
+ Err(err) => {
+ log::warn!("OAuth token refresh failed: {}", err);
+ Ok(false)
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use http_client::Response;
+
+ // -- require_https_or_loopback tests ------------------------------------
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_https() {
+ let url = Url::parse("https://auth.example.com/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_rejects_http_remote() {
+ let url = Url::parse("http://auth.example.com/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_err());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_127_0_0_1() {
+ let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_ipv6_loopback() {
+ let url = Url::parse("http://[::1]:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_localhost() {
+ let url = Url::parse("http://localhost:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_accepts_http_localhost_case_insensitive() {
+ let url = Url::parse("http://LOCALHOST:8080/callback").unwrap();
+ assert!(require_https_or_loopback(&url).is_ok());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_rejects_http_non_loopback_ip() {
+ let url = Url::parse("http://192.168.1.1:8080/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_err());
+ }
+
+ #[test]
+ fn test_require_https_or_loopback_rejects_ftp() {
+ let url = Url::parse("ftp://auth.example.com/token").unwrap();
+ assert!(require_https_or_loopback(&url).is_err());
+ }
+
+ // -- validate_oauth_url (SSRF) tests ------------------------------------
+
+ #[test]
+ fn test_validate_oauth_url_accepts_https_public() {
+ let url = Url::parse("https://auth.example.com/token").unwrap();
+ assert!(validate_oauth_url(&url).is_ok());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_private_ipv4_10() {
+ let url = Url::parse("https://10.0.0.1/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_private_ipv4_172() {
+ let url = Url::parse("https://172.16.0.1/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_private_ipv4_192() {
+ let url = Url::parse("https://192.168.1.1/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_link_local() {
+ let url = Url::parse("https://169.254.169.254/latest/meta-data/").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv6_ula() {
+ let url = Url::parse("https://[fd12:3456:789a::1]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv6_unspecified() {
+ let url = Url::parse("https://[::]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_private() {
+ let url = Url::parse("https://[::ffff:10.0.0.1]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_rejects_ipv4_mapped_ipv6_link_local() {
+ let url = Url::parse("https://[::ffff:169.254.169.254]/token").unwrap();
+ assert!(validate_oauth_url(&url).is_err());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_allows_http_loopback() {
+ // Loopback is permitted (it's our callback server).
+ let url = Url::parse("http://127.0.0.1:8080/callback").unwrap();
+ assert!(validate_oauth_url(&url).is_ok());
+ }
+
+ #[test]
+ fn test_validate_oauth_url_allows_https_public_ip() {
+ let url = Url::parse("https://93.184.216.34/token").unwrap();
+ assert!(validate_oauth_url(&url).is_ok());
+ }
+
+ // -- parse_www_authenticate tests ----------------------------------------
+
+ #[test]
+ fn test_parse_www_authenticate_with_resource_metadata_and_scope() {
+ let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="files:read user:profile""#;
+ let result = parse_www_authenticate(header).unwrap();
+
+ assert_eq!(
+ result.resource_metadata.as_ref().map(|u| u.as_str()),
+ Some("https://mcp.example.com/.well-known/oauth-protected-resource")
+ );
+ assert_eq!(
+ result.scope,
+ Some(vec!["files:read".to_string(), "user:profile".to_string()])
+ );
+ assert_eq!(result.error, None);
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_resource_metadata_only() {
+ let header = r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#;
+ let result = parse_www_authenticate(header).unwrap();
+
+ assert_eq!(
+ result.resource_metadata.as_ref().map(|u| u.as_str()),
+ Some("https://mcp.example.com/.well-known/oauth-protected-resource")
+ );
+ assert_eq!(result.scope, None);
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_bare_bearer() {
+ let result = parse_www_authenticate("Bearer").unwrap();
+ assert_eq!(result.resource_metadata, None);
+ assert_eq!(result.scope, None);
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_with_error() {
+ let header = r#"Bearer error="insufficient_scope", scope="files:read files:write", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", error_description="Additional file write permission required""#;
+ let result = parse_www_authenticate(header).unwrap();
+
+ assert_eq!(result.error, Some(BearerError::InsufficientScope));
+ assert_eq!(
+ result.error_description.as_deref(),
+ Some("Additional file write permission required")
+ );
+ assert_eq!(
+ result.scope,
+ Some(vec!["files:read".to_string(), "files:write".to_string()])
+ );
+ assert!(result.resource_metadata.is_some());
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_invalid_token_error() {
+ let header =
+ r#"Bearer error="invalid_token", error_description="The access token expired""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert_eq!(result.error, Some(BearerError::InvalidToken));
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_invalid_request_error() {
+ let header = r#"Bearer error="invalid_request""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert_eq!(result.error, Some(BearerError::InvalidRequest));
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_unknown_error() {
+ let header = r#"Bearer error="some_future_error""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert_eq!(result.error, Some(BearerError::Other));
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_rejects_non_bearer() {
+ let result = parse_www_authenticate("Basic realm=\"example\"");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_case_insensitive_scheme() {
+ let header = r#"bearer resource_metadata="https://example.com/.well-known/oauth-protected-resource""#;
+ let result = parse_www_authenticate(header).unwrap();
+ assert!(result.resource_metadata.is_some());
+ }
+
+ #[test]
+ fn test_parse_www_authenticate_multiline_style() {
+ // Some servers emit the header spread across multiple lines joined by
+ // whitespace, as shown in the spec examples.
+ let header = "Bearer resource_metadata=\"https://mcp.example.com/.well-known/oauth-protected-resource\",\n scope=\"files:read\"";
+ let result = parse_www_authenticate(header).unwrap();
+ assert!(result.resource_metadata.is_some());
+ assert_eq!(result.scope, Some(vec!["files:read".to_string()]));
+ }
+
+ #[test]
+ fn test_protected_resource_metadata_urls_with_path() {
+ let server_url = Url::parse("https://api.example.com/v1/mcp").unwrap();
+ let urls = protected_resource_metadata_urls(&server_url);
+
+ assert_eq!(urls.len(), 2);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
+ );
+ assert_eq!(
+ urls[1].as_str(),
+ "https://api.example.com/.well-known/oauth-protected-resource"
+ );
+ }
+
+ #[test]
+ fn test_protected_resource_metadata_urls_without_path() {
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let urls = protected_resource_metadata_urls(&server_url);
+
+ assert_eq!(urls.len(), 1);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://mcp.example.com/.well-known/oauth-protected-resource"
+ );
+ }
+
+ #[test]
+ fn test_auth_server_metadata_urls_with_path() {
+ let issuer = Url::parse("https://auth.example.com/tenant1").unwrap();
+ let urls = auth_server_metadata_urls(&issuer);
+
+ assert_eq!(urls.len(), 3);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://auth.example.com/.well-known/oauth-authorization-server/tenant1"
+ );
+ assert_eq!(
+ urls[1].as_str(),
+ "https://auth.example.com/.well-known/openid-configuration/tenant1"
+ );
+ assert_eq!(
+ urls[2].as_str(),
+ "https://auth.example.com/tenant1/.well-known/openid-configuration"
+ );
+ }
+
+ #[test]
+ fn test_auth_server_metadata_urls_without_path() {
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let urls = auth_server_metadata_urls(&issuer);
+
+ assert_eq!(urls.len(), 2);
+ assert_eq!(
+ urls[0].as_str(),
+ "https://auth.example.com/.well-known/oauth-authorization-server"
+ );
+ assert_eq!(
+ urls[1].as_str(),
+ "https://auth.example.com/.well-known/openid-configuration"
+ );
+ }
+
+ // -- Canonical server URI tests ------------------------------------------
+
+ #[test]
+ fn test_canonical_server_uri_simple() {
+ let url = Url::parse("https://mcp.example.com").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_with_path() {
+ let url = Url::parse("https://mcp.example.com/v1/mcp").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com/v1/mcp");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_strips_trailing_slash() {
+ let url = Url::parse("https://mcp.example.com/").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_preserves_port() {
+ let url = Url::parse("https://mcp.example.com:8443").unwrap();
+ assert_eq!(canonical_server_uri(&url), "https://mcp.example.com:8443");
+ }
+
+ #[test]
+ fn test_canonical_server_uri_lowercases() {
+ let url = Url::parse("HTTPS://MCP.Example.COM/Server/MCP").unwrap();
+ assert_eq!(
+ canonical_server_uri(&url),
+ "https://mcp.example.com/Server/MCP"
+ );
+ }
+
+ // -- Scope selection tests -----------------------------------------------
+
+ #[test]
+ fn test_select_scopes_prefers_www_authenticate() {
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: Some(vec!["files:read".into()]),
+ error: None,
+ error_description: None,
+ };
+ let resource_meta = ProtectedResourceMetadata {
+ resource: Url::parse("https://example.com").unwrap(),
+ authorization_servers: vec![],
+ scopes_supported: Some(vec!["files:read".into(), "files:write".into()]),
+ };
+ assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["files:read"]);
+ }
+
+ #[test]
+ fn test_select_scopes_falls_back_to_resource_metadata() {
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+ let resource_meta = ProtectedResourceMetadata {
+ resource: Url::parse("https://example.com").unwrap(),
+ authorization_servers: vec![],
+ scopes_supported: Some(vec!["admin".into()]),
+ };
+ assert_eq!(select_scopes(&www_auth, &resource_meta), vec!["admin"]);
+ }
+
+ #[test]
+ fn test_select_scopes_empty_when_nothing_available() {
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+ let resource_meta = ProtectedResourceMetadata {
+ resource: Url::parse("https://example.com").unwrap(),
+ authorization_servers: vec![],
+ scopes_supported: None,
+ };
+ assert!(select_scopes(&www_auth, &resource_meta).is_empty());
+ }
+
+ // -- Client registration strategy tests ----------------------------------
+
+ #[test]
+ fn test_registration_strategy_prefers_cimd() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: Some(Url::parse("https://auth.example.com/register").unwrap()),
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+ assert_eq!(
+ determine_registration_strategy(&metadata),
+ ClientRegistrationStrategy::Cimd {
+ client_id: CIMD_URL.to_string(),
+ }
+ );
+ }
+
+ #[test]
+ fn test_registration_strategy_falls_back_to_dcr() {
+ let reg_endpoint = Url::parse("https://auth.example.com/register").unwrap();
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: Some(reg_endpoint.clone()),
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: false,
+ };
+ assert_eq!(
+ determine_registration_strategy(&metadata),
+ ClientRegistrationStrategy::Dcr {
+ registration_endpoint: reg_endpoint,
+ }
+ );
+ }
+
+ #[test]
+ fn test_registration_strategy_unavailable() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: false,
+ };
+ assert_eq!(
+ determine_registration_strategy(&metadata),
+ ClientRegistrationStrategy::Unavailable,
+ );
+ }
+
+ // -- PKCE tests ----------------------------------------------------------
+
+ #[test]
+ fn test_pkce_challenge_verifier_length() {
+ let pkce = generate_pkce_challenge();
+ // 32 random bytes โ 43 base64url chars (no padding).
+ assert_eq!(pkce.verifier.len(), 43);
+ }
+
+ #[test]
+ fn test_pkce_challenge_is_valid_base64url() {
+ let pkce = generate_pkce_challenge();
+ for c in pkce.verifier.chars().chain(pkce.challenge.chars()) {
+ assert!(
+ c.is_ascii_alphanumeric() || c == '-' || c == '_',
+ "invalid base64url character: {}",
+ c
+ );
+ }
+ }
+
+ #[test]
+ fn test_pkce_challenge_is_s256_of_verifier() {
+ let pkce = generate_pkce_challenge();
+ let engine = base64::engine::general_purpose::URL_SAFE_NO_PAD;
+ let expected_digest = Sha256::digest(pkce.verifier.as_bytes());
+ let expected_challenge = engine.encode(expected_digest);
+ assert_eq!(pkce.challenge, expected_challenge);
+ }
+
+ #[test]
+ fn test_pkce_challenges_are_unique() {
+ let a = generate_pkce_challenge();
+ let b = generate_pkce_challenge();
+ assert_ne!(a.verifier, b.verifier);
+ }
+
+ // -- Authorization URL tests ---------------------------------------------
+
+ #[test]
+ fn test_build_authorization_url() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+ let pkce = PkceChallenge {
+ verifier: "test_verifier".into(),
+ challenge: "test_challenge".into(),
+ };
+ let url = build_authorization_url(
+ &metadata,
+ "https://zed.dev/oauth/client-metadata.json",
+ "http://127.0.0.1:12345/callback",
+ &["files:read".into(), "files:write".into()],
+ "https://mcp.example.com",
+ &pkce,
+ "random_state_123",
+ );
+
+ let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
+ assert_eq!(pairs.get("response_type").unwrap(), "code");
+ assert_eq!(
+ pairs.get("client_id").unwrap(),
+ "https://zed.dev/oauth/client-metadata.json"
+ );
+ assert_eq!(
+ pairs.get("redirect_uri").unwrap(),
+ "http://127.0.0.1:12345/callback"
+ );
+ assert_eq!(pairs.get("scope").unwrap(), "files:read files:write");
+ assert_eq!(pairs.get("resource").unwrap(), "https://mcp.example.com");
+ assert_eq!(pairs.get("code_challenge").unwrap(), "test_challenge");
+ assert_eq!(pairs.get("code_challenge_method").unwrap(), "S256");
+ assert_eq!(pairs.get("state").unwrap(), "random_state_123");
+ }
+
+ #[test]
+ fn test_build_authorization_url_omits_empty_scope() {
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: false,
+ };
+ let pkce = PkceChallenge {
+ verifier: "v".into(),
+ challenge: "c".into(),
+ };
+ let url = build_authorization_url(
+ &metadata,
+ "client_123",
+ "http://127.0.0.1:9999/callback",
+ &[],
+ "https://mcp.example.com",
+ &pkce,
+ "state",
+ );
+
+ let pairs: std::collections::HashMap<_, _> = url.query_pairs().collect();
+ assert!(!pairs.contains_key("scope"));
+ }
+
+ // -- Token exchange / refresh param tests --------------------------------
+
+ #[test]
+ fn test_token_exchange_params() {
+ let params = token_exchange_params(
+ "auth_code_abc",
+ "client_xyz",
+ "http://127.0.0.1:5555/callback",
+ "verifier_123",
+ "https://mcp.example.com",
+ );
+ let map: std::collections::HashMap<&str, &str> =
+ params.iter().map(|(k, v)| (*k, v.as_str())).collect();
+
+ assert_eq!(map["grant_type"], "authorization_code");
+ assert_eq!(map["code"], "auth_code_abc");
+ assert_eq!(map["redirect_uri"], "http://127.0.0.1:5555/callback");
+ assert_eq!(map["client_id"], "client_xyz");
+ assert_eq!(map["code_verifier"], "verifier_123");
+ assert_eq!(map["resource"], "https://mcp.example.com");
+ }
+
+ #[test]
+ fn test_token_refresh_params() {
+ let params =
+ token_refresh_params("refresh_token_abc", "client_xyz", "https://mcp.example.com");
+ let map: std::collections::HashMap<&str, &str> =
+ params.iter().map(|(k, v)| (*k, v.as_str())).collect();
+
+ assert_eq!(map["grant_type"], "refresh_token");
+ assert_eq!(map["refresh_token"], "refresh_token_abc");
+ assert_eq!(map["client_id"], "client_xyz");
+ assert_eq!(map["resource"], "https://mcp.example.com");
+ }
+
+ // -- Token response tests ------------------------------------------------
+
+ #[test]
+ fn test_token_response_into_tokens_with_expiry() {
+ let response: TokenResponse = serde_json::from_str(
+ r#"{"access_token": "at_123", "refresh_token": "rt_456", "expires_in": 3600, "token_type": "Bearer"}"#,
+ )
+ .unwrap();
+
+ let tokens = response.into_tokens();
+ assert_eq!(tokens.access_token, "at_123");
+ assert_eq!(tokens.refresh_token.as_deref(), Some("rt_456"));
+ assert!(tokens.expires_at.is_some());
+ }
+
+ #[test]
+ fn test_token_response_into_tokens_minimal() {
+ let response: TokenResponse =
+ serde_json::from_str(r#"{"access_token": "at_789"}"#).unwrap();
+
+ let tokens = response.into_tokens();
+ assert_eq!(tokens.access_token, "at_789");
+ assert_eq!(tokens.refresh_token, None);
+ assert_eq!(tokens.expires_at, None);
+ }
+
+ // -- DCR body test -------------------------------------------------------
+
+ #[test]
+ fn test_dcr_registration_body_shape() {
+ let body = dcr_registration_body("http://127.0.0.1:12345/callback");
+ assert_eq!(body["client_name"], "Zed");
+ assert_eq!(body["redirect_uris"][0], "http://127.0.0.1:12345/callback");
+ assert_eq!(body["grant_types"][0], "authorization_code");
+ assert_eq!(body["response_types"][0], "code");
+ assert_eq!(body["token_endpoint_auth_method"], "none");
+ }
+
+ // -- Test helpers for async/HTTP tests -----------------------------------
+
+ fn make_fake_http_client(
+ handler: impl Fn(
+ http_client::Request<AsyncBody>,
+ ) -> std::pin::Pin<
+ Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
+ > + Send
+ + Sync
+ + 'static,
+ ) -> Arc<dyn HttpClient> {
+ http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
+ }
+
+ fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
+ Ok(Response::builder()
+ .status(status)
+ .header("Content-Type", "application/json")
+ .body(AsyncBody::from(body.as_bytes().to_vec()))
+ .unwrap())
+ }
+
+ // -- Discovery integration tests -----------------------------------------
+
+ #[test]
+ fn test_fetch_protected_resource_metadata() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains(".well-known/oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"],
+ "scopes_supported": ["read", "write"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
+ .await
+ .unwrap();
+
+ assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
+ assert_eq!(metadata.authorization_servers.len(), 1);
+ assert_eq!(
+ metadata.authorization_servers[0].as_str(),
+ "https://auth.example.com/"
+ );
+ assert_eq!(
+ metadata.scopes_supported,
+ Some(vec!["read".to_string(), "write".to_string()])
+ );
+ });
+ }
+
+ #[test]
+ fn test_fetch_protected_resource_metadata_prefers_www_authenticate_url() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri == "https://mcp.example.com/custom-resource-metadata" {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else {
+ json_response(500, r#"{"error": "should not be called"}"#)
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: Some(
+ Url::parse("https://mcp.example.com/custom-resource-metadata").unwrap(),
+ ),
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
+ .await
+ .unwrap();
+
+ assert_eq!(metadata.authorization_servers.len(), 1);
+ });
+ }
+
+ #[test]
+ fn test_fetch_protected_resource_metadata_rejects_cross_origin_url() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ // The cross-origin URL should NOT be fetched; only the
+ // well-known fallback at the server's own origin should be.
+ if uri.contains("attacker.example.com") {
+ panic!("should not fetch cross-origin resource_metadata URL");
+ } else if uri.contains(".well-known/oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: Some(
+ Url::parse("https://attacker.example.com/fake-metadata").unwrap(),
+ ),
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let metadata = fetch_protected_resource_metadata(&client, &server_url, &www_auth)
+ .await
+ .unwrap();
+
+ // Should have used the fallback well-known URL, not the attacker's.
+ assert_eq!(metadata.resource.as_str(), "https://mcp.example.com/");
+ });
+ }
+
+ #[test]
+ fn test_fetch_auth_server_metadata() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains(".well-known/oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "registration_endpoint": "https://auth.example.com/register",
+ "code_challenge_methods_supported": ["S256"],
+ "client_id_metadata_document_supported": true
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
+
+ assert_eq!(metadata.issuer.as_str(), "https://auth.example.com/");
+ assert_eq!(
+ metadata.authorization_endpoint.as_str(),
+ "https://auth.example.com/authorize"
+ );
+ assert_eq!(
+ metadata.token_endpoint.as_str(),
+ "https://auth.example.com/token"
+ );
+ assert!(metadata.registration_endpoint.is_some());
+ assert!(metadata.client_id_metadata_document_supported);
+ assert_eq!(
+ metadata.code_challenge_methods_supported,
+ Some(vec!["S256".to_string()])
+ );
+ });
+ }
+
+ #[test]
+ fn test_fetch_auth_server_metadata_falls_back_to_oidc() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("openid-configuration") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "code_challenge_methods_supported": ["S256"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let metadata = fetch_auth_server_metadata(&client, &issuer).await.unwrap();
+
+ assert_eq!(
+ metadata.authorization_endpoint.as_str(),
+ "https://auth.example.com/authorize"
+ );
+ assert!(!metadata.client_id_metadata_document_supported);
+ });
+ }
+
+ #[test]
+ fn test_fetch_auth_server_metadata_rejects_issuer_mismatch() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains(".well-known/oauth-authorization-server") {
+ // Response claims to be a different issuer.
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://evil.example.com",
+ "authorization_endpoint": "https://evil.example.com/authorize",
+ "token_endpoint": "https://evil.example.com/token",
+ "code_challenge_methods_supported": ["S256"]
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let issuer = Url::parse("https://auth.example.com").unwrap();
+ let result = fetch_auth_server_metadata(&client, &issuer).await;
+
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("issuer mismatch"),
+ "unexpected error: {}",
+ err_msg
+ );
+ });
+ }
+
+ // -- Full discover integration tests -------------------------------------
+
+ #[test]
+ fn test_full_discover_with_cimd() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"],
+ "scopes_supported": ["mcp:read"]
+ }"#,
+ )
+ } else if uri.contains("oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "code_challenge_methods_supported": ["S256"],
+ "client_id_metadata_document_supported": true
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
+ let registration =
+ resolve_client_registration(&client, &discovery, "http://127.0.0.1:12345/callback")
+ .await
+ .unwrap();
+
+ assert_eq!(registration.client_id, CIMD_URL);
+ assert_eq!(registration.client_secret, None);
+ assert_eq!(discovery.scopes, vec!["mcp:read"]);
+ });
+ }
+
+ #[test]
+ fn test_full_discover_with_dcr_fallback() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else if uri.contains("oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token",
+ "registration_endpoint": "https://auth.example.com/register",
+ "code_challenge_methods_supported": ["S256"],
+ "client_id_metadata_document_supported": false
+ }"#,
+ )
+ } else if uri.contains("/register") {
+ json_response(
+ 201,
+ r#"{
+ "client_id": "dcr-minted-id-123",
+ "client_secret": "dcr-secret-456"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: Some(vec!["files:read".into()]),
+ error: None,
+ error_description: None,
+ };
+
+ let discovery = discover(&client, &server_url, &www_auth).await.unwrap();
+ let registration =
+ resolve_client_registration(&client, &discovery, "http://127.0.0.1:9999/callback")
+ .await
+ .unwrap();
+
+ assert_eq!(registration.client_id, "dcr-minted-id-123");
+ assert_eq!(
+ registration.client_secret.as_deref(),
+ Some("dcr-secret-456")
+ );
+ assert_eq!(discovery.scopes, vec!["files:read"]);
+ });
+ }
+
+ #[test]
+ fn test_discover_fails_without_pkce_support() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("oauth-protected-resource") {
+ json_response(
+ 200,
+ r#"{
+ "resource": "https://mcp.example.com",
+ "authorization_servers": ["https://auth.example.com"]
+ }"#,
+ )
+ } else if uri.contains("oauth-authorization-server") {
+ json_response(
+ 200,
+ r#"{
+ "issuer": "https://auth.example.com",
+ "authorization_endpoint": "https://auth.example.com/authorize",
+ "token_endpoint": "https://auth.example.com/token"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let server_url = Url::parse("https://mcp.example.com").unwrap();
+ let www_auth = WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+
+ let result = discover(&client, &server_url, &www_auth).await;
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("code_challenge_methods_supported"),
+ "unexpected error: {}",
+ err_msg
+ );
+ });
+ }
+
+ // -- Token exchange integration tests ------------------------------------
+
+ #[test]
+ fn test_exchange_code_success() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("/token") {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "new_access_token",
+ "refresh_token": "new_refresh_token",
+ "expires_in": 3600,
+ "token_type": "Bearer"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+
+ let tokens = exchange_code(
+ &client,
+ &metadata,
+ "auth_code_123",
+ CIMD_URL,
+ "http://127.0.0.1:9999/callback",
+ "verifier_abc",
+ "https://mcp.example.com",
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(tokens.access_token, "new_access_token");
+ assert_eq!(tokens.refresh_token.as_deref(), Some("new_refresh_token"));
+ assert!(tokens.expires_at.is_some());
+ });
+ }
+
+ #[test]
+ fn test_refresh_tokens_success() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|req| {
+ Box::pin(async move {
+ let uri = req.uri().to_string();
+ if uri.contains("/token") {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "refreshed_token",
+ "expires_in": 1800,
+ "token_type": "Bearer"
+ }"#,
+ )
+ } else {
+ json_response(404, "{}")
+ }
+ })
+ });
+
+ let token_endpoint = Url::parse("https://auth.example.com/token").unwrap();
+
+ let tokens = refresh_tokens(
+ &client,
+ &token_endpoint,
+ "old_refresh_token",
+ CIMD_URL,
+ "https://mcp.example.com",
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(tokens.access_token, "refreshed_token");
+ assert_eq!(tokens.refresh_token, None);
+ assert!(tokens.expires_at.is_some());
+ });
+ }
+
+ #[test]
+ fn test_exchange_code_failure() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async move { json_response(400, r#"{"error": "invalid_grant"}"#) })
+ });
+
+ let metadata = AuthServerMetadata {
+ issuer: Url::parse("https://auth.example.com").unwrap(),
+ authorization_endpoint: Url::parse("https://auth.example.com/authorize").unwrap(),
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ registration_endpoint: None,
+ scopes_supported: None,
+ code_challenge_methods_supported: Some(vec!["S256".into()]),
+ client_id_metadata_document_supported: true,
+ };
+
+ let result = exchange_code(
+ &client,
+ &metadata,
+ "bad_code",
+ "client",
+ "http://127.0.0.1:1/callback",
+ "verifier",
+ "https://mcp.example.com",
+ )
+ .await;
+
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("400"));
+ });
+ }
+
+ // -- DCR integration tests -----------------------------------------------
+
+ #[test]
+ fn test_perform_dcr() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async move {
+ json_response(
+ 201,
+ r#"{
+ "client_id": "dynamic-client-001",
+ "client_secret": "dynamic-secret-001"
+ }"#,
+ )
+ })
+ });
+
+ let endpoint = Url::parse("https://auth.example.com/register").unwrap();
+ let registration = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback")
+ .await
+ .unwrap();
+
+ assert_eq!(registration.client_id, "dynamic-client-001");
+ assert_eq!(
+ registration.client_secret.as_deref(),
+ Some("dynamic-secret-001")
+ );
+ });
+ }
+
+ #[test]
+ fn test_perform_dcr_failure() {
+ smol::block_on(async {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(
+ async move { json_response(403, r#"{"error": "registration_not_allowed"}"#) },
+ )
+ });
+
+ let endpoint = Url::parse("https://auth.example.com/register").unwrap();
+ let result = perform_dcr(&client, &endpoint, "http://127.0.0.1:9999/callback").await;
+
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("403"));
+ });
+ }
+
+ // -- OAuthCallback parse tests -------------------------------------------
+
+ #[test]
+ fn test_oauth_callback_parse_query() {
+ let callback = OAuthCallback::parse_query("code=test_auth_code&state=test_state").unwrap();
+ assert_eq!(callback.code, "test_auth_code");
+ assert_eq!(callback.state, "test_state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_reversed_order() {
+ let callback = OAuthCallback::parse_query("state=test_state&code=test_auth_code").unwrap();
+ assert_eq!(callback.code, "test_auth_code");
+ assert_eq!(callback.state, "test_state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_with_extra_params() {
+ let callback =
+ OAuthCallback::parse_query("code=test_auth_code&state=test_state&extra=ignored")
+ .unwrap();
+ assert_eq!(callback.code, "test_auth_code");
+ assert_eq!(callback.state, "test_state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_missing_code() {
+ let result = OAuthCallback::parse_query("state=test_state");
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("code"));
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_missing_state() {
+ let result = OAuthCallback::parse_query("code=test_auth_code");
+ assert!(result.is_err());
+ assert!(result.unwrap_err().to_string().contains("state"));
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_empty_code() {
+ let result = OAuthCallback::parse_query("code=&state=test_state");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_empty_state() {
+ let result = OAuthCallback::parse_query("code=test_auth_code&state=");
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_url_encoded_values() {
+ let callback = OAuthCallback::parse_query("code=abc%20def&state=test%3Dstate").unwrap();
+ assert_eq!(callback.code, "abc def");
+ assert_eq!(callback.state, "test=state");
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_error_response() {
+ let result = OAuthCallback::parse_query(
+ "error=access_denied&error_description=User%20denied%20access&state=abc",
+ );
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("access_denied"),
+ "unexpected error: {}",
+ err_msg
+ );
+ assert!(
+ err_msg.contains("User denied access"),
+ "unexpected error: {}",
+ err_msg
+ );
+ }
+
+ #[test]
+ fn test_oauth_callback_parse_query_error_without_description() {
+ let result = OAuthCallback::parse_query("error=server_error&state=abc");
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("server_error"),
+ "unexpected error: {}",
+ err_msg
+ );
+ assert!(
+ err_msg.contains("no description"),
+ "unexpected error: {}",
+ err_msg
+ );
+ }
+
+ // -- McpOAuthTokenProvider tests -----------------------------------------
+
+ fn make_test_session(
+ access_token: &str,
+ refresh_token: Option<&str>,
+ expires_at: Option<SystemTime>,
+ ) -> OAuthSession {
+ OAuthSession {
+ token_endpoint: Url::parse("https://auth.example.com/token").unwrap(),
+ resource: Url::parse("https://mcp.example.com").unwrap(),
+ client_registration: OAuthClientRegistration {
+ client_id: "test-client".into(),
+ client_secret: None,
+ },
+ tokens: OAuthTokens {
+ access_token: access_token.into(),
+ refresh_token: refresh_token.map(String::from),
+ expires_at,
+ },
+ }
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_returns_none_when_token_expired() {
+ let expired = SystemTime::now() - Duration::from_secs(60);
+ let session = make_test_session("stale-token", Some("rt"), Some(expired));
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| Box::pin(async { unreachable!() })),
+ None,
+ );
+
+ assert_eq!(provider.access_token(), None);
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_returns_token_when_not_expired() {
+ let far_future = SystemTime::now() + Duration::from_secs(3600);
+ let session = make_test_session("valid-token", Some("rt"), Some(far_future));
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| Box::pin(async { unreachable!() })),
+ None,
+ );
+
+ assert_eq!(provider.access_token().as_deref(), Some("valid-token"));
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_returns_token_when_no_expiry() {
+ let session = make_test_session("no-expiry-token", Some("rt"), None);
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| Box::pin(async { unreachable!() })),
+ None,
+ );
+
+ assert_eq!(provider.access_token().as_deref(), Some("no-expiry-token"));
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_without_refresh_token_returns_false() {
+ smol::block_on(async {
+ let session = make_test_session("token", None, None);
+ let provider = McpOAuthTokenProvider::new(
+ session,
+ make_fake_http_client(|_| {
+ Box::pin(async { unreachable!("no HTTP call expected") })
+ }),
+ None,
+ );
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(!refreshed);
+ });
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_updates_session_and_notifies_channel() {
+ smol::block_on(async {
+ let session = make_test_session("old-access", Some("my-refresh-token"), None);
+ let (tx, mut rx) = futures::channel::mpsc::unbounded();
+
+ let http_client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "new-access",
+ "refresh_token": "new-refresh",
+ "expires_in": 1800
+ }"#,
+ )
+ })
+ });
+
+ let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(refreshed);
+ assert_eq!(provider.access_token().as_deref(), Some("new-access"));
+
+ let notified_session = rx
+ .try_next()
+ .unwrap()
+ .expect("channel should have a session");
+ assert_eq!(notified_session.tokens.access_token, "new-access");
+ assert_eq!(
+ notified_session.tokens.refresh_token.as_deref(),
+ Some("new-refresh")
+ );
+ });
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_preserves_old_refresh_token_when_server_omits_it() {
+ smol::block_on(async {
+ let session = make_test_session("old-access", Some("original-refresh"), None);
+ let (tx, mut rx) = futures::channel::mpsc::unbounded();
+
+ let http_client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ json_response(
+ 200,
+ r#"{
+ "access_token": "new-access",
+ "expires_in": 900
+ }"#,
+ )
+ })
+ });
+
+ let provider = McpOAuthTokenProvider::new(session, http_client, Some(tx));
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(refreshed);
+
+ let notified_session = rx
+ .try_next()
+ .unwrap()
+ .expect("channel should have a session");
+ assert_eq!(notified_session.tokens.access_token, "new-access");
+ assert_eq!(
+ notified_session.tokens.refresh_token.as_deref(),
+ Some("original-refresh"),
+ );
+ });
+ }
+
+ #[test]
+ fn test_mcp_oauth_provider_refresh_returns_false_on_http_error() {
+ smol::block_on(async {
+ let session = make_test_session("old-access", Some("my-refresh"), None);
+
+ let http_client = make_fake_http_client(|_req| {
+ Box::pin(async { json_response(401, r#"{"error": "invalid_grant"}"#) })
+ });
+
+ let provider = McpOAuthTokenProvider::new(session, http_client, None);
+
+ let refreshed = provider.try_refresh().await.unwrap();
+ assert!(!refreshed);
+ // The old token should still be in place.
+ assert_eq!(provider.access_token().as_deref(), Some("old-access"));
+ });
+ }
+}
@@ -8,8 +8,30 @@ use parking_lot::Mutex as SyncMutex;
use smol::channel;
use std::{pin::Pin, sync::Arc};
+use crate::oauth::{self, OAuthTokenProvider, WwwAuthenticate};
use crate::transport::Transport;
+/// Typed errors returned by the HTTP transport that callers can downcast from
+/// `anyhow::Error` to handle specific failure modes.
+#[derive(Debug)]
+pub enum TransportError {
+ /// The server returned 401 and token refresh either wasn't possible or
+ /// failed. The caller should initiate the OAuth authorization flow.
+ AuthRequired { www_authenticate: WwwAuthenticate },
+}
+
+impl std::fmt::Display for TransportError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ TransportError::AuthRequired { .. } => {
+ write!(f, "OAuth authorization required")
+ }
+ }
+ }
+}
+
+impl std::error::Error for TransportError {}
+
// Constants from MCP spec
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
@@ -25,8 +47,11 @@ pub struct HttpTransport {
response_rx: channel::Receiver<String>,
error_tx: channel::Sender<String>,
error_rx: channel::Receiver<String>,
- // Authentication headers to include in requests
+ /// Static headers to include in every request (e.g. from server config).
headers: HashMap<String, String>,
+ /// When set, the transport attaches `Authorization: Bearer` headers and
+ /// handles 401 responses with token refresh + retry.
+ token_provider: Option<Arc<dyn OAuthTokenProvider>>,
}
impl HttpTransport {
@@ -35,6 +60,16 @@ impl HttpTransport {
endpoint: String,
headers: HashMap<String, String>,
executor: BackgroundExecutor,
+ ) -> Self {
+ Self::new_with_token_provider(http_client, endpoint, headers, executor, None)
+ }
+
+ pub fn new_with_token_provider(
+ http_client: Arc<dyn HttpClient>,
+ endpoint: String,
+ headers: HashMap<String, String>,
+ executor: BackgroundExecutor,
+ token_provider: Option<Arc<dyn OAuthTokenProvider>>,
) -> Self {
let (response_tx, response_rx) = channel::unbounded();
let (error_tx, error_rx) = channel::unbounded();
@@ -49,14 +84,14 @@ impl HttpTransport {
error_tx,
error_rx,
headers,
+ token_provider,
}
}
- /// Send a message and handle the response based on content type
- async fn send_message(&self, message: String) -> Result<()> {
- let is_notification =
- !message.contains("\"id\":") || message.contains("notifications/initialized");
-
+ /// Build a POST request for the given message body, attaching all standard
+ /// headers (content-type, accept, session ID, static headers, and bearer
+ /// token if available).
+ fn build_request(&self, message: &[u8]) -> Result<http_client::Request<AsyncBody>> {
let mut request_builder = Request::builder()
.method(Method::POST)
.uri(&self.endpoint)
@@ -70,15 +105,71 @@ impl HttpTransport {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
- // Add session ID if we have one (except for initialize)
+ // Attach bearer token when a token provider is present.
+ if let Some(token) = self.token_provider.as_ref().and_then(|p| p.access_token()) {
+ request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
+ }
+
+ // Add session ID if we have one (except for initialize).
if let Some(ref session_id) = *self.session_id.lock() {
request_builder = request_builder.header(HEADER_SESSION_ID, session_id.as_str());
}
- let request = request_builder.body(AsyncBody::from(message.into_bytes()))?;
+ Ok(request_builder.body(AsyncBody::from(message.to_vec()))?)
+ }
+
+ /// Send a message and handle the response based on content type.
+ async fn send_message(&self, message: String) -> Result<()> {
+ let is_notification =
+ !message.contains("\"id\":") || message.contains("notifications/initialized");
+
+ // If we currently have no access token, try refreshing before sending
+ // the request so restored but expired sessions do not need an initial
+ // 401 round-trip before they can recover.
+ if let Some(ref provider) = self.token_provider {
+ if provider.access_token().is_none() {
+ provider.try_refresh().await.unwrap_or(false);
+ }
+ }
+
+ let request = self.build_request(message.as_bytes())?;
let mut response = self.http_client.send(request).await?;
- // Handle different response types based on status and content-type
+ // On 401, try refreshing the token and retry once.
+ if response.status().as_u16() == 401 {
+ let www_auth_header = response
+ .headers()
+ .get("www-authenticate")
+ .and_then(|v| v.to_str().ok())
+ .unwrap_or("Bearer");
+
+ let www_authenticate =
+ oauth::parse_www_authenticate(www_auth_header).unwrap_or(WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ });
+
+ if let Some(ref provider) = self.token_provider {
+ if provider.try_refresh().await.unwrap_or(false) {
+ // Retry with the refreshed token.
+ let retry_request = self.build_request(message.as_bytes())?;
+ response = self.http_client.send(retry_request).await?;
+
+ // If still 401 after refresh, give up.
+ if response.status().as_u16() == 401 {
+ return Err(TransportError::AuthRequired { www_authenticate }.into());
+ }
+ } else {
+ return Err(TransportError::AuthRequired { www_authenticate }.into());
+ }
+ } else {
+ return Err(TransportError::AuthRequired { www_authenticate }.into());
+ }
+ }
+
+ // Handle different response types based on status and content-type.
match response.status() {
status if status.is_success() => {
// Check content type
@@ -233,6 +324,7 @@ impl Drop for HttpTransport {
let endpoint = self.endpoint.clone();
let session_id = self.session_id.lock().clone();
let headers = self.headers.clone();
+ let access_token = self.token_provider.as_ref().and_then(|p| p.access_token());
if let Some(session_id) = session_id {
self.executor
@@ -242,11 +334,17 @@ impl Drop for HttpTransport {
.uri(&endpoint)
.header(HEADER_SESSION_ID, &session_id);
- // Add authentication headers if present
+ // Add static authentication headers.
for (key, value) in headers {
request_builder = request_builder.header(key.as_str(), value.as_str());
}
+ // Attach bearer token if available.
+ if let Some(token) = access_token {
+ request_builder =
+ request_builder.header("Authorization", format!("Bearer {}", token));
+ }
+
let request = request_builder.body(AsyncBody::empty());
if let Ok(request) = request {
@@ -257,3 +355,402 @@ impl Drop for HttpTransport {
}
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use async_trait::async_trait;
+ use gpui::TestAppContext;
+ use parking_lot::Mutex as SyncMutex;
+ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
+
+ /// A mock token provider that returns a configurable token and tracks
+ /// refresh attempts.
+ struct FakeTokenProvider {
+ token: SyncMutex<Option<String>>,
+ refreshed_token: SyncMutex<Option<String>>,
+ refresh_succeeds: AtomicBool,
+ refresh_count: AtomicUsize,
+ }
+
+ impl FakeTokenProvider {
+ fn new(token: Option<&str>, refresh_succeeds: bool) -> Arc<Self> {
+ Self::with_refreshed_token(token, None, refresh_succeeds)
+ }
+
+ fn with_refreshed_token(
+ token: Option<&str>,
+ refreshed_token: Option<&str>,
+ refresh_succeeds: bool,
+ ) -> Arc<Self> {
+ Arc::new(Self {
+ token: SyncMutex::new(token.map(String::from)),
+ refreshed_token: SyncMutex::new(refreshed_token.map(String::from)),
+ refresh_succeeds: AtomicBool::new(refresh_succeeds),
+ refresh_count: AtomicUsize::new(0),
+ })
+ }
+
+ fn set_token(&self, token: &str) {
+ *self.token.lock() = Some(token.to_string());
+ }
+
+ fn refresh_count(&self) -> usize {
+ self.refresh_count.load(Ordering::SeqCst)
+ }
+ }
+
+ #[async_trait]
+ impl OAuthTokenProvider for FakeTokenProvider {
+ fn access_token(&self) -> Option<String> {
+ self.token.lock().clone()
+ }
+
+ async fn try_refresh(&self) -> Result<bool> {
+ self.refresh_count.fetch_add(1, Ordering::SeqCst);
+
+ let refresh_succeeds = self.refresh_succeeds.load(Ordering::SeqCst);
+ if refresh_succeeds {
+ if let Some(token) = self.refreshed_token.lock().clone() {
+ *self.token.lock() = Some(token);
+ }
+ }
+
+ Ok(refresh_succeeds)
+ }
+ }
+
+ fn make_fake_http_client(
+ handler: impl Fn(
+ http_client::Request<AsyncBody>,
+ ) -> std::pin::Pin<
+ Box<dyn std::future::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send>,
+ > + Send
+ + Sync
+ + 'static,
+ ) -> Arc<dyn HttpClient> {
+ http_client::FakeHttpClient::create(handler) as Arc<dyn HttpClient>
+ }
+
+ fn json_response(status: u16, body: &str) -> anyhow::Result<Response<AsyncBody>> {
+ Ok(Response::builder()
+ .status(status)
+ .header("Content-Type", "application/json")
+ .body(AsyncBody::from(body.as_bytes().to_vec()))
+ .unwrap())
+ }
+
+ #[gpui::test]
+ async fn test_bearer_token_attached_to_requests(cx: &mut TestAppContext) {
+ // Capture the Authorization header from the request.
+ let captured_auth = Arc::new(SyncMutex::new(None::<String>));
+ let captured_auth_clone = captured_auth.clone();
+
+ let client = make_fake_http_client(move |req| {
+ let auth = req
+ .headers()
+ .get("Authorization")
+ .map(|v| v.to_str().unwrap().to_string());
+ *captured_auth_clone.lock() = auth;
+ Box::pin(async { json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) })
+ });
+
+ let provider = FakeTokenProvider::new(Some("test-access-token"), false);
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider),
+ );
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed");
+
+ assert_eq!(
+ captured_auth.lock().as_deref(),
+ Some("Bearer test-access-token"),
+ );
+ }
+
+ #[gpui::test]
+ async fn test_no_bearer_token_without_provider(cx: &mut TestAppContext) {
+ let captured_auth = Arc::new(SyncMutex::new(None::<String>));
+ let captured_auth_clone = captured_auth.clone();
+
+ let client = make_fake_http_client(move |req| {
+ let auth = req
+ .headers()
+ .get("Authorization")
+ .map(|v| v.to_str().unwrap().to_string());
+ *captured_auth_clone.lock() = auth;
+ Box::pin(async { json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) })
+ });
+
+ let transport = HttpTransport::new(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ );
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed");
+
+ assert!(captured_auth.lock().is_none());
+ }
+
+ #[gpui::test]
+ async fn test_missing_token_triggers_refresh_before_first_request(cx: &mut TestAppContext) {
+ let captured_auth = Arc::new(SyncMutex::new(None::<String>));
+ let captured_auth_clone = captured_auth.clone();
+
+ let client = make_fake_http_client(move |req| {
+ let auth = req
+ .headers()
+ .get("Authorization")
+ .map(|v| v.to_str().unwrap().to_string());
+ *captured_auth_clone.lock() = auth;
+ Box::pin(async { json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#) })
+ });
+
+ let provider = FakeTokenProvider::with_refreshed_token(None, Some("refreshed-token"), true);
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed after proactive refresh");
+
+ assert_eq!(provider.refresh_count(), 1);
+ assert_eq!(
+ captured_auth.lock().as_deref(),
+ Some("Bearer refreshed-token"),
+ );
+ }
+
+ #[gpui::test]
+ async fn test_invalid_token_still_triggers_refresh_and_retry(cx: &mut TestAppContext) {
+ let request_count = Arc::new(AtomicUsize::new(0));
+ let request_count_clone = request_count.clone();
+
+ let client = make_fake_http_client(move |_req| {
+ let count = request_count_clone.fetch_add(1, Ordering::SeqCst);
+ Box::pin(async move {
+ if count == 0 {
+ Ok(Response::builder()
+ .status(401)
+ .header(
+ "WWW-Authenticate",
+ r#"Bearer error="invalid_token", resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#,
+ )
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ } else {
+ json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#)
+ }
+ })
+ });
+
+ let provider = FakeTokenProvider::with_refreshed_token(
+ Some("old-token"),
+ Some("refreshed-token"),
+ true,
+ );
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed after refresh");
+
+ assert_eq!(provider.refresh_count(), 1);
+ assert_eq!(request_count.load(Ordering::SeqCst), 2);
+ }
+
+ #[gpui::test]
+ async fn test_401_triggers_refresh_and_retry(cx: &mut TestAppContext) {
+ let request_count = Arc::new(AtomicUsize::new(0));
+ let request_count_clone = request_count.clone();
+
+ let client = make_fake_http_client(move |_req| {
+ let count = request_count_clone.fetch_add(1, Ordering::SeqCst);
+ Box::pin(async move {
+ if count == 0 {
+ // First request: 401.
+ Ok(Response::builder()
+ .status(401)
+ .header(
+ "WWW-Authenticate",
+ r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#,
+ )
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ } else {
+ // Retry after refresh: 200.
+ json_response(200, r#"{"jsonrpc":"2.0","id":1,"result":{}}"#)
+ }
+ })
+ });
+
+ let provider = FakeTokenProvider::new(Some("old-token"), true);
+ // Simulate the refresh updating the token.
+ let provider_ref = provider.clone();
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ // Set the new token that will be used on retry.
+ provider_ref.set_token("refreshed-token");
+
+ transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .expect("send should succeed after refresh");
+
+ assert_eq!(provider_ref.refresh_count(), 1);
+ assert_eq!(request_count.load(Ordering::SeqCst), 2);
+ }
+
+ #[gpui::test]
+ async fn test_401_returns_auth_required_when_refresh_fails(cx: &mut TestAppContext) {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ Ok(Response::builder()
+ .status(401)
+ .header(
+ "WWW-Authenticate",
+ r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource", scope="read write""#,
+ )
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ })
+ });
+
+ // Refresh returns false โ no new token available.
+ let provider = FakeTokenProvider::new(Some("stale-token"), false);
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ let err = transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .unwrap_err();
+
+ let transport_err = err
+ .downcast_ref::<TransportError>()
+ .expect("error should be TransportError");
+ match transport_err {
+ TransportError::AuthRequired { www_authenticate } => {
+ assert_eq!(
+ www_authenticate
+ .resource_metadata
+ .as_ref()
+ .map(|u| u.as_str()),
+ Some("https://mcp.example.com/.well-known/oauth-protected-resource"),
+ );
+ assert_eq!(
+ www_authenticate.scope,
+ Some(vec!["read".to_string(), "write".to_string()]),
+ );
+ }
+ }
+ assert_eq!(provider.refresh_count(), 1);
+ }
+
+ #[gpui::test]
+ async fn test_401_returns_auth_required_without_provider(cx: &mut TestAppContext) {
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ Ok(Response::builder()
+ .status(401)
+ .header("WWW-Authenticate", "Bearer")
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ })
+ });
+
+ // No token provider at all.
+ let transport = HttpTransport::new(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ );
+
+ let err = transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .unwrap_err();
+
+ let transport_err = err
+ .downcast_ref::<TransportError>()
+ .expect("error should be TransportError");
+ match transport_err {
+ TransportError::AuthRequired { www_authenticate } => {
+ assert!(www_authenticate.resource_metadata.is_none());
+ assert!(www_authenticate.scope.is_none());
+ }
+ }
+ }
+
+ #[gpui::test]
+ async fn test_401_after_successful_refresh_still_returns_auth_required(
+ cx: &mut TestAppContext,
+ ) {
+ // Both requests return 401 โ the server rejects the refreshed token too.
+ let client = make_fake_http_client(|_req| {
+ Box::pin(async {
+ Ok(Response::builder()
+ .status(401)
+ .header("WWW-Authenticate", "Bearer")
+ .body(AsyncBody::from(b"Unauthorized".to_vec()))
+ .unwrap())
+ })
+ });
+
+ let provider = FakeTokenProvider::new(Some("token"), true);
+ let transport = HttpTransport::new_with_token_provider(
+ client,
+ "http://mcp.example.com/mcp".to_string(),
+ HashMap::default(),
+ cx.background_executor.clone(),
+ Some(provider.clone()),
+ );
+
+ let err = transport
+ .send(r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#.to_string())
+ .await
+ .unwrap_err();
+
+ err.downcast_ref::<TransportError>()
+ .expect("error should be TransportError");
+ // Refresh was attempted exactly once.
+ assert_eq!(provider.refresh_count(), 1);
+ }
+}
@@ -45,6 +45,7 @@ client.workspace = true
clock.workspace = true
collections.workspace = true
context_server.workspace = true
+credentials_provider.workspace = true
dap.workspace = true
extension.workspace = true
fancy-regex.workspace = true
@@ -7,10 +7,16 @@ use std::time::Duration;
use anyhow::{Context as _, Result};
use collections::{HashMap, HashSet};
+use context_server::oauth::{self, McpOAuthTokenProvider, OAuthDiscovery, OAuthSession};
+use context_server::transport::{HttpTransport, TransportError};
use context_server::{ContextServer, ContextServerCommand, ContextServerId};
-use futures::{FutureExt as _, future::Either, future::join_all};
+use credentials_provider::CredentialsProvider;
+use futures::future::Either;
+use futures::{FutureExt as _, StreamExt as _, future::join_all};
use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Subscription, Task, WeakEntity, actions};
+use http_client::HttpClient;
use itertools::Itertools;
+use rand::Rng as _;
use registry::ContextServerDescriptorRegistry;
use remote::RemoteClient;
use rpc::{AnyProtoClient, TypedEnvelope, proto};
@@ -45,6 +51,12 @@ pub enum ContextServerStatus {
Running,
Stopped,
Error(Arc<str>),
+ /// The server returned 401 and OAuth authorization is needed. The UI
+ /// should show an "Authenticate" button.
+ AuthRequired,
+ /// The OAuth browser flow is in progress โ the user has been redirected
+ /// to the authorization server and we're waiting for the callback.
+ Authenticating,
}
impl ContextServerStatus {
@@ -54,6 +66,8 @@ impl ContextServerStatus {
ContextServerState::Running { .. } => ContextServerStatus::Running,
ContextServerState::Stopped { .. } => ContextServerStatus::Stopped,
ContextServerState::Error { error, .. } => ContextServerStatus::Error(error.clone()),
+ ContextServerState::AuthRequired { .. } => ContextServerStatus::AuthRequired,
+ ContextServerState::Authenticating { .. } => ContextServerStatus::Authenticating,
}
}
}
@@ -77,24 +91,42 @@ enum ContextServerState {
configuration: Arc<ContextServerConfiguration>,
error: Arc<str>,
},
+ /// The server requires OAuth authorization before it can be used. The
+ /// `OAuthDiscovery` holds everything needed to start the browser flow.
+ AuthRequired {
+ server: Arc<ContextServer>,
+ configuration: Arc<ContextServerConfiguration>,
+ discovery: Arc<OAuthDiscovery>,
+ },
+ /// The OAuth browser flow is in progress. The user has been redirected
+ /// to the authorization server and we're waiting for the callback.
+ Authenticating {
+ server: Arc<ContextServer>,
+ configuration: Arc<ContextServerConfiguration>,
+ _task: Task<()>,
+ },
}
impl ContextServerState {
pub fn server(&self) -> Arc<ContextServer> {
match self {
- ContextServerState::Starting { server, .. } => server.clone(),
- ContextServerState::Running { server, .. } => server.clone(),
- ContextServerState::Stopped { server, .. } => server.clone(),
- ContextServerState::Error { server, .. } => server.clone(),
+ ContextServerState::Starting { server, .. }
+ | ContextServerState::Running { server, .. }
+ | ContextServerState::Stopped { server, .. }
+ | ContextServerState::Error { server, .. }
+ | ContextServerState::AuthRequired { server, .. }
+ | ContextServerState::Authenticating { server, .. } => server.clone(),
}
}
pub fn configuration(&self) -> Arc<ContextServerConfiguration> {
match self {
- ContextServerState::Starting { configuration, .. } => configuration.clone(),
- ContextServerState::Running { configuration, .. } => configuration.clone(),
- ContextServerState::Stopped { configuration, .. } => configuration.clone(),
- ContextServerState::Error { configuration, .. } => configuration.clone(),
+ ContextServerState::Starting { configuration, .. }
+ | ContextServerState::Running { configuration, .. }
+ | ContextServerState::Stopped { configuration, .. }
+ | ContextServerState::Error { configuration, .. }
+ | ContextServerState::AuthRequired { configuration, .. }
+ | ContextServerState::Authenticating { configuration, .. } => configuration.clone(),
}
}
}
@@ -126,6 +158,15 @@ impl ContextServerConfiguration {
}
}
+ pub fn has_static_auth_header(&self) -> bool {
+ match self {
+ ContextServerConfiguration::Http { headers, .. } => headers
+ .keys()
+ .any(|k| k.eq_ignore_ascii_case("authorization")),
+ _ => false,
+ }
+ }
+
pub fn remote(&self) -> bool {
match self {
ContextServerConfiguration::Custom { remote, .. } => *remote,
@@ -517,9 +558,10 @@ impl ContextServerStore {
pub fn start_server(&mut self, server: Arc<ContextServer>, cx: &mut Context<Self>) {
cx.spawn(async move |this, cx| {
let this = this.upgrade().context("Context server store dropped")?;
+ let id = server.id();
let settings = this
.update(cx, |this, _| {
- this.context_server_settings.get(&server.id().0).cloned()
+ this.context_server_settings.get(&id.0).cloned()
})
.context("Failed to get context server settings")?;
@@ -532,7 +574,7 @@ impl ContextServerStore {
});
let configuration = ContextServerConfiguration::from_settings(
settings,
- server.id(),
+ id.clone(),
registry,
worktree_store,
cx,
@@ -590,7 +632,11 @@ impl ContextServerStore {
let id = server.id();
if matches!(
self.servers.get(&id),
- Some(ContextServerState::Starting { .. } | ContextServerState::Running { .. })
+ Some(
+ ContextServerState::Starting { .. }
+ | ContextServerState::Running { .. }
+ | ContextServerState::Authenticating { .. },
+ )
) {
self.stop_server(&id, cx).log_err();
}
@@ -600,38 +646,20 @@ impl ContextServerStore {
let configuration = configuration.clone();
async move |this, cx| {
- match server.clone().start(cx).await {
+ let new_state = match server.clone().start(cx).await {
Ok(_) => {
debug_assert!(server.client().is_some());
-
- this.update(cx, |this, cx| {
- this.update_server_state(
- id.clone(),
- ContextServerState::Running {
- server,
- configuration,
- },
- cx,
- )
- })
- .log_err()
- }
- Err(err) => {
- log::error!("{} context server failed to start: {}", id, err);
- this.update(cx, |this, cx| {
- this.update_server_state(
- id.clone(),
- ContextServerState::Error {
- configuration,
- server,
- error: err.to_string().into(),
- },
- cx,
- )
- })
- .log_err()
+ ContextServerState::Running {
+ server,
+ configuration,
+ }
}
+ Err(err) => resolve_start_failure(&id, err, server, configuration, cx).await,
};
+ this.update(cx, |this, cx| {
+ this.update_server_state(id.clone(), new_state, cx)
+ })
+ .log_err();
}
});
@@ -651,6 +679,20 @@ impl ContextServerStore {
.servers
.remove(id)
.context("Context server not found")?;
+
+ if let ContextServerConfiguration::Http { url, .. } = state.configuration().as_ref() {
+ let server_url = url.clone();
+ let id = id.clone();
+ cx.spawn(async move |_this, cx| {
+ let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await
+ {
+ log::warn!("{} failed to clear OAuth session on removal: {}", id, err);
+ }
+ })
+ .detach();
+ }
+
drop(state);
cx.emit(ServerStatusChangedEvent {
server_id: id.clone(),
@@ -742,29 +784,71 @@ impl ContextServerStore {
configuration
};
+ if let Some(server) = this.update(cx, |this, _| {
+ this.context_server_factory
+ .as_ref()
+ .map(|factory| factory(id.clone(), configuration.clone()))
+ })? {
+ return Ok((server, configuration));
+ }
+
+ let cached_token_provider: Option<Arc<dyn oauth::OAuthTokenProvider>> =
+ if let ContextServerConfiguration::Http { url, .. } = configuration.as_ref() {
+ if configuration.has_static_auth_header() {
+ None
+ } else {
+ let credentials_provider =
+ cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ let http_client = cx.update(|cx| cx.http_client());
+
+ match Self::load_session(&credentials_provider, url, &cx).await {
+ Ok(Some(session)) => {
+ log::info!("{} loaded cached OAuth session from keychain", id);
+ Some(Self::create_oauth_token_provider(
+ &id,
+ url,
+ session,
+ http_client,
+ credentials_provider,
+ cx,
+ ))
+ }
+ Ok(None) => None,
+ Err(err) => {
+ log::warn!("{} failed to load cached OAuth session: {}", id, err);
+ None
+ }
+ }
+ }
+ } else {
+ None
+ };
+
let server: Arc<ContextServer> = this.update(cx, |this, cx| {
let global_timeout =
Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
- if let Some(factory) = this.context_server_factory.as_ref() {
- return anyhow::Ok(factory(id.clone(), configuration.clone()));
- }
-
match configuration.as_ref() {
ContextServerConfiguration::Http {
url,
headers,
timeout,
- } => anyhow::Ok(Arc::new(ContextServer::http(
- id,
- url,
- headers.clone(),
- cx.http_client(),
- cx.background_executor().clone(),
- Some(Duration::from_secs(
- timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
- )),
- )?)),
+ } => {
+ let transport = HttpTransport::new_with_token_provider(
+ cx.http_client(),
+ url.to_string(),
+ headers.clone(),
+ cx.background_executor().clone(),
+ cached_token_provider.clone(),
+ );
+ anyhow::Ok(Arc::new(ContextServer::new_with_timeout(
+ id,
+ Arc::new(transport),
+ Some(Duration::from_secs(
+ timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
+ )),
+ )))
+ }
_ => {
let mut command = configuration
.command()
@@ -861,6 +945,310 @@ impl ContextServerStore {
ProjectSettings::get(location, cx)
}
+ fn create_oauth_token_provider(
+ id: &ContextServerId,
+ server_url: &url::Url,
+ session: OAuthSession,
+ http_client: Arc<dyn HttpClient>,
+ credentials_provider: Arc<dyn CredentialsProvider>,
+ cx: &mut AsyncApp,
+ ) -> Arc<dyn oauth::OAuthTokenProvider> {
+ let (token_refresh_tx, mut token_refresh_rx) = futures::channel::mpsc::unbounded();
+ let id = id.clone();
+ let server_url = server_url.clone();
+
+ cx.spawn(async move |cx| {
+ while let Some(refreshed_session) = token_refresh_rx.next().await {
+ if let Err(err) =
+ Self::store_session(&credentials_provider, &server_url, &refreshed_session, &cx)
+ .await
+ {
+ log::warn!("{} failed to persist refreshed OAuth session: {}", id, err);
+ }
+ }
+ log::debug!("{} OAuth session persistence task ended", id);
+ })
+ .detach();
+
+ Arc::new(McpOAuthTokenProvider::new(
+ session,
+ http_client,
+ Some(token_refresh_tx),
+ ))
+ }
+
+ /// Initiate the OAuth browser flow for a server in the `AuthRequired` state.
+ ///
+ /// This starts a loopback HTTP callback server on an ephemeral port, builds
+ /// the authorization URL, opens the user's browser, waits for the callback,
+ /// exchanges the code for tokens, persists them in the keychain, and restarts
+ /// the server with the new token provider.
+ pub fn authenticate_server(
+ &mut self,
+ id: &ContextServerId,
+ cx: &mut Context<Self>,
+ ) -> Result<()> {
+ let state = self.servers.get(id).context("Context server not found")?;
+
+ let (discovery, server, configuration) = match state {
+ ContextServerState::AuthRequired {
+ discovery,
+ server,
+ configuration,
+ } => (discovery.clone(), server.clone(), configuration.clone()),
+ _ => anyhow::bail!("Server is not in AuthRequired state"),
+ };
+
+ let id = id.clone();
+
+ let task = cx.spawn({
+ let id = id.clone();
+ let server = server.clone();
+ let configuration = configuration.clone();
+ async move |this, cx| {
+ let result = Self::run_oauth_flow(
+ this.clone(),
+ id.clone(),
+ discovery.clone(),
+ configuration.clone(),
+ cx,
+ )
+ .await;
+
+ if let Err(err) = &result {
+ log::error!("{} OAuth authentication failed: {:?}", id, err);
+ // Transition back to AuthRequired so the user can retry
+ // rather than landing in a terminal Error state.
+ this.update(cx, |this, cx| {
+ this.update_server_state(
+ id.clone(),
+ ContextServerState::AuthRequired {
+ server,
+ configuration,
+ discovery,
+ },
+ cx,
+ )
+ })
+ .log_err();
+ }
+ }
+ });
+
+ self.update_server_state(
+ id,
+ ContextServerState::Authenticating {
+ server,
+ configuration,
+ _task: task,
+ },
+ cx,
+ );
+
+ Ok(())
+ }
+
+ async fn run_oauth_flow(
+ this: WeakEntity<Self>,
+ id: ContextServerId,
+ discovery: Arc<OAuthDiscovery>,
+ configuration: Arc<ContextServerConfiguration>,
+ cx: &mut AsyncApp,
+ ) -> Result<()> {
+ let resource = oauth::canonical_server_uri(&discovery.resource_metadata.resource);
+ let pkce = oauth::generate_pkce_challenge();
+
+ let mut state_bytes = [0u8; 32];
+ rand::rng().fill(&mut state_bytes);
+ let state_param: String = state_bytes.iter().map(|b| format!("{:02x}", b)).collect();
+
+ // Start a loopback HTTP server on an ephemeral port. The redirect URI
+ // includes this port so the browser sends the callback directly to our
+ // process.
+ let (redirect_uri, callback_rx) = oauth::start_callback_server()
+ .await
+ .context("Failed to start OAuth callback server")?;
+
+ let http_client = cx.update(|cx| cx.http_client());
+ let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ let server_url = match configuration.as_ref() {
+ ContextServerConfiguration::Http { url, .. } => url.clone(),
+ _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
+ };
+
+ let client_registration =
+ oauth::resolve_client_registration(&http_client, &discovery, &redirect_uri)
+ .await
+ .context("Failed to resolve OAuth client registration")?;
+
+ let auth_url = oauth::build_authorization_url(
+ &discovery.auth_server_metadata,
+ &client_registration.client_id,
+ &redirect_uri,
+ &discovery.scopes,
+ &resource,
+ &pkce,
+ &state_param,
+ );
+
+ cx.update(|cx| cx.open_url(auth_url.as_str()));
+
+ let callback = callback_rx
+ .await
+ .map_err(|_| {
+ anyhow::anyhow!("OAuth callback server was shut down before receiving a response")
+ })?
+ .context("OAuth callback server received an invalid request")?;
+
+ if callback.state != state_param {
+ anyhow::bail!("OAuth state parameter mismatch (possible CSRF)");
+ }
+
+ let tokens = oauth::exchange_code(
+ &http_client,
+ &discovery.auth_server_metadata,
+ &callback.code,
+ &client_registration.client_id,
+ &redirect_uri,
+ &pkce.verifier,
+ &resource,
+ )
+ .await
+ .context("Failed to exchange authorization code for tokens")?;
+
+ let session = OAuthSession {
+ token_endpoint: discovery.auth_server_metadata.token_endpoint.clone(),
+ resource: discovery.resource_metadata.resource.clone(),
+ client_registration,
+ tokens,
+ };
+
+ Self::store_session(&credentials_provider, &server_url, &session, cx)
+ .await
+ .context("Failed to persist OAuth session in keychain")?;
+
+ let token_provider = Self::create_oauth_token_provider(
+ &id,
+ &server_url,
+ session,
+ http_client.clone(),
+ credentials_provider,
+ cx,
+ );
+
+ let new_server = this.update(cx, |this, cx| {
+ let global_timeout =
+ Self::resolve_project_settings(&this.worktree_store, cx).context_server_timeout;
+
+ match configuration.as_ref() {
+ ContextServerConfiguration::Http {
+ url,
+ headers,
+ timeout,
+ } => {
+ let transport = HttpTransport::new_with_token_provider(
+ http_client.clone(),
+ url.to_string(),
+ headers.clone(),
+ cx.background_executor().clone(),
+ Some(token_provider.clone()),
+ );
+ Ok(Arc::new(ContextServer::new_with_timeout(
+ id.clone(),
+ Arc::new(transport),
+ Some(Duration::from_secs(
+ timeout.unwrap_or(global_timeout).min(MAX_TIMEOUT_SECS),
+ )),
+ )))
+ }
+ _ => anyhow::bail!("OAuth authentication only supported for HTTP servers"),
+ }
+ })??;
+
+ this.update(cx, |this, cx| {
+ this.run_server(new_server, configuration, cx);
+ })?;
+
+ Ok(())
+ }
+
+ /// Store the full OAuth session in the system keychain, keyed by the
+ /// server's canonical URI.
+ async fn store_session(
+ credentials_provider: &Arc<dyn CredentialsProvider>,
+ server_url: &url::Url,
+ session: &OAuthSession,
+ cx: &AsyncApp,
+ ) -> Result<()> {
+ let key = Self::keychain_key(server_url);
+ let json = serde_json::to_string(session)?;
+ credentials_provider
+ .write_credentials(&key, "mcp-oauth", json.as_bytes(), cx)
+ .await
+ }
+
+ /// Load the full OAuth session from the system keychain for the given
+ /// server URL.
+ async fn load_session(
+ credentials_provider: &Arc<dyn CredentialsProvider>,
+ server_url: &url::Url,
+ cx: &AsyncApp,
+ ) -> Result<Option<OAuthSession>> {
+ let key = Self::keychain_key(server_url);
+ match credentials_provider.read_credentials(&key, cx).await? {
+ Some((_username, password_bytes)) => {
+ let session: OAuthSession = serde_json::from_slice(&password_bytes)?;
+ Ok(Some(session))
+ }
+ None => Ok(None),
+ }
+ }
+
+ /// Clear the stored OAuth session from the system keychain.
+ async fn clear_session(
+ credentials_provider: &Arc<dyn CredentialsProvider>,
+ server_url: &url::Url,
+ cx: &AsyncApp,
+ ) -> Result<()> {
+ let key = Self::keychain_key(server_url);
+ credentials_provider.delete_credentials(&key, cx).await
+ }
+
+ fn keychain_key(server_url: &url::Url) -> String {
+ format!("mcp-oauth:{}", oauth::canonical_server_uri(server_url))
+ }
+
+ /// Log out of an OAuth-authenticated MCP server: clear the stored OAuth
+ /// session from the keychain and stop the server.
+ pub fn logout_server(&mut self, id: &ContextServerId, cx: &mut Context<Self>) -> Result<()> {
+ let state = self.servers.get(id).context("Context server not found")?;
+ let configuration = state.configuration();
+
+ let server_url = match configuration.as_ref() {
+ ContextServerConfiguration::Http { url, .. } => url.clone(),
+ _ => anyhow::bail!("logout only applies to HTTP servers with OAuth"),
+ };
+
+ let id = id.clone();
+ self.stop_server(&id, cx)?;
+
+ cx.spawn(async move |this, cx| {
+ let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ if let Err(err) = Self::clear_session(&credentials_provider, &server_url, &cx).await {
+ log::error!("{} failed to clear OAuth session: {}", id, err);
+ }
+ // Trigger server recreation so the next start uses a fresh
+ // transport without the old (now-invalidated) token provider.
+ this.update(cx, |this, cx| {
+ this.available_context_servers_changed(cx);
+ })
+ .log_err();
+ })
+ .detach();
+
+ Ok(())
+ }
+
fn update_server_state(
&mut self,
id: ContextServerId,
@@ -1014,3 +1402,104 @@ impl ContextServerStore {
Ok(())
}
}
+
+/// Determines the appropriate server state after a start attempt fails.
+///
+/// When the error is an HTTP 401 with no static auth header configured,
+/// attempts OAuth discovery so the UI can offer an authentication flow.
+async fn resolve_start_failure(
+ id: &ContextServerId,
+ err: anyhow::Error,
+ server: Arc<ContextServer>,
+ configuration: Arc<ContextServerConfiguration>,
+ cx: &AsyncApp,
+) -> ContextServerState {
+ let www_authenticate = err.downcast_ref::<TransportError>().map(|e| match e {
+ TransportError::AuthRequired { www_authenticate } => www_authenticate.clone(),
+ });
+
+ if www_authenticate.is_some() && configuration.has_static_auth_header() {
+ log::warn!("{id} received 401 with a static Authorization header configured");
+ return ContextServerState::Error {
+ configuration,
+ server,
+ error: "Server returned 401 Unauthorized. Check your configured Authorization header."
+ .into(),
+ };
+ }
+
+ let server_url = match configuration.as_ref() {
+ ContextServerConfiguration::Http { url, .. } if !configuration.has_static_auth_header() => {
+ url.clone()
+ }
+ _ => {
+ if www_authenticate.is_some() {
+ log::error!("{id} got OAuth 401 on a non-HTTP transport or with static auth");
+ } else {
+ log::error!("{id} context server failed to start: {err}");
+ }
+ return ContextServerState::Error {
+ configuration,
+ server,
+ error: err.to_string().into(),
+ };
+ }
+ };
+
+ // When the error is NOT a 401 but there is a cached OAuth session in the
+ // keychain, the session is likely stale/expired and caused the failure
+ // (e.g. timeout because the server rejected the token silently). Clear it
+ // so the next start attempt can get a clean 401 and trigger the auth flow.
+ if www_authenticate.is_none() {
+ let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx));
+ match ContextServerStore::load_session(&credentials_provider, &server_url, cx).await {
+ Ok(Some(_)) => {
+ log::info!("{id} start failed with a cached OAuth session present; clearing it");
+ ContextServerStore::clear_session(&credentials_provider, &server_url, cx)
+ .await
+ .log_err();
+ }
+ _ => {
+ log::error!("{id} context server failed to start: {err}");
+ return ContextServerState::Error {
+ configuration,
+ server,
+ error: err.to_string().into(),
+ };
+ }
+ }
+ }
+
+ let default_www_authenticate = oauth::WwwAuthenticate {
+ resource_metadata: None,
+ scope: None,
+ error: None,
+ error_description: None,
+ };
+ let www_authenticate = www_authenticate
+ .as_ref()
+ .unwrap_or(&default_www_authenticate);
+ let http_client = cx.update(|cx| cx.http_client());
+
+ match context_server::oauth::discover(&http_client, &server_url, www_authenticate).await {
+ Ok(discovery) => {
+ log::info!(
+ "{id} requires OAuth authorization (auth server: {})",
+ discovery.auth_server_metadata.issuer,
+ );
+ ContextServerState::AuthRequired {
+ server,
+ configuration,
+ discovery: Arc::new(discovery),
+ }
+ }
+ Err(discovery_err) => {
+ log::error!("{id} OAuth discovery failed: {discovery_err}");
+ ContextServerState::Error {
+ configuration,
+ server,
+ error: format!("OAuth discovery failed: {discovery_err}").into(),
+ }
+ }
+ }
+}
@@ -162,7 +162,7 @@ impl RenderOnce for ModalHeader {
children.insert(
0,
Headline::new(headline)
- .size(HeadlineSize::XSmall)
+ .size(HeadlineSize::Small)
.color(Color::Muted)
.into_any_element(),
);