diff --git a/Cargo.lock b/Cargo.lock index 63abadce0ce0bbada4c5b8ffbe564dab27cd172a..a782de048787ce7bde37ba16bbb869a4026b4a6d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3454,7 +3454,6 @@ dependencies = [ "smol", "tempfile", "terminal", - "tiny_http", "url", "util", ] @@ -8217,6 +8216,7 @@ dependencies = [ "serde_urlencoded", "sha2", "tempfile", + "tiny_http", "url", "util", ] @@ -9388,6 +9388,7 @@ dependencies = [ "open_ai", "open_router", "opencode", + "parking_lot", "partial-json-fixer", "pretty_assertions", "rand 0.9.2", diff --git a/crates/context_server/Cargo.toml b/crates/context_server/Cargo.toml index 0a9c94a54d70196c0a0fee04dec249ea367d56c0..15a4c3ab557b16dc9d1f9b9c363e007b92bdbf5d 100644 --- a/crates/context_server/Cargo.toml +++ b/crates/context_server/Cargo.toml @@ -35,7 +35,6 @@ sha2.workspace = true slotmap.workspace = true smol.workspace = true tempfile.workspace = true -tiny_http.workspace = true url = { workspace = true, features = ["serde"] } util.workspace = true terminal.workspace = true diff --git a/crates/context_server/src/oauth.rs b/crates/context_server/src/oauth.rs index 44bcbbd557a67f40cbcf7621c52ddbd8d818b9d4..a21741d262ec7f0d8b02dc0aebf8dbc598ac3300 100644 --- a/crates/context_server/src/oauth.rs +++ b/crates/context_server/src/oauth.rs @@ -27,11 +27,9 @@ use rand::Rng as _; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; -use std::str::FromStr; use std::sync::Arc; use std::time::{Duration, SystemTime}; use url::Url; -use util::ResultExt as _; /// The CIMD URL where Zed's OAuth client metadata document is hosted. pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json"; @@ -992,58 +990,14 @@ impl OAuthCallback { /// Parse the query string from a callback URL like /// `http://127.0.0.1:/callback?code=...&state=...`. pub fn parse_query(query: &str) -> Result { - let mut code: Option = None; - let mut state: Option = None; - let mut error: Option = None; - let mut error_description: Option = None; - - for (key, value) in url::form_urlencoded::parse(query.as_bytes()) { - match key.as_ref() { - "code" => { - if !value.is_empty() { - code = Some(value.into_owned()); - } - } - "state" => { - if !value.is_empty() { - state = Some(value.into_owned()); - } - } - "error" => { - if !value.is_empty() { - error = Some(value.into_owned()); - } - } - "error_description" => { - if !value.is_empty() { - error_description = Some(value.into_owned()); - } - } - _ => {} - } - } - - // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before - // checking for missing code/state. - if let Some(error_code) = error { - bail!( - "OAuth authorization failed: {} ({})", - error_code, - error_description.as_deref().unwrap_or("no description") - ); - } - - let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?; - let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?; - - Ok(Self { code, state }) + let params = http_client::OAuthCallbackParams::parse_query(query)?; + Ok(Self { + code: params.code, + state: params.state, + }) } } -/// How long to wait for the browser to complete the OAuth flow before giving -/// up and releasing the loopback port. -const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60); - /// Start a loopback HTTP server to receive the OAuth authorization callback. /// /// Binds to an ephemeral loopback port for each flow. @@ -1057,107 +1011,26 @@ const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60); /// HTML page telling the user they can close the tab, and shuts down. /// /// The callback server shuts down when the returned oneshot receiver is dropped -/// (e.g. because the authentication task was cancelled), or after a timeout -/// ([CALLBACK_TIMEOUT]). +/// (e.g. because the authentication task was cancelled), or after a timeout. pub async fn start_callback_server() -> Result<( String, futures::channel::oneshot::Receiver>, )> { - let server = tiny_http::Server::http("127.0.0.1:0") - .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?; - let port = server - .server_addr() - .to_ip() - .context("server not bound to a TCP address")? - .port(); - - let redirect_uri = format!("http://127.0.0.1:{}/callback", port); - - let (tx, rx) = futures::channel::oneshot::channel(); - - // `tiny_http` is blocking, so we run it on a background thread. - // The `recv_timeout` loop lets us check for cancellation (the receiver - // being dropped) and enforce an overall timeout. - std::thread::spawn(move || { - let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT; - - loop { - if tx.is_canceled() { - return; - } - let remaining = deadline.saturating_duration_since(std::time::Instant::now()); - if remaining.is_zero() { - return; - } - - let timeout = remaining.min(Duration::from_millis(500)); - let Some(request) = (match server.recv_timeout(timeout) { - Ok(req) => req, - Err(_) => { - let _ = tx.send(Err(anyhow!("OAuth callback server I/O error"))); - return; - } - }) else { - // Timeout with no request — loop back and check cancellation. - continue; - }; - - let result = handle_callback_request(&request); - - let (status_code, body) = match &result { - Ok(_) => ( - 200, - http_client::oauth_callback_page( - "Authorization Successful", - "You can close this tab and return to Zed.", - ), - ), - Err(err) => { - log::error!("OAuth callback error: {}", err); - ( - 400, - http_client::oauth_callback_page( - "Authorization Failed", - "Something went wrong. Please try again from Zed.", - ), - ) - } - }; - - let response = tiny_http::Response::from_string(body) - .with_status_code(status_code) - .with_header( - tiny_http::Header::from_str("Content-Type: text/html") - .expect("failed to construct response header"), - ) - .with_header( - tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0") - .expect("failed to construct response header"), - ); - request.respond(response).log_err(); - - let _ = tx.send(result); - return; - } - }); - - Ok((redirect_uri, rx)) -} - -/// Extract the `code` and `state` query parameters from an OAuth callback -/// request to `/callback`. -fn handle_callback_request(request: &tiny_http::Request) -> Result { - let url = Url::parse(&format!("http://localhost{}", request.url())) - .context("malformed callback request URL")?; - - if url.path() != "/callback" { - bail!("unexpected path in OAuth callback: {}", url.path()); - } - - let query = url - .query() - .ok_or_else(|| anyhow!("OAuth callback has no query string"))?; - OAuthCallback::parse_query(query) + let (redirect_uri, rx) = http_client::start_oauth_callback_server()?; + let (tx, mapped_rx) = futures::channel::oneshot::channel(); + smol::spawn(async move { + let result = match rx.await { + Ok(Ok(params)) => Ok(OAuthCallback { + code: params.code, + state: params.state, + }), + Ok(Err(e)) => Err(e), + Err(_) => Err(anyhow!("OAuth callback channel was cancelled")), + }; + let _ = tx.send(result); + }) + .detach(); + Ok((redirect_uri, mapped_rx)) } // -- JSON fetch helper ------------------------------------------------------- diff --git a/crates/http_client/Cargo.toml b/crates/http_client/Cargo.toml index 6273d773d8c4651fd292555e18d2a2462e6358df..09bdadde0e5bc168829b9d1dec55171b5e2e5c41 100644 --- a/crates/http_client/Cargo.toml +++ b/crates/http_client/Cargo.toml @@ -37,3 +37,4 @@ async-fs.workspace = true async-tar.workspace = true sha2.workspace = true tempfile.workspace = true +tiny_http.workspace = true diff --git a/crates/http_client/src/http_client.rs b/crates/http_client/src/http_client.rs index b1f097fdac62107f2ba7aa9d571a4688f775fee4..7681465575172c313cce83b6861183d09ea93587 100644 --- a/crates/http_client/src/http_client.rs +++ b/crates/http_client/src/http_client.rs @@ -291,12 +291,43 @@ impl HttpClient for HttpClientWithUrl { } } +fn html_escape(input: &str) -> String { + let mut output = String::with_capacity(input.len()); + for ch in input.chars() { + match ch { + '&' => output.push_str("&"), + '<' => output.push_str("<"), + '>' => output.push_str(">"), + '"' => output.push_str("""), + '\'' => output.push_str("'"), + _ => output.push(ch), + } + } + output +} + /// Generate a styled HTML page for OAuth callback responses. /// /// Returns a complete HTML document (no HTTP headers) with a centered card /// layout styled to match Zed's dark theme. The `title` is rendered as a /// heading and `message` as body text below it. -pub fn oauth_callback_page(title: &str, message: &str) -> String { +/// +/// When `is_error` is true, a red X icon is shown instead of the green +/// checkmark. +pub fn oauth_callback_page(title: &str, message: &str, is_error: bool) -> String { + let title = html_escape(title); + let message = html_escape(message); + let (icon_bg, icon_svg) = if is_error { + ( + "#f38ba8", + r#""#, + ) + } else { + ( + "#a6e3a1", + r#""#, + ) + }; format!( r#" @@ -329,7 +360,7 @@ pub fn oauth_callback_page(title: &str, message: &str) -> String { width: 48px; height: 48px; margin: 0 auto 1.5rem; - background: #a6e3a1; + background: {icon_bg}; border-radius: 50%; display: flex; align-items: center; @@ -364,7 +395,7 @@ pub fn oauth_callback_page(title: &str, message: &str) -> String {
- + {icon_svg}

{title}

{message}

@@ -374,6 +405,8 @@ pub fn oauth_callback_page(title: &str, message: &str) -> String { "#, title = title, message = message, + icon_bg = icon_bg, + icon_svg = icon_svg, ) } @@ -529,3 +562,184 @@ impl HttpClient for FakeHttpClient { self } } + +// --------------------------------------------------------------------------- +// Shared OAuth callback server (non-wasm only) +// --------------------------------------------------------------------------- + +#[cfg(not(target_family = "wasm"))] +mod oauth_callback_server { + use super::*; + use anyhow::Context as _; + use std::str::FromStr; + use std::time::Duration; + + /// Parsed OAuth callback parameters from the authorization server redirect. + pub struct OAuthCallbackParams { + pub code: String, + pub state: String, + } + + impl OAuthCallbackParams { + /// Parse the query string from a callback URL like + /// `http://127.0.0.1:/callback?code=...&state=...`. + pub fn parse_query(query: &str) -> Result { + let mut code: Option = None; + let mut state: Option = None; + let mut error: Option = None; + let mut error_description: Option = None; + + for (key, value) in url::form_urlencoded::parse(query.as_bytes()) { + match key.as_ref() { + "code" => { + if !value.is_empty() { + code = Some(value.into_owned()); + } + } + "state" => { + if !value.is_empty() { + state = Some(value.into_owned()); + } + } + "error" => { + if !value.is_empty() { + error = Some(value.into_owned()); + } + } + "error_description" => { + if !value.is_empty() { + error_description = Some(value.into_owned()); + } + } + _ => {} + } + } + + if let Some(error_code) = error { + anyhow::bail!( + "OAuth authorization failed: {} ({})", + error_code, + error_description.as_deref().unwrap_or("no description") + ); + } + + let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?; + let state = + state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?; + + Ok(Self { code, state }) + } + } + + /// How long to wait for the browser to complete the OAuth flow before giving + /// up and releasing the loopback port. + const OAUTH_CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60); + + /// Start a loopback HTTP server to receive the OAuth authorization callback. + /// + /// Binds to an ephemeral loopback port. Returns `(redirect_uri, callback_future)`. + /// The caller should use the redirect URI in the authorization request, open + /// the browser, then await the future to receive the callback. + pub fn start_oauth_callback_server() -> Result<( + String, + futures::channel::oneshot::Receiver>, + )> { + let server = tiny_http::Server::http("127.0.0.1:0").map_err(|e| { + anyhow!(e).context("Failed to bind loopback listener for OAuth callback") + })?; + let port = server + .server_addr() + .to_ip() + .ok_or_else(|| anyhow!("server not bound to a TCP address"))? + .port(); + + let redirect_uri = format!("http://127.0.0.1:{}/callback", port); + + let (tx, rx) = futures::channel::oneshot::channel(); + + std::thread::spawn(move || { + let deadline = std::time::Instant::now() + OAUTH_CALLBACK_TIMEOUT; + + loop { + if tx.is_canceled() { + return; + } + let remaining = deadline.saturating_duration_since(std::time::Instant::now()); + if remaining.is_zero() { + return; + } + + let timeout = remaining.min(Duration::from_millis(500)); + let Some(request) = (match server.recv_timeout(timeout) { + Ok(req) => req, + Err(_) => { + let _ = tx.send(Err(anyhow!("OAuth callback server I/O error"))); + return; + } + }) else { + continue; + }; + + let result = handle_oauth_callback_request(&request); + + let (status_code, body) = match &result { + Ok(_) => ( + 200, + oauth_callback_page( + "Authorization Successful", + "You can close this tab and return to Zed.", + false, + ), + ), + Err(err) => { + log::error!("OAuth callback error: {}", err); + ( + 400, + oauth_callback_page( + "Authorization Failed", + "Something went wrong. Please try again from Zed.", + true, + ), + ) + } + }; + + let response = tiny_http::Response::from_string(body) + .with_status_code(status_code) + .with_header( + tiny_http::Header::from_str("Content-Type: text/html") + .expect("failed to construct response header"), + ) + .with_header( + tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0") + .expect("failed to construct response header"), + ); + if let Err(err) = request.respond(response) { + log::error!("Failed to send OAuth callback response: {}", err); + } + + let _ = tx.send(result); + return; + } + }); + + Ok((redirect_uri, rx)) + } + + fn handle_oauth_callback_request(request: &tiny_http::Request) -> Result { + let url = Url::parse(&format!("http://localhost{}", request.url())) + .context("malformed callback request URL")?; + + if url.path() != "/callback" { + anyhow::bail!("unexpected path in OAuth callback: {}", url.path()); + } + + let query = url + .query() + .ok_or_else(|| anyhow!("OAuth callback has no query string"))?; + OAuthCallbackParams::parse_query(query) + } +} + +#[cfg(not(target_family = "wasm"))] +pub use oauth_callback_server::{OAuthCallbackParams, start_oauth_callback_server}; diff --git a/crates/language_models/src/provider/openai_subscribed.rs b/crates/language_models/src/provider/openai_subscribed.rs index 3f7a767723af410871b8e7fae3d61246bc6c1e73..bceb446bb600cc9309b02e8b79e9c43111cd3713 100644 --- a/crates/language_models/src/provider/openai_subscribed.rs +++ b/crates/language_models/src/provider/openai_subscribed.rs @@ -2,7 +2,7 @@ use anyhow::{Context as _, Result, anyhow}; use base64::Engine as _; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use credentials_provider::CredentialsProvider; -use futures::{FutureExt, StreamExt, future::BoxFuture, future::Either, future::Shared}; +use futures::{FutureExt, StreamExt, future::BoxFuture, future::Shared}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use language_model::{ @@ -15,9 +15,8 @@ use open_ai::{ReasoningEffort, responses::stream_response}; use rand::RngCore as _; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; -use smol::io::{AsyncReadExt as _, AsyncWriteExt as _}; use std::sync::Arc; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::time::{SystemTime, UNIX_EPOCH}; use ui::{ConfiguredApiCard, prelude::*}; use url::form_urlencoded; use util::ResultExt as _; @@ -32,7 +31,7 @@ const CODEX_BASE_URL: &str = "https://chatgpt.com/backend-api/codex"; const OPENAI_TOKEN_URL: &str = "https://auth.openai.com/oauth/token"; const OPENAI_AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize"; const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann"; -const REDIRECT_URI: &str = "http://localhost:1455/auth/callback"; + const CREDENTIALS_KEY: &str = "https://chatgpt.com/backend-api/codex"; const TOKEN_REFRESH_BUFFER_MS: u64 = 5 * 60 * 1000; @@ -56,7 +55,25 @@ pub struct State { credentials: Option, sign_in_task: Option>>, refresh_task: Option>>>>, + load_task: Option>>>>, credentials_provider: Arc, + auth_generation: u64, + last_auth_error: Option, +} + +#[derive(Debug)] +enum RefreshError { + Fatal(anyhow::Error), + Transient(anyhow::Error), +} + +impl std::fmt::Display for RefreshError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RefreshError::Fatal(e) => write!(f, "{e}"), + RefreshError::Transient(e) => write!(f, "{e}"), + } + } } impl State { @@ -88,7 +105,10 @@ impl OpenAiSubscribedProvider { credentials: None, sign_in_task: None, refresh_task: None, + load_task: None, credentials_provider, + auth_generation: 0, + last_auth_error: None, }); let provider = Self { http_client, state }; @@ -100,31 +120,38 @@ impl OpenAiSubscribedProvider { fn load_credentials(&self, cx: &mut App) { let state = self.state.downgrade(); - cx.spawn(async move |cx| { - let credentials_provider = - state.read_with(&*cx, |s, _| s.credentials_provider.clone())?; - let result = credentials_provider - .read_credentials(CREDENTIALS_KEY, &*cx) - .await; - state.update(cx, |s, cx| { - if let Ok(Some((_, bytes))) = result { - match serde_json::from_slice::(&bytes) { - Ok(creds) => s.credentials = Some(creds), - Err(err) => { - log::warn!( - "Failed to deserialize ChatGPT subscription credentials: {err}" - ); + let load_task = cx + .spawn(async move |cx| { + let credentials_provider = + state.read_with(&*cx, |s, _| s.credentials_provider.clone())?; + let result = credentials_provider + .read_credentials(CREDENTIALS_KEY, &*cx) + .await; + state.update(cx, |s, cx| { + if let Ok(Some((_, bytes))) = result { + match serde_json::from_slice::(&bytes) { + Ok(creds) => s.credentials = Some(creds), + Err(err) => { + log::warn!( + "Failed to deserialize ChatGPT subscription credentials: {err}" + ); + } } } - } - cx.notify(); + s.load_task = None; + cx.notify(); + })?; + Ok::<(), Arc>(()) }) - }) - .detach(); + .shared(); + + self.state.update(cx, |s, _| { + s.load_task = Some(load_task); + }); } - fn sign_out(&self, cx: &mut App) { - do_sign_out(&self.state.downgrade(), cx); + fn sign_out(&self, cx: &mut App) -> Task> { + do_sign_out(&self.state.downgrade(), cx) } fn create_language_model(&self, model: ChatGptModel) -> Arc { @@ -182,10 +209,29 @@ impl LanguageModelProvider for OpenAiSubscribedProvider { if self.is_authenticated(cx) { return Task::ready(Ok(())); } - Task::ready(Err(anyhow!( - "Sign in with your ChatGPT Plus or Pro subscription to use this provider." - ) - .into())) + let load_task = self.state.read(cx).load_task.clone(); + if let Some(load_task) = load_task { + let weak_state = self.state.downgrade(); + cx.spawn(async move |cx| { + let _ = load_task.await; + let is_auth = weak_state + .read_with(&*cx, |s, _| s.is_authenticated()) + .unwrap_or(false); + if is_auth { + Ok(()) + } else { + Err(anyhow!( + "Sign in with your ChatGPT Plus or Pro subscription to use this provider." + ) + .into()) + } + }) + } else { + Task::ready(Err(anyhow!( + "Sign in with your ChatGPT Plus or Pro subscription to use this provider." + ) + .into())) + } } fn configuration_view( @@ -201,8 +247,7 @@ impl LanguageModelProvider for OpenAiSubscribedProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.sign_out(cx); - Task::ready(Ok(())) + self.sign_out(cx) } } @@ -380,10 +425,20 @@ impl LanguageModel for OpenAiSubscribedLanguageModel { fn count_tokens( &self, - _request: LanguageModelRequest, - _cx: &App, + request: LanguageModelRequest, + cx: &App, ) -> BoxFuture<'static, Result> { - futures::future::ready(Ok(0)).boxed() + let max_token_count = self.model.max_token_count(); + cx.background_spawn(async move { + let messages = crate::provider::open_ai::collect_tiktoken_messages(request); + let model = if max_token_count >= 100_000 { + "gpt-4o" + } else { + "gpt-4" + }; + tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64) + }) + .boxed() } fn stream_completion( @@ -501,6 +556,11 @@ async fn get_fresh_credentials( let state_clone = state.clone(); let refresh_token_value = creds.refresh_token.clone(); + // Capture the generation so we can detect sign-outs that happened during refresh. + let generation = state + .read_with(&*cx, |s, _| s.auth_generation) + .map_err(LanguageModelCompletionError::Other)?; + let shared_task = cx .spawn(async move |cx| { let result = refresh_token(&http_client_clone, &refresh_token_value).await; @@ -508,6 +568,16 @@ async fn get_fresh_credentials( match result { Ok(refreshed) => { let persist_result: Result> = async { + // Check if auth_generation changed (sign-out during refresh). + let current_generation = state_clone + .read_with(&*cx, |s, _| s.auth_generation) + .map_err(|e| Arc::new(e))?; + if current_generation != generation { + return Err(Arc::new(anyhow!( + "Sign-out occurred during token refresh" + ))); + } + let credentials_provider = state_clone .read_with(&*cx, |s, _| s.credentials_provider.clone()) .map_err(|e| Arc::new(e))?; @@ -540,7 +610,28 @@ async fn get_fresh_credentials( persist_result } - Err(e) => { + Err(RefreshError::Fatal(e)) => { + log::error!("ChatGPT subscription token refresh failed fatally: {e:?}"); + let _ = state_clone.update(cx, |s, cx| { + s.refresh_task = None; + s.credentials = None; + s.last_auth_error = + Some("Your session has expired. Please sign in again.".into()); + cx.notify(); + }); + // Also clear the keychain so stale credentials aren't loaded next time. + if let Ok(credentials_provider) = + state_clone.read_with(&*cx, |s, _| s.credentials_provider.clone()) + { + credentials_provider + .delete_credentials(CREDENTIALS_KEY, &*cx) + .await + .log_err(); + } + Err(Arc::new(e)) + } + Err(RefreshError::Transient(e)) => { + log::warn!("ChatGPT subscription token refresh failed transiently: {e:?}"); let _ = state_clone.update(cx, |s, _| { s.refresh_task = None; }); @@ -577,6 +668,10 @@ async fn do_oauth_flow( http_client: Arc, cx: &AsyncApp, ) -> Result { + // Start the callback server FIRST so the redirect URI is ready + let (redirect_uri, callback_rx) = http_client::start_oauth_callback_server() + .context("Failed to start OAuth callback server")?; + // PKCE verifier: 32 random bytes → base64url (no padding) let mut verifier_bytes = [0u8; 32]; rand::rng().fill_bytes(&mut verifier_bytes); @@ -596,7 +691,7 @@ async fn do_oauth_flow( auth_url .query_pairs_mut() .append_pair("client_id", CLIENT_ID) - .append_pair("redirect_uri", REDIRECT_URI) + .append_pair("redirect_uri", &redirect_uri) .append_pair("scope", "openid profile email offline_access") .append_pair("response_type", "code") .append_pair("code_challenge", &challenge) @@ -605,13 +700,21 @@ async fn do_oauth_flow( .append_pair("codex_cli_simplified_flow", "true") .append_pair("originator", "zed"); + // Open browser AFTER the listener is ready cx.update(|cx| cx.open_url(auth_url.as_str())); - let code = await_oauth_callback(&oauth_state, cx) + // Await the callback + let callback = callback_rx .await + .map_err(|_| anyhow!("OAuth callback was cancelled"))? .context("OAuth callback failed")?; - let tokens = exchange_code(&http_client, &code, &verifier) + // Validate CSRF state + if callback.state != oauth_state { + return Err(anyhow!("OAuth state mismatch")); + } + + let tokens = exchange_code(&http_client, &callback.code, &verifier, &redirect_uri) .await .context("Token exchange failed")?; @@ -630,93 +733,17 @@ async fn do_oauth_flow( }) } -async fn await_oauth_callback(expected_state: &str, cx: &AsyncApp) -> Result { - let listener = smol::net::TcpListener::bind("127.0.0.1:1455") - .await - .context("Failed to bind to port 1455 for OAuth callback. Another application may be using this port.")?; - - let accept_future = listener.accept(); - let timeout_future = cx.background_executor().timer(Duration::from_secs(120)); - - let (mut stream, _) = match futures::future::select( - std::pin::pin!(accept_future), - std::pin::pin!(timeout_future), - ) - .await - { - Either::Left((result, _)) => result?, - Either::Right((_, _)) => { - return Err(anyhow!( - "OAuth sign-in timed out after 2 minutes. Please try again." - )); - } - }; - - let mut buffer = vec![0u8; 4096]; - let mut total = 0; - loop { - if total >= buffer.len() { - return Err(anyhow!("OAuth callback request too large")); - } - let n = stream.read(&mut buffer[total..]).await?; - if n == 0 { - break; - } - total += n; - if buffer[..total].windows(4).any(|w| w == b"\r\n\r\n") { - break; - } - } - let request_text = std::str::from_utf8(&buffer[..total])?; - - // First line: "GET /auth/callback?code=...&state=... HTTP/1.1" - let path = request_text - .lines() - .next() - .and_then(|line| line.split_whitespace().nth(1)) - .ok_or_else(|| anyhow!("Invalid HTTP request from browser"))?; - - let query = path.split('?').nth(1).unwrap_or(""); - let mut code: Option = None; - let mut received_state: Option = None; - for (key, value) in form_urlencoded::parse(query.as_bytes()) { - match key.as_ref() { - "code" => code = Some(value.into_owned()), - "state" => received_state = Some(value.into_owned()), - _ => {} - } - } - - let page = http_client::oauth_callback_page( - "Signed In", - "You've signed into Zed via your ChatGPT subscription. You can close this tab and return to Zed.", - ); - let response = format!( - "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\n\r\n{}", - page.len(), - page - ); - stream.write_all(response.as_bytes()).await.log_err(); - - let received_state = - received_state.ok_or_else(|| anyhow!("Missing state in OAuth callback"))?; - if received_state != expected_state { - return Err(anyhow!("OAuth state mismatch")); - } - - code.ok_or_else(|| anyhow!("Missing authorization code in OAuth callback")) -} - async fn exchange_code( client: &Arc, code: &str, verifier: &str, + redirect_uri: &str, ) -> Result { let body = form_urlencoded::Serializer::new(String::new()) .append_pair("grant_type", "authorization_code") .append_pair("client_id", CLIENT_ID) .append_pair("code", code) - .append_pair("redirect_uri", REDIRECT_URI) + .append_pair("redirect_uri", redirect_uri) .append_pair("code_verifier", verifier) .finish(); @@ -743,7 +770,7 @@ async fn exchange_code( async fn refresh_token( client: &Arc, refresh_token: &str, -) -> Result { +) -> Result { let body = form_urlencoded::Serializer::new(String::new()) .append_pair("grant_type", "refresh_token") .append_pair("client_id", CLIENT_ID) @@ -754,20 +781,34 @@ async fn refresh_token( .method(Method::POST) .uri(OPENAI_TOKEN_URL) .header("Content-Type", "application/x-www-form-urlencoded") - .body(AsyncBody::from(body))?; + .body(AsyncBody::from(body)) + .map_err(|e| RefreshError::Transient(e.into()))?; - let mut response = client.send(request).await?; + let mut response = client + .send(request) + .await + .map_err(|e| RefreshError::Transient(e))?; + let status = response.status(); let mut body = String::new(); - smol::io::AsyncReadExt::read_to_string(response.body_mut(), &mut body).await?; - - if !response.status().is_success() { - return Err(anyhow!( - "Token refresh failed (HTTP {}): {body}", - response.status() - )); + smol::io::AsyncReadExt::read_to_string(response.body_mut(), &mut body) + .await + .map_err(|e| RefreshError::Transient(e.into()))?; + + if !status.is_success() { + let err = anyhow!("Token refresh failed (HTTP {}): {body}", status); + // 400/401/403 indicate a revoked or invalid refresh token. + // 5xx and other errors are treated as transient. + if status == http_client::StatusCode::BAD_REQUEST + || status == http_client::StatusCode::UNAUTHORIZED + || status == http_client::StatusCode::FORBIDDEN + { + return Err(RefreshError::Fatal(err)); + } + return Err(RefreshError::Transient(err)); } - let tokens: TokenResponse = serde_json::from_str(&body)?; + let tokens: TokenResponse = + serde_json::from_str(&body).map_err(|e| RefreshError::Transient(e.into()))?; let jwt = tokens .id_token .as_deref() @@ -876,6 +917,7 @@ fn do_sign_in(state: &Entity, http_client: &Arc, cx: &mut .update(cx, |s, cx| { s.credentials = Some(creds); s.sign_in_task = None; + s.last_auth_error = None; cx.notify(); }) .log_err(); @@ -887,6 +929,8 @@ fn do_sign_in(state: &Entity, http_client: &Arc, cx: &mut weak_state .update(cx, |s, cx| { s.sign_in_task = None; + s.last_auth_error = + Some("Failed to save credentials. Please try again.".into()); cx.notify(); }) .log_err(); @@ -898,6 +942,7 @@ fn do_sign_in(state: &Entity, http_client: &Arc, cx: &mut weak_state .update(cx, |s, cx| { s.sign_in_task = None; + s.last_auth_error = Some("Sign-in failed. Please try again.".into()); cx.notify(); }) .log_err(); @@ -907,28 +952,36 @@ fn do_sign_in(state: &Entity, http_client: &Arc, cx: &mut }); state.update(cx, |s, cx| { + s.last_auth_error = None; s.sign_in_task = Some(task); cx.notify(); }); } -fn do_sign_out(state: &gpui::WeakEntity, cx: &mut App) { +fn do_sign_out(state: &gpui::WeakEntity, cx: &mut App) -> Task> { let weak_state = state.clone(); + // Clear credentials and cancel in-flight work immediately so the UI + // reflects the sign-out right away. + weak_state + .update(cx, |s, cx| { + s.auth_generation += 1; + s.credentials = None; + s.sign_in_task = None; + s.refresh_task = None; + s.last_auth_error = None; + cx.notify(); + }) + .log_err(); + cx.spawn(async move |cx| { let credentials_provider = weak_state.read_with(&*cx, |s, _| s.credentials_provider.clone())?; credentials_provider .delete_credentials(CREDENTIALS_KEY, &*cx) .await - .log_err(); - weak_state.update(cx, |s, cx| { - s.credentials = None; - s.sign_in_task = None; - cx.notify(); - })?; + .context("Failed to delete ChatGPT subscription credentials from keychain")?; anyhow::Ok(()) }) - .detach(); } struct ConfigurationView { @@ -952,7 +1005,7 @@ impl Render for ConfigurationView { ConfiguredApiCard::new(SharedString::from(label)) .button_label("Sign Out") .on_click(cx.listener(move |_this, _, _window, cx| { - do_sign_out(&weak_state, cx); + do_sign_out(&weak_state, cx).detach_and_log_err(cx); })), ) .into_any_element(); @@ -964,11 +1017,15 @@ impl Render for ConfigurationView { .into_any_element(); } + let last_auth_error = state.last_auth_error.clone(); let provider_state = self.state.clone(); let http_client = self.http_client.clone(); v_flex() .gap_2() + .when_some(last_auth_error, |this, error| { + this.child(Label::new(error).color(Color::Error)) + }) .child(Label::new( "Sign in with your ChatGPT Plus or Pro subscription to use OpenAI models in Zed's agent.", )) @@ -1085,7 +1142,10 @@ mod tests { credentials: Some(make_expired_credentials()), sign_in_task: None, refresh_task: None, + load_task: None, credentials_provider: Arc::new(FakeCredentialsProvider::new()), + auth_generation: 0, + last_auth_error: None, }); let weak_state = cx.read(|_cx| state.downgrade()); @@ -1138,7 +1198,10 @@ mod tests { credentials: Some(make_fresh_credentials()), sign_in_task: None, refresh_task: None, + load_task: None, credentials_provider: Arc::new(FakeCredentialsProvider::new()), + auth_generation: 0, + last_auth_error: None, }); let weak_state = cx.read(|_cx| state.downgrade()); @@ -1171,7 +1234,10 @@ mod tests { credentials: None, sign_in_task: None, refresh_task: None, + load_task: None, credentials_provider: Arc::new(FakeCredentialsProvider::new()), + auth_generation: 0, + last_auth_error: None, }); let weak_state = cx.read(|_cx| state.downgrade()); @@ -1188,4 +1254,225 @@ mod tests { Err(LanguageModelCompletionError::NoApiKey { .. }) )); } + + #[gpui::test] + async fn test_fatal_refresh_clears_auth_state(cx: &mut TestAppContext) { + let http_client = FakeHttpClient::create(move |_request| async move { + Ok(http_client::Response::builder() + .status(401) + .body(http_client::AsyncBody::from(r#"{"error":"invalid_grant"}"#))?) + }); + + let creds_provider = Arc::new(FakeCredentialsProvider::new()); + let state = cx.new(|_cx| State { + credentials: Some(make_expired_credentials()), + sign_in_task: None, + refresh_task: None, + load_task: None, + credentials_provider: creds_provider.clone(), + auth_generation: 0, + last_auth_error: None, + }); + + let weak_state = cx.read(|_cx| state.downgrade()); + let http: Arc = http_client; + + let weak = weak_state.clone(); + let http_clone = http.clone(); + let result = cx + .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await) + .await; + + cx.run_until_parked(); + + assert!(result.is_err(), "fatal refresh should return an error"); + cx.read(|cx| { + let s = state.read(cx); + assert!( + s.credentials.is_none(), + "credentials should be cleared on fatal refresh failure" + ); + assert!( + s.last_auth_error.is_some(), + "last_auth_error should be set on fatal refresh failure" + ); + }); + } + + #[gpui::test] + async fn test_transient_refresh_keeps_credentials(cx: &mut TestAppContext) { + let http_client = FakeHttpClient::create(move |_request| async move { + Ok(http_client::Response::builder() + .status(500) + .body(http_client::AsyncBody::from("Internal Server Error"))?) + }); + + let state = cx.new(|_cx| State { + credentials: Some(make_expired_credentials()), + sign_in_task: None, + refresh_task: None, + load_task: None, + credentials_provider: Arc::new(FakeCredentialsProvider::new()), + auth_generation: 0, + last_auth_error: None, + }); + + let weak_state = cx.read(|_cx| state.downgrade()); + let http: Arc = http_client; + + let weak = weak_state.clone(); + let http_clone = http.clone(); + let result = cx + .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await) + .await; + + cx.run_until_parked(); + + assert!(result.is_err(), "transient refresh should return an error"); + cx.read(|cx| { + let s = state.read(cx); + assert!( + s.credentials.is_some(), + "credentials should be kept on transient refresh failure" + ); + assert!( + s.last_auth_error.is_none(), + "last_auth_error should not be set on transient refresh failure" + ); + }); + } + + #[gpui::test] + async fn test_sign_out_during_refresh_discards_result(cx: &mut TestAppContext) { + let (gate_tx, gate_rx) = futures::channel::oneshot::channel::<()>(); + let gate_rx = Arc::new(Mutex::new(Some(gate_rx))); + let gate_rx_clone = gate_rx.clone(); + + let http_client = FakeHttpClient::create(move |_request| { + let gate_rx = gate_rx_clone.clone(); + async move { + // Wait until the gate is opened, simulating a slow network. + let rx = gate_rx.lock().take(); + if let Some(rx) = rx { + let _ = rx.await; + } + let body = fake_token_response(); + Ok(http_client::Response::builder() + .status(200) + .body(http_client::AsyncBody::from(body))?) + } + }); + + let creds_provider = Arc::new(FakeCredentialsProvider::new()); + let state = cx.new(|_cx| State { + credentials: Some(make_expired_credentials()), + sign_in_task: None, + refresh_task: None, + load_task: None, + credentials_provider: creds_provider.clone(), + auth_generation: 0, + last_auth_error: None, + }); + + let weak_state = cx.read(|_cx| state.downgrade()); + let http: Arc = http_client; + + // Start a refresh + let weak = weak_state.clone(); + let http_clone = http.clone(); + let refresh_task = + cx.spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await); + + cx.run_until_parked(); + + // Sign out while the refresh is in-flight + cx.update(|cx| { + do_sign_out(&weak_state, cx).detach(); + }); + cx.run_until_parked(); + + // Now let the refresh respond by opening the gate + let _ = gate_tx.send(()); + cx.run_until_parked(); + + let result = refresh_task.await; + assert!(result.is_err(), "refresh should fail after sign-out"); + + cx.read(|cx| { + let s = state.read(cx); + assert!( + s.credentials.is_none(), + "sign-out should have cleared credentials" + ); + }); + } + + #[gpui::test] + async fn test_sign_out_completes_fully(cx: &mut TestAppContext) { + let creds_provider = Arc::new(FakeCredentialsProvider::new()); + // Pre-populate the credential store + creds_provider + .storage + .lock() + .replace(("Bearer".to_string(), b"some-creds".to_vec())); + + let state = cx.new(|_cx| State { + credentials: Some(make_fresh_credentials()), + sign_in_task: None, + refresh_task: None, + load_task: None, + credentials_provider: creds_provider.clone(), + auth_generation: 0, + last_auth_error: None, + }); + + let weak_state = cx.read(|_cx| state.downgrade()); + let sign_out_task = cx.update(|cx| do_sign_out(&weak_state, cx)); + + cx.run_until_parked(); + sign_out_task.await.expect("sign-out should succeed"); + + assert!( + creds_provider.storage.lock().is_none(), + "credential store should be empty after sign-out" + ); + cx.read(|cx| { + assert!( + !state.read(cx).is_authenticated(), + "state should show not authenticated" + ); + }); + } + + #[gpui::test] + async fn test_authenticate_awaits_initial_load(cx: &mut TestAppContext) { + let creds = make_fresh_credentials(); + let creds_json = serde_json::to_vec(&creds).unwrap(); + let creds_provider = Arc::new(FakeCredentialsProvider::new()); + creds_provider + .storage + .lock() + .replace(("Bearer".to_string(), creds_json)); + + let http_client = FakeHttpClient::create(|_| async { + Ok(http_client::Response::builder() + .status(200) + .body(http_client::AsyncBody::default())?) + }); + + let provider = + cx.update(|cx| OpenAiSubscribedProvider::new(http_client, creds_provider, cx)); + + // Before load completes, authenticate should still await the load. + let auth_task = cx.update(|cx| provider.authenticate(cx)); + + // Drive the load to completion. + cx.run_until_parked(); + + let result = auth_task.await; + assert!( + result.is_ok(), + "authenticate should succeed after load completes with valid credentials" + ); + } }