@@ -291,12 +291,43 @@ impl HttpClient for HttpClientWithUrl {
}
}
+fn html_escape(input: &str) -> String {
+ let mut output = String::with_capacity(input.len());
+ for ch in input.chars() {
+ match ch {
+ '&' => output.push_str("&"),
+ '<' => output.push_str("<"),
+ '>' => output.push_str(">"),
+ '"' => output.push_str("""),
+ '\'' => output.push_str("'"),
+ _ => output.push(ch),
+ }
+ }
+ output
+}
+
/// Generate a styled HTML page for OAuth callback responses.
///
/// Returns a complete HTML document (no HTTP headers) with a centered card
/// layout styled to match Zed's dark theme. The `title` is rendered as a
/// heading and `message` as body text below it.
-pub fn oauth_callback_page(title: &str, message: &str) -> String {
+///
+/// When `is_error` is true, a red X icon is shown instead of the green
+/// checkmark.
+pub fn oauth_callback_page(title: &str, message: &str, is_error: bool) -> String {
+ let title = html_escape(title);
+ let message = html_escape(message);
+ let (icon_bg, icon_svg) = if is_error {
+ (
+ "#f38ba8",
+ r#"<svg viewBox="0 0 24 24"><line x1="18" y1="6" x2="6" y2="18"/><line x1="6" y1="6" x2="18" y2="18"/></svg>"#,
+ )
+ } else {
+ (
+ "#a6e3a1",
+ r#"<svg viewBox="0 0 24 24"><polyline points="20 6 9 17 4 12"/></svg>"#,
+ )
+ };
format!(
r#"<!DOCTYPE html>
<html lang="en">
@@ -329,7 +360,7 @@ pub fn oauth_callback_page(title: &str, message: &str) -> String {
width: 48px;
height: 48px;
margin: 0 auto 1.5rem;
- background: #a6e3a1;
+ background: {icon_bg};
border-radius: 50%;
display: flex;
align-items: center;
@@ -364,7 +395,7 @@ pub fn oauth_callback_page(title: &str, message: &str) -> String {
<body>
<div class="card">
<div class="icon">
- <svg viewBox="0 0 24 24"><polyline points="20 6 9 17 4 12"/></svg>
+ {icon_svg}
</div>
<h1>{title}</h1>
<p>{message}</p>
@@ -374,6 +405,8 @@ pub fn oauth_callback_page(title: &str, message: &str) -> String {
</html>"#,
title = title,
message = message,
+ icon_bg = icon_bg,
+ icon_svg = icon_svg,
)
}
@@ -529,3 +562,184 @@ impl HttpClient for FakeHttpClient {
self
}
}
+
+// ---------------------------------------------------------------------------
+// Shared OAuth callback server (non-wasm only)
+// ---------------------------------------------------------------------------
+
+#[cfg(not(target_family = "wasm"))]
+mod oauth_callback_server {
+ use super::*;
+ use anyhow::Context as _;
+ use std::str::FromStr;
+ use std::time::Duration;
+
+ /// Parsed OAuth callback parameters from the authorization server redirect.
+ pub struct OAuthCallbackParams {
+ pub code: String,
+ pub state: String,
+ }
+
+ impl OAuthCallbackParams {
+ /// Parse the query string from a callback URL like
+ /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
+ pub fn parse_query(query: &str) -> Result<Self> {
+ let mut code: Option<String> = None;
+ let mut state: Option<String> = None;
+ let mut error: Option<String> = None;
+ let mut error_description: Option<String> = None;
+
+ for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
+ match key.as_ref() {
+ "code" => {
+ if !value.is_empty() {
+ code = Some(value.into_owned());
+ }
+ }
+ "state" => {
+ if !value.is_empty() {
+ state = Some(value.into_owned());
+ }
+ }
+ "error" => {
+ if !value.is_empty() {
+ error = Some(value.into_owned());
+ }
+ }
+ "error_description" => {
+ if !value.is_empty() {
+ error_description = Some(value.into_owned());
+ }
+ }
+ _ => {}
+ }
+ }
+
+ if let Some(error_code) = error {
+ anyhow::bail!(
+ "OAuth authorization failed: {} ({})",
+ error_code,
+ error_description.as_deref().unwrap_or("no description")
+ );
+ }
+
+ let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
+ let state =
+ state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
+
+ Ok(Self { code, state })
+ }
+ }
+
+ /// How long to wait for the browser to complete the OAuth flow before giving
+ /// up and releasing the loopback port.
+ const OAUTH_CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
+
+ /// Start a loopback HTTP server to receive the OAuth authorization callback.
+ ///
+ /// Binds to an ephemeral loopback port. Returns `(redirect_uri, callback_future)`.
+ /// The caller should use the redirect URI in the authorization request, open
+ /// the browser, then await the future to receive the callback.
+ pub fn start_oauth_callback_server() -> Result<(
+ String,
+ futures::channel::oneshot::Receiver<Result<OAuthCallbackParams>>,
+ )> {
+ let server = tiny_http::Server::http("127.0.0.1:0").map_err(|e| {
+ anyhow!(e).context("Failed to bind loopback listener for OAuth callback")
+ })?;
+ let port = server
+ .server_addr()
+ .to_ip()
+ .ok_or_else(|| anyhow!("server not bound to a TCP address"))?
+ .port();
+
+ let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
+
+ let (tx, rx) = futures::channel::oneshot::channel();
+
+ std::thread::spawn(move || {
+ let deadline = std::time::Instant::now() + OAUTH_CALLBACK_TIMEOUT;
+
+ loop {
+ if tx.is_canceled() {
+ return;
+ }
+ let remaining = deadline.saturating_duration_since(std::time::Instant::now());
+ if remaining.is_zero() {
+ return;
+ }
+
+ let timeout = remaining.min(Duration::from_millis(500));
+ let Some(request) = (match server.recv_timeout(timeout) {
+ Ok(req) => req,
+ Err(_) => {
+ let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
+ return;
+ }
+ }) else {
+ continue;
+ };
+
+ let result = handle_oauth_callback_request(&request);
+
+ let (status_code, body) = match &result {
+ Ok(_) => (
+ 200,
+ oauth_callback_page(
+ "Authorization Successful",
+ "You can close this tab and return to Zed.",
+ false,
+ ),
+ ),
+ Err(err) => {
+ log::error!("OAuth callback error: {}", err);
+ (
+ 400,
+ oauth_callback_page(
+ "Authorization Failed",
+ "Something went wrong. Please try again from Zed.",
+ true,
+ ),
+ )
+ }
+ };
+
+ let response = tiny_http::Response::from_string(body)
+ .with_status_code(status_code)
+ .with_header(
+ tiny_http::Header::from_str("Content-Type: text/html")
+ .expect("failed to construct response header"),
+ )
+ .with_header(
+ tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
+ .expect("failed to construct response header"),
+ );
+ if let Err(err) = request.respond(response) {
+ log::error!("Failed to send OAuth callback response: {}", err);
+ }
+
+ let _ = tx.send(result);
+ return;
+ }
+ });
+
+ Ok((redirect_uri, rx))
+ }
+
+ fn handle_oauth_callback_request(request: &tiny_http::Request) -> Result<OAuthCallbackParams> {
+ let url = Url::parse(&format!("http://localhost{}", request.url()))
+ .context("malformed callback request URL")?;
+
+ if url.path() != "/callback" {
+ anyhow::bail!("unexpected path in OAuth callback: {}", url.path());
+ }
+
+ let query = url
+ .query()
+ .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
+ OAuthCallbackParams::parse_query(query)
+ }
+}
+
+#[cfg(not(target_family = "wasm"))]
+pub use oauth_callback_server::{OAuthCallbackParams, start_oauth_callback_server};
@@ -2,7 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
use base64::Engine as _;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use credentials_provider::CredentialsProvider;
-use futures::{FutureExt, StreamExt, future::BoxFuture, future::Either, future::Shared};
+use futures::{FutureExt, StreamExt, future::BoxFuture, future::Shared};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use language_model::{
@@ -15,9 +15,8 @@ use open_ai::{ReasoningEffort, responses::stream_response};
use rand::RngCore as _;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
-use smol::io::{AsyncReadExt as _, AsyncWriteExt as _};
use std::sync::Arc;
-use std::time::{Duration, SystemTime, UNIX_EPOCH};
+use std::time::{SystemTime, UNIX_EPOCH};
use ui::{ConfiguredApiCard, prelude::*};
use url::form_urlencoded;
use util::ResultExt as _;
@@ -32,7 +31,7 @@ const CODEX_BASE_URL: &str = "https://chatgpt.com/backend-api/codex";
const OPENAI_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
const OPENAI_AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
-const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
+
const CREDENTIALS_KEY: &str = "https://chatgpt.com/backend-api/codex";
const TOKEN_REFRESH_BUFFER_MS: u64 = 5 * 60 * 1000;
@@ -56,7 +55,25 @@ pub struct State {
credentials: Option<CodexCredentials>,
sign_in_task: Option<Task<Result<()>>>,
refresh_task: Option<Shared<Task<Result<CodexCredentials, Arc<anyhow::Error>>>>>,
+ load_task: Option<Shared<Task<Result<(), Arc<anyhow::Error>>>>>,
credentials_provider: Arc<dyn CredentialsProvider>,
+ auth_generation: u64,
+ last_auth_error: Option<SharedString>,
+}
+
+#[derive(Debug)]
+enum RefreshError {
+ Fatal(anyhow::Error),
+ Transient(anyhow::Error),
+}
+
+impl std::fmt::Display for RefreshError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ RefreshError::Fatal(e) => write!(f, "{e}"),
+ RefreshError::Transient(e) => write!(f, "{e}"),
+ }
+ }
}
impl State {
@@ -88,7 +105,10 @@ impl OpenAiSubscribedProvider {
credentials: None,
sign_in_task: None,
refresh_task: None,
+ load_task: None,
credentials_provider,
+ auth_generation: 0,
+ last_auth_error: None,
});
let provider = Self { http_client, state };
@@ -100,31 +120,38 @@ impl OpenAiSubscribedProvider {
fn load_credentials(&self, cx: &mut App) {
let state = self.state.downgrade();
- cx.spawn(async move |cx| {
- let credentials_provider =
- state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
- let result = credentials_provider
- .read_credentials(CREDENTIALS_KEY, &*cx)
- .await;
- state.update(cx, |s, cx| {
- if let Ok(Some((_, bytes))) = result {
- match serde_json::from_slice::<CodexCredentials>(&bytes) {
- Ok(creds) => s.credentials = Some(creds),
- Err(err) => {
- log::warn!(
- "Failed to deserialize ChatGPT subscription credentials: {err}"
- );
+ let load_task = cx
+ .spawn(async move |cx| {
+ let credentials_provider =
+ state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
+ let result = credentials_provider
+ .read_credentials(CREDENTIALS_KEY, &*cx)
+ .await;
+ state.update(cx, |s, cx| {
+ if let Ok(Some((_, bytes))) = result {
+ match serde_json::from_slice::<CodexCredentials>(&bytes) {
+ Ok(creds) => s.credentials = Some(creds),
+ Err(err) => {
+ log::warn!(
+ "Failed to deserialize ChatGPT subscription credentials: {err}"
+ );
+ }
}
}
- }
- cx.notify();
+ s.load_task = None;
+ cx.notify();
+ })?;
+ Ok::<(), Arc<anyhow::Error>>(())
})
- })
- .detach();
+ .shared();
+
+ self.state.update(cx, |s, _| {
+ s.load_task = Some(load_task);
+ });
}
- fn sign_out(&self, cx: &mut App) {
- do_sign_out(&self.state.downgrade(), cx);
+ fn sign_out(&self, cx: &mut App) -> Task<Result<()>> {
+ do_sign_out(&self.state.downgrade(), cx)
}
fn create_language_model(&self, model: ChatGptModel) -> Arc<dyn LanguageModel> {
@@ -182,10 +209,29 @@ impl LanguageModelProvider for OpenAiSubscribedProvider {
if self.is_authenticated(cx) {
return Task::ready(Ok(()));
}
- Task::ready(Err(anyhow!(
- "Sign in with your ChatGPT Plus or Pro subscription to use this provider."
- )
- .into()))
+ let load_task = self.state.read(cx).load_task.clone();
+ if let Some(load_task) = load_task {
+ let weak_state = self.state.downgrade();
+ cx.spawn(async move |cx| {
+ let _ = load_task.await;
+ let is_auth = weak_state
+ .read_with(&*cx, |s, _| s.is_authenticated())
+ .unwrap_or(false);
+ if is_auth {
+ Ok(())
+ } else {
+ Err(anyhow!(
+ "Sign in with your ChatGPT Plus or Pro subscription to use this provider."
+ )
+ .into())
+ }
+ })
+ } else {
+ Task::ready(Err(anyhow!(
+ "Sign in with your ChatGPT Plus or Pro subscription to use this provider."
+ )
+ .into()))
+ }
}
fn configuration_view(
@@ -201,8 +247,7 @@ impl LanguageModelProvider for OpenAiSubscribedProvider {
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
- self.sign_out(cx);
- Task::ready(Ok(()))
+ self.sign_out(cx)
}
}
@@ -380,10 +425,20 @@ impl LanguageModel for OpenAiSubscribedLanguageModel {
fn count_tokens(
&self,
- _request: LanguageModelRequest,
- _cx: &App,
+ request: LanguageModelRequest,
+ cx: &App,
) -> BoxFuture<'static, Result<u64>> {
- futures::future::ready(Ok(0)).boxed()
+ let max_token_count = self.model.max_token_count();
+ cx.background_spawn(async move {
+ let messages = crate::provider::open_ai::collect_tiktoken_messages(request);
+ let model = if max_token_count >= 100_000 {
+ "gpt-4o"
+ } else {
+ "gpt-4"
+ };
+ tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
+ })
+ .boxed()
}
fn stream_completion(
@@ -501,6 +556,11 @@ async fn get_fresh_credentials(
let state_clone = state.clone();
let refresh_token_value = creds.refresh_token.clone();
+ // Capture the generation so we can detect sign-outs that happened during refresh.
+ let generation = state
+ .read_with(&*cx, |s, _| s.auth_generation)
+ .map_err(LanguageModelCompletionError::Other)?;
+
let shared_task = cx
.spawn(async move |cx| {
let result = refresh_token(&http_client_clone, &refresh_token_value).await;
@@ -508,6 +568,16 @@ async fn get_fresh_credentials(
match result {
Ok(refreshed) => {
let persist_result: Result<CodexCredentials, Arc<anyhow::Error>> = async {
+ // Check if auth_generation changed (sign-out during refresh).
+ let current_generation = state_clone
+ .read_with(&*cx, |s, _| s.auth_generation)
+ .map_err(|e| Arc::new(e))?;
+ if current_generation != generation {
+ return Err(Arc::new(anyhow!(
+ "Sign-out occurred during token refresh"
+ )));
+ }
+
let credentials_provider = state_clone
.read_with(&*cx, |s, _| s.credentials_provider.clone())
.map_err(|e| Arc::new(e))?;
@@ -540,7 +610,28 @@ async fn get_fresh_credentials(
persist_result
}
- Err(e) => {
+ Err(RefreshError::Fatal(e)) => {
+ log::error!("ChatGPT subscription token refresh failed fatally: {e:?}");
+ let _ = state_clone.update(cx, |s, cx| {
+ s.refresh_task = None;
+ s.credentials = None;
+ s.last_auth_error =
+ Some("Your session has expired. Please sign in again.".into());
+ cx.notify();
+ });
+ // Also clear the keychain so stale credentials aren't loaded next time.
+ if let Ok(credentials_provider) =
+ state_clone.read_with(&*cx, |s, _| s.credentials_provider.clone())
+ {
+ credentials_provider
+ .delete_credentials(CREDENTIALS_KEY, &*cx)
+ .await
+ .log_err();
+ }
+ Err(Arc::new(e))
+ }
+ Err(RefreshError::Transient(e)) => {
+ log::warn!("ChatGPT subscription token refresh failed transiently: {e:?}");
let _ = state_clone.update(cx, |s, _| {
s.refresh_task = None;
});
@@ -577,6 +668,10 @@ async fn do_oauth_flow(
http_client: Arc<dyn HttpClient>,
cx: &AsyncApp,
) -> Result<CodexCredentials> {
+ // Start the callback server FIRST so the redirect URI is ready
+ let (redirect_uri, callback_rx) = http_client::start_oauth_callback_server()
+ .context("Failed to start OAuth callback server")?;
+
// PKCE verifier: 32 random bytes → base64url (no padding)
let mut verifier_bytes = [0u8; 32];
rand::rng().fill_bytes(&mut verifier_bytes);
@@ -596,7 +691,7 @@ async fn do_oauth_flow(
auth_url
.query_pairs_mut()
.append_pair("client_id", CLIENT_ID)
- .append_pair("redirect_uri", REDIRECT_URI)
+ .append_pair("redirect_uri", &redirect_uri)
.append_pair("scope", "openid profile email offline_access")
.append_pair("response_type", "code")
.append_pair("code_challenge", &challenge)
@@ -605,13 +700,21 @@ async fn do_oauth_flow(
.append_pair("codex_cli_simplified_flow", "true")
.append_pair("originator", "zed");
+ // Open browser AFTER the listener is ready
cx.update(|cx| cx.open_url(auth_url.as_str()));
- let code = await_oauth_callback(&oauth_state, cx)
+ // Await the callback
+ let callback = callback_rx
.await
+ .map_err(|_| anyhow!("OAuth callback was cancelled"))?
.context("OAuth callback failed")?;
- let tokens = exchange_code(&http_client, &code, &verifier)
+ // Validate CSRF state
+ if callback.state != oauth_state {
+ return Err(anyhow!("OAuth state mismatch"));
+ }
+
+ let tokens = exchange_code(&http_client, &callback.code, &verifier, &redirect_uri)
.await
.context("Token exchange failed")?;
@@ -630,93 +733,17 @@ async fn do_oauth_flow(
})
}
-async fn await_oauth_callback(expected_state: &str, cx: &AsyncApp) -> Result<String> {
- let listener = smol::net::TcpListener::bind("127.0.0.1:1455")
- .await
- .context("Failed to bind to port 1455 for OAuth callback. Another application may be using this port.")?;
-
- let accept_future = listener.accept();
- let timeout_future = cx.background_executor().timer(Duration::from_secs(120));
-
- let (mut stream, _) = match futures::future::select(
- std::pin::pin!(accept_future),
- std::pin::pin!(timeout_future),
- )
- .await
- {
- Either::Left((result, _)) => result?,
- Either::Right((_, _)) => {
- return Err(anyhow!(
- "OAuth sign-in timed out after 2 minutes. Please try again."
- ));
- }
- };
-
- let mut buffer = vec![0u8; 4096];
- let mut total = 0;
- loop {
- if total >= buffer.len() {
- return Err(anyhow!("OAuth callback request too large"));
- }
- let n = stream.read(&mut buffer[total..]).await?;
- if n == 0 {
- break;
- }
- total += n;
- if buffer[..total].windows(4).any(|w| w == b"\r\n\r\n") {
- break;
- }
- }
- let request_text = std::str::from_utf8(&buffer[..total])?;
-
- // First line: "GET /auth/callback?code=...&state=... HTTP/1.1"
- let path = request_text
- .lines()
- .next()
- .and_then(|line| line.split_whitespace().nth(1))
- .ok_or_else(|| anyhow!("Invalid HTTP request from browser"))?;
-
- let query = path.split('?').nth(1).unwrap_or("");
- let mut code: Option<String> = None;
- let mut received_state: Option<String> = None;
- for (key, value) in form_urlencoded::parse(query.as_bytes()) {
- match key.as_ref() {
- "code" => code = Some(value.into_owned()),
- "state" => received_state = Some(value.into_owned()),
- _ => {}
- }
- }
-
- let page = http_client::oauth_callback_page(
- "Signed In",
- "You've signed into Zed via your ChatGPT subscription. You can close this tab and return to Zed.",
- );
- let response = format!(
- "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\n\r\n{}",
- page.len(),
- page
- );
- stream.write_all(response.as_bytes()).await.log_err();
-
- let received_state =
- received_state.ok_or_else(|| anyhow!("Missing state in OAuth callback"))?;
- if received_state != expected_state {
- return Err(anyhow!("OAuth state mismatch"));
- }
-
- code.ok_or_else(|| anyhow!("Missing authorization code in OAuth callback"))
-}
-
async fn exchange_code(
client: &Arc<dyn HttpClient>,
code: &str,
verifier: &str,
+ redirect_uri: &str,
) -> Result<TokenResponse> {
let body = form_urlencoded::Serializer::new(String::new())
.append_pair("grant_type", "authorization_code")
.append_pair("client_id", CLIENT_ID)
.append_pair("code", code)
- .append_pair("redirect_uri", REDIRECT_URI)
+ .append_pair("redirect_uri", redirect_uri)
.append_pair("code_verifier", verifier)
.finish();
@@ -743,7 +770,7 @@ async fn exchange_code(
async fn refresh_token(
client: &Arc<dyn HttpClient>,
refresh_token: &str,
-) -> Result<CodexCredentials> {
+) -> Result<CodexCredentials, RefreshError> {
let body = form_urlencoded::Serializer::new(String::new())
.append_pair("grant_type", "refresh_token")
.append_pair("client_id", CLIENT_ID)
@@ -754,20 +781,34 @@ async fn refresh_token(
.method(Method::POST)
.uri(OPENAI_TOKEN_URL)
.header("Content-Type", "application/x-www-form-urlencoded")
- .body(AsyncBody::from(body))?;
+ .body(AsyncBody::from(body))
+ .map_err(|e| RefreshError::Transient(e.into()))?;
- let mut response = client.send(request).await?;
+ let mut response = client
+ .send(request)
+ .await
+ .map_err(|e| RefreshError::Transient(e))?;
+ let status = response.status();
let mut body = String::new();
- smol::io::AsyncReadExt::read_to_string(response.body_mut(), &mut body).await?;
-
- if !response.status().is_success() {
- return Err(anyhow!(
- "Token refresh failed (HTTP {}): {body}",
- response.status()
- ));
+ smol::io::AsyncReadExt::read_to_string(response.body_mut(), &mut body)
+ .await
+ .map_err(|e| RefreshError::Transient(e.into()))?;
+
+ if !status.is_success() {
+ let err = anyhow!("Token refresh failed (HTTP {}): {body}", status);
+ // 400/401/403 indicate a revoked or invalid refresh token.
+ // 5xx and other errors are treated as transient.
+ if status == http_client::StatusCode::BAD_REQUEST
+ || status == http_client::StatusCode::UNAUTHORIZED
+ || status == http_client::StatusCode::FORBIDDEN
+ {
+ return Err(RefreshError::Fatal(err));
+ }
+ return Err(RefreshError::Transient(err));
}
- let tokens: TokenResponse = serde_json::from_str(&body)?;
+ let tokens: TokenResponse =
+ serde_json::from_str(&body).map_err(|e| RefreshError::Transient(e.into()))?;
let jwt = tokens
.id_token
.as_deref()
@@ -876,6 +917,7 @@ fn do_sign_in(state: &Entity<State>, http_client: &Arc<dyn HttpClient>, cx: &mut
.update(cx, |s, cx| {
s.credentials = Some(creds);
s.sign_in_task = None;
+ s.last_auth_error = None;
cx.notify();
})
.log_err();
@@ -887,6 +929,8 @@ fn do_sign_in(state: &Entity<State>, http_client: &Arc<dyn HttpClient>, cx: &mut
weak_state
.update(cx, |s, cx| {
s.sign_in_task = None;
+ s.last_auth_error =
+ Some("Failed to save credentials. Please try again.".into());
cx.notify();
})
.log_err();
@@ -898,6 +942,7 @@ fn do_sign_in(state: &Entity<State>, http_client: &Arc<dyn HttpClient>, cx: &mut
weak_state
.update(cx, |s, cx| {
s.sign_in_task = None;
+ s.last_auth_error = Some("Sign-in failed. Please try again.".into());
cx.notify();
})
.log_err();
@@ -907,28 +952,36 @@ fn do_sign_in(state: &Entity<State>, http_client: &Arc<dyn HttpClient>, cx: &mut
});
state.update(cx, |s, cx| {
+ s.last_auth_error = None;
s.sign_in_task = Some(task);
cx.notify();
});
}
-fn do_sign_out(state: &gpui::WeakEntity<State>, cx: &mut App) {
+fn do_sign_out(state: &gpui::WeakEntity<State>, cx: &mut App) -> Task<Result<()>> {
let weak_state = state.clone();
+ // Clear credentials and cancel in-flight work immediately so the UI
+ // reflects the sign-out right away.
+ weak_state
+ .update(cx, |s, cx| {
+ s.auth_generation += 1;
+ s.credentials = None;
+ s.sign_in_task = None;
+ s.refresh_task = None;
+ s.last_auth_error = None;
+ cx.notify();
+ })
+ .log_err();
+
cx.spawn(async move |cx| {
let credentials_provider =
weak_state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
credentials_provider
.delete_credentials(CREDENTIALS_KEY, &*cx)
.await
- .log_err();
- weak_state.update(cx, |s, cx| {
- s.credentials = None;
- s.sign_in_task = None;
- cx.notify();
- })?;
+ .context("Failed to delete ChatGPT subscription credentials from keychain")?;
anyhow::Ok(())
})
- .detach();
}
struct ConfigurationView {
@@ -952,7 +1005,7 @@ impl Render for ConfigurationView {
ConfiguredApiCard::new(SharedString::from(label))
.button_label("Sign Out")
.on_click(cx.listener(move |_this, _, _window, cx| {
- do_sign_out(&weak_state, cx);
+ do_sign_out(&weak_state, cx).detach_and_log_err(cx);
})),
)
.into_any_element();
@@ -964,11 +1017,15 @@ impl Render for ConfigurationView {
.into_any_element();
}
+ let last_auth_error = state.last_auth_error.clone();
let provider_state = self.state.clone();
let http_client = self.http_client.clone();
v_flex()
.gap_2()
+ .when_some(last_auth_error, |this, error| {
+ this.child(Label::new(error).color(Color::Error))
+ })
.child(Label::new(
"Sign in with your ChatGPT Plus or Pro subscription to use OpenAI models in Zed's agent.",
))
@@ -1085,7 +1142,10 @@ mod tests {
credentials: Some(make_expired_credentials()),
sign_in_task: None,
refresh_task: None,
+ load_task: None,
credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+ auth_generation: 0,
+ last_auth_error: None,
});
let weak_state = cx.read(|_cx| state.downgrade());
@@ -1138,7 +1198,10 @@ mod tests {
credentials: Some(make_fresh_credentials()),
sign_in_task: None,
refresh_task: None,
+ load_task: None,
credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+ auth_generation: 0,
+ last_auth_error: None,
});
let weak_state = cx.read(|_cx| state.downgrade());
@@ -1171,7 +1234,10 @@ mod tests {
credentials: None,
sign_in_task: None,
refresh_task: None,
+ load_task: None,
credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+ auth_generation: 0,
+ last_auth_error: None,
});
let weak_state = cx.read(|_cx| state.downgrade());
@@ -1188,4 +1254,225 @@ mod tests {
Err(LanguageModelCompletionError::NoApiKey { .. })
));
}
+
+ #[gpui::test]
+ async fn test_fatal_refresh_clears_auth_state(cx: &mut TestAppContext) {
+ let http_client = FakeHttpClient::create(move |_request| async move {
+ Ok(http_client::Response::builder()
+ .status(401)
+ .body(http_client::AsyncBody::from(r#"{"error":"invalid_grant"}"#))?)
+ });
+
+ let creds_provider = Arc::new(FakeCredentialsProvider::new());
+ let state = cx.new(|_cx| State {
+ credentials: Some(make_expired_credentials()),
+ sign_in_task: None,
+ refresh_task: None,
+ load_task: None,
+ credentials_provider: creds_provider.clone(),
+ auth_generation: 0,
+ last_auth_error: None,
+ });
+
+ let weak_state = cx.read(|_cx| state.downgrade());
+ let http: Arc<dyn HttpClient> = http_client;
+
+ let weak = weak_state.clone();
+ let http_clone = http.clone();
+ let result = cx
+ .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
+ .await;
+
+ cx.run_until_parked();
+
+ assert!(result.is_err(), "fatal refresh should return an error");
+ cx.read(|cx| {
+ let s = state.read(cx);
+ assert!(
+ s.credentials.is_none(),
+ "credentials should be cleared on fatal refresh failure"
+ );
+ assert!(
+ s.last_auth_error.is_some(),
+ "last_auth_error should be set on fatal refresh failure"
+ );
+ });
+ }
+
+ #[gpui::test]
+ async fn test_transient_refresh_keeps_credentials(cx: &mut TestAppContext) {
+ let http_client = FakeHttpClient::create(move |_request| async move {
+ Ok(http_client::Response::builder()
+ .status(500)
+ .body(http_client::AsyncBody::from("Internal Server Error"))?)
+ });
+
+ let state = cx.new(|_cx| State {
+ credentials: Some(make_expired_credentials()),
+ sign_in_task: None,
+ refresh_task: None,
+ load_task: None,
+ credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+ auth_generation: 0,
+ last_auth_error: None,
+ });
+
+ let weak_state = cx.read(|_cx| state.downgrade());
+ let http: Arc<dyn HttpClient> = http_client;
+
+ let weak = weak_state.clone();
+ let http_clone = http.clone();
+ let result = cx
+ .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
+ .await;
+
+ cx.run_until_parked();
+
+ assert!(result.is_err(), "transient refresh should return an error");
+ cx.read(|cx| {
+ let s = state.read(cx);
+ assert!(
+ s.credentials.is_some(),
+ "credentials should be kept on transient refresh failure"
+ );
+ assert!(
+ s.last_auth_error.is_none(),
+ "last_auth_error should not be set on transient refresh failure"
+ );
+ });
+ }
+
+ #[gpui::test]
+ async fn test_sign_out_during_refresh_discards_result(cx: &mut TestAppContext) {
+ let (gate_tx, gate_rx) = futures::channel::oneshot::channel::<()>();
+ let gate_rx = Arc::new(Mutex::new(Some(gate_rx)));
+ let gate_rx_clone = gate_rx.clone();
+
+ let http_client = FakeHttpClient::create(move |_request| {
+ let gate_rx = gate_rx_clone.clone();
+ async move {
+ // Wait until the gate is opened, simulating a slow network.
+ let rx = gate_rx.lock().take();
+ if let Some(rx) = rx {
+ let _ = rx.await;
+ }
+ let body = fake_token_response();
+ Ok(http_client::Response::builder()
+ .status(200)
+ .body(http_client::AsyncBody::from(body))?)
+ }
+ });
+
+ let creds_provider = Arc::new(FakeCredentialsProvider::new());
+ let state = cx.new(|_cx| State {
+ credentials: Some(make_expired_credentials()),
+ sign_in_task: None,
+ refresh_task: None,
+ load_task: None,
+ credentials_provider: creds_provider.clone(),
+ auth_generation: 0,
+ last_auth_error: None,
+ });
+
+ let weak_state = cx.read(|_cx| state.downgrade());
+ let http: Arc<dyn HttpClient> = http_client;
+
+ // Start a refresh
+ let weak = weak_state.clone();
+ let http_clone = http.clone();
+ let refresh_task =
+ cx.spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await);
+
+ cx.run_until_parked();
+
+ // Sign out while the refresh is in-flight
+ cx.update(|cx| {
+ do_sign_out(&weak_state, cx).detach();
+ });
+ cx.run_until_parked();
+
+ // Now let the refresh respond by opening the gate
+ let _ = gate_tx.send(());
+ cx.run_until_parked();
+
+ let result = refresh_task.await;
+ assert!(result.is_err(), "refresh should fail after sign-out");
+
+ cx.read(|cx| {
+ let s = state.read(cx);
+ assert!(
+ s.credentials.is_none(),
+ "sign-out should have cleared credentials"
+ );
+ });
+ }
+
+ #[gpui::test]
+ async fn test_sign_out_completes_fully(cx: &mut TestAppContext) {
+ let creds_provider = Arc::new(FakeCredentialsProvider::new());
+ // Pre-populate the credential store
+ creds_provider
+ .storage
+ .lock()
+ .replace(("Bearer".to_string(), b"some-creds".to_vec()));
+
+ let state = cx.new(|_cx| State {
+ credentials: Some(make_fresh_credentials()),
+ sign_in_task: None,
+ refresh_task: None,
+ load_task: None,
+ credentials_provider: creds_provider.clone(),
+ auth_generation: 0,
+ last_auth_error: None,
+ });
+
+ let weak_state = cx.read(|_cx| state.downgrade());
+ let sign_out_task = cx.update(|cx| do_sign_out(&weak_state, cx));
+
+ cx.run_until_parked();
+ sign_out_task.await.expect("sign-out should succeed");
+
+ assert!(
+ creds_provider.storage.lock().is_none(),
+ "credential store should be empty after sign-out"
+ );
+ cx.read(|cx| {
+ assert!(
+ !state.read(cx).is_authenticated(),
+ "state should show not authenticated"
+ );
+ });
+ }
+
+ #[gpui::test]
+ async fn test_authenticate_awaits_initial_load(cx: &mut TestAppContext) {
+ let creds = make_fresh_credentials();
+ let creds_json = serde_json::to_vec(&creds).unwrap();
+ let creds_provider = Arc::new(FakeCredentialsProvider::new());
+ creds_provider
+ .storage
+ .lock()
+ .replace(("Bearer".to_string(), creds_json));
+
+ let http_client = FakeHttpClient::create(|_| async {
+ Ok(http_client::Response::builder()
+ .status(200)
+ .body(http_client::AsyncBody::default())?)
+ });
+
+ let provider =
+ cx.update(|cx| OpenAiSubscribedProvider::new(http_client, creds_provider, cx));
+
+ // Before load completes, authenticate should still await the load.
+ let auth_task = cx.update(|cx| provider.authenticate(cx));
+
+ // Drive the load to completion.
+ cx.run_until_parked();
+
+ let result = auth_task.await;
+ assert!(
+ result.is_ok(),
+ "authenticate should succeed after load completes with valid credentials"
+ );
+ }
}