Address code review: shared OAuth callback server, auth lifecycle fixes, token counting

Richard Feldman created

- Extract OAuth callback server into http_client crate for reuse
- HTML-escape oauth_callback_page parameters, add error styling
- Update context_server to delegate to shared OAuth callback server
- Rewrite openai_subscribed OAuth flow: ephemeral port, browser opens after listener
- Add auth_generation counter to prevent stale refresh writes after sign-out
- Make do_sign_out awaitable, cancel in-flight work immediately
- Distinguish fatal (400/401/403) vs transient (5xx) refresh errors
- Clear credentials on fatal refresh, keep on transient
- Make authenticate() await initial credential load
- Surface auth errors in ConfigurationView UI
- Implement real token counting via tiktoken
- Add tests for fatal/transient refresh, sign-out during refresh, authenticate-awaits-load

Change summary

Cargo.lock                                               |   3 
crates/context_server/Cargo.toml                         |   1 
crates/context_server/src/oauth.rs                       | 169 --
crates/http_client/Cargo.toml                            |   1 
crates/http_client/src/http_client.rs                    | 220 +++
crates/language_models/src/provider/openai_subscribed.rs | 557 +++++++--
6 files changed, 663 insertions(+), 288 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3454,7 +3454,6 @@ dependencies = [
  "smol",
  "tempfile",
  "terminal",
- "tiny_http",
  "url",
  "util",
 ]
@@ -8217,6 +8216,7 @@ dependencies = [
  "serde_urlencoded",
  "sha2",
  "tempfile",
+ "tiny_http",
  "url",
  "util",
 ]
@@ -9388,6 +9388,7 @@ dependencies = [
  "open_ai",
  "open_router",
  "opencode",
+ "parking_lot",
  "partial-json-fixer",
  "pretty_assertions",
  "rand 0.9.2",

crates/context_server/Cargo.toml 🔗

@@ -35,7 +35,6 @@ sha2.workspace = true
 slotmap.workspace = true
 smol.workspace = true
 tempfile.workspace = true
-tiny_http.workspace = true
 url = { workspace = true, features = ["serde"] }
 util.workspace = true
 terminal.workspace = true

crates/context_server/src/oauth.rs 🔗

@@ -27,11 +27,9 @@ use rand::Rng as _;
 use serde::{Deserialize, Serialize};
 use sha2::{Digest, Sha256};
 
-use std::str::FromStr;
 use std::sync::Arc;
 use std::time::{Duration, SystemTime};
 use url::Url;
-use util::ResultExt as _;
 
 /// The CIMD URL where Zed's OAuth client metadata document is hosted.
 pub const CIMD_URL: &str = "https://zed.dev/oauth/client-metadata.json";
@@ -992,58 +990,14 @@ impl OAuthCallback {
     /// Parse the query string from a callback URL like
     /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
     pub fn parse_query(query: &str) -> Result<Self> {
-        let mut code: Option<String> = None;
-        let mut state: Option<String> = None;
-        let mut error: Option<String> = None;
-        let mut error_description: Option<String> = None;
-
-        for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
-            match key.as_ref() {
-                "code" => {
-                    if !value.is_empty() {
-                        code = Some(value.into_owned());
-                    }
-                }
-                "state" => {
-                    if !value.is_empty() {
-                        state = Some(value.into_owned());
-                    }
-                }
-                "error" => {
-                    if !value.is_empty() {
-                        error = Some(value.into_owned());
-                    }
-                }
-                "error_description" => {
-                    if !value.is_empty() {
-                        error_description = Some(value.into_owned());
-                    }
-                }
-                _ => {}
-            }
-        }
-
-        // Check for OAuth error response (RFC 6749 Section 4.1.2.1) before
-        // checking for missing code/state.
-        if let Some(error_code) = error {
-            bail!(
-                "OAuth authorization failed: {} ({})",
-                error_code,
-                error_description.as_deref().unwrap_or("no description")
-            );
-        }
-
-        let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
-        let state = state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
-
-        Ok(Self { code, state })
+        let params = http_client::OAuthCallbackParams::parse_query(query)?;
+        Ok(Self {
+            code: params.code,
+            state: params.state,
+        })
     }
 }
 
-/// How long to wait for the browser to complete the OAuth flow before giving
-/// up and releasing the loopback port.
-const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
-
 /// Start a loopback HTTP server to receive the OAuth authorization callback.
 ///
 /// Binds to an ephemeral loopback port for each flow.
@@ -1057,107 +1011,26 @@ const CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
 /// HTML page telling the user they can close the tab, and shuts down.
 ///
 /// The callback server shuts down when the returned oneshot receiver is dropped
-/// (e.g. because the authentication task was cancelled), or after a timeout
-/// ([CALLBACK_TIMEOUT]).
+/// (e.g. because the authentication task was cancelled), or after a timeout.
 pub async fn start_callback_server() -> Result<(
     String,
     futures::channel::oneshot::Receiver<Result<OAuthCallback>>,
 )> {
-    let server = tiny_http::Server::http("127.0.0.1:0")
-        .map_err(|e| anyhow!(e).context("Failed to bind loopback listener for OAuth callback"))?;
-    let port = server
-        .server_addr()
-        .to_ip()
-        .context("server not bound to a TCP address")?
-        .port();
-
-    let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
-
-    let (tx, rx) = futures::channel::oneshot::channel();
-
-    // `tiny_http` is blocking, so we run it on a background thread.
-    // The `recv_timeout` loop lets us check for cancellation (the receiver
-    // being dropped) and enforce an overall timeout.
-    std::thread::spawn(move || {
-        let deadline = std::time::Instant::now() + CALLBACK_TIMEOUT;
-
-        loop {
-            if tx.is_canceled() {
-                return;
-            }
-            let remaining = deadline.saturating_duration_since(std::time::Instant::now());
-            if remaining.is_zero() {
-                return;
-            }
-
-            let timeout = remaining.min(Duration::from_millis(500));
-            let Some(request) = (match server.recv_timeout(timeout) {
-                Ok(req) => req,
-                Err(_) => {
-                    let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
-                    return;
-                }
-            }) else {
-                // Timeout with no request — loop back and check cancellation.
-                continue;
-            };
-
-            let result = handle_callback_request(&request);
-
-            let (status_code, body) = match &result {
-                Ok(_) => (
-                    200,
-                    http_client::oauth_callback_page(
-                        "Authorization Successful",
-                        "You can close this tab and return to Zed.",
-                    ),
-                ),
-                Err(err) => {
-                    log::error!("OAuth callback error: {}", err);
-                    (
-                        400,
-                        http_client::oauth_callback_page(
-                            "Authorization Failed",
-                            "Something went wrong. Please try again from Zed.",
-                        ),
-                    )
-                }
-            };
-
-            let response = tiny_http::Response::from_string(body)
-                .with_status_code(status_code)
-                .with_header(
-                    tiny_http::Header::from_str("Content-Type: text/html")
-                        .expect("failed to construct response header"),
-                )
-                .with_header(
-                    tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
-                        .expect("failed to construct response header"),
-                );
-            request.respond(response).log_err();
-
-            let _ = tx.send(result);
-            return;
-        }
-    });
-
-    Ok((redirect_uri, rx))
-}
-
-/// Extract the `code` and `state` query parameters from an OAuth callback
-/// request to `/callback`.
-fn handle_callback_request(request: &tiny_http::Request) -> Result<OAuthCallback> {
-    let url = Url::parse(&format!("http://localhost{}", request.url()))
-        .context("malformed callback request URL")?;
-
-    if url.path() != "/callback" {
-        bail!("unexpected path in OAuth callback: {}", url.path());
-    }
-
-    let query = url
-        .query()
-        .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
-    OAuthCallback::parse_query(query)
+    let (redirect_uri, rx) = http_client::start_oauth_callback_server()?;
+    let (tx, mapped_rx) = futures::channel::oneshot::channel();
+    smol::spawn(async move {
+        let result = match rx.await {
+            Ok(Ok(params)) => Ok(OAuthCallback {
+                code: params.code,
+                state: params.state,
+            }),
+            Ok(Err(e)) => Err(e),
+            Err(_) => Err(anyhow!("OAuth callback channel was cancelled")),
+        };
+        let _ = tx.send(result);
+    })
+    .detach();
+    Ok((redirect_uri, mapped_rx))
 }
 
 // -- JSON fetch helper -------------------------------------------------------

crates/http_client/Cargo.toml 🔗

@@ -37,3 +37,4 @@ async-fs.workspace = true
 async-tar.workspace = true
 sha2.workspace = true
 tempfile.workspace = true
+tiny_http.workspace = true

crates/http_client/src/http_client.rs 🔗

@@ -291,12 +291,43 @@ impl HttpClient for HttpClientWithUrl {
     }
 }
 
+fn html_escape(input: &str) -> String {
+    let mut output = String::with_capacity(input.len());
+    for ch in input.chars() {
+        match ch {
+            '&' => output.push_str("&amp;"),
+            '<' => output.push_str("&lt;"),
+            '>' => output.push_str("&gt;"),
+            '"' => output.push_str("&quot;"),
+            '\'' => output.push_str("&#x27;"),
+            _ => output.push(ch),
+        }
+    }
+    output
+}
+
 /// Generate a styled HTML page for OAuth callback responses.
 ///
 /// Returns a complete HTML document (no HTTP headers) with a centered card
 /// layout styled to match Zed's dark theme. The `title` is rendered as a
 /// heading and `message` as body text below it.
-pub fn oauth_callback_page(title: &str, message: &str) -> String {
+///
+/// When `is_error` is true, a red X icon is shown instead of the green
+/// checkmark.
+pub fn oauth_callback_page(title: &str, message: &str, is_error: bool) -> String {
+    let title = html_escape(title);
+    let message = html_escape(message);
+    let (icon_bg, icon_svg) = if is_error {
+        (
+            "#f38ba8",
+            r#"<svg viewBox="0 0 24 24"><line x1="18" y1="6" x2="6" y2="18"/><line x1="6" y1="6" x2="18" y2="18"/></svg>"#,
+        )
+    } else {
+        (
+            "#a6e3a1",
+            r#"<svg viewBox="0 0 24 24"><polyline points="20 6 9 17 4 12"/></svg>"#,
+        )
+    };
     format!(
         r#"<!DOCTYPE html>
 <html lang="en">
@@ -329,7 +360,7 @@ pub fn oauth_callback_page(title: &str, message: &str) -> String {
     width: 48px;
     height: 48px;
     margin: 0 auto 1.5rem;
-    background: #a6e3a1;
+    background: {icon_bg};
     border-radius: 50%;
     display: flex;
     align-items: center;
@@ -364,7 +395,7 @@ pub fn oauth_callback_page(title: &str, message: &str) -> String {
 <body>
 <div class="card">
   <div class="icon">
-    <svg viewBox="0 0 24 24"><polyline points="20 6 9 17 4 12"/></svg>
+    {icon_svg}
   </div>
   <h1>{title}</h1>
   <p>{message}</p>
@@ -374,6 +405,8 @@ pub fn oauth_callback_page(title: &str, message: &str) -> String {
 </html>"#,
         title = title,
         message = message,
+        icon_bg = icon_bg,
+        icon_svg = icon_svg,
     )
 }
 
@@ -529,3 +562,184 @@ impl HttpClient for FakeHttpClient {
         self
     }
 }
+
+// ---------------------------------------------------------------------------
+// Shared OAuth callback server (non-wasm only)
+// ---------------------------------------------------------------------------
+
+#[cfg(not(target_family = "wasm"))]
+mod oauth_callback_server {
+    use super::*;
+    use anyhow::Context as _;
+    use std::str::FromStr;
+    use std::time::Duration;
+
+    /// Parsed OAuth callback parameters from the authorization server redirect.
+    pub struct OAuthCallbackParams {
+        pub code: String,
+        pub state: String,
+    }
+
+    impl OAuthCallbackParams {
+        /// Parse the query string from a callback URL like
+        /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
+        pub fn parse_query(query: &str) -> Result<Self> {
+            let mut code: Option<String> = None;
+            let mut state: Option<String> = None;
+            let mut error: Option<String> = None;
+            let mut error_description: Option<String> = None;
+
+            for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
+                match key.as_ref() {
+                    "code" => {
+                        if !value.is_empty() {
+                            code = Some(value.into_owned());
+                        }
+                    }
+                    "state" => {
+                        if !value.is_empty() {
+                            state = Some(value.into_owned());
+                        }
+                    }
+                    "error" => {
+                        if !value.is_empty() {
+                            error = Some(value.into_owned());
+                        }
+                    }
+                    "error_description" => {
+                        if !value.is_empty() {
+                            error_description = Some(value.into_owned());
+                        }
+                    }
+                    _ => {}
+                }
+            }
+
+            if let Some(error_code) = error {
+                anyhow::bail!(
+                    "OAuth authorization failed: {} ({})",
+                    error_code,
+                    error_description.as_deref().unwrap_or("no description")
+                );
+            }
+
+            let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
+            let state =
+                state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
+
+            Ok(Self { code, state })
+        }
+    }
+
+    /// How long to wait for the browser to complete the OAuth flow before giving
+    /// up and releasing the loopback port.
+    const OAUTH_CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
+
+    /// Start a loopback HTTP server to receive the OAuth authorization callback.
+    ///
+    /// Binds to an ephemeral loopback port. Returns `(redirect_uri, callback_future)`.
+    /// The caller should use the redirect URI in the authorization request, open
+    /// the browser, then await the future to receive the callback.
+    pub fn start_oauth_callback_server() -> Result<(
+        String,
+        futures::channel::oneshot::Receiver<Result<OAuthCallbackParams>>,
+    )> {
+        let server = tiny_http::Server::http("127.0.0.1:0").map_err(|e| {
+            anyhow!(e).context("Failed to bind loopback listener for OAuth callback")
+        })?;
+        let port = server
+            .server_addr()
+            .to_ip()
+            .ok_or_else(|| anyhow!("server not bound to a TCP address"))?
+            .port();
+
+        let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
+
+        let (tx, rx) = futures::channel::oneshot::channel();
+
+        std::thread::spawn(move || {
+            let deadline = std::time::Instant::now() + OAUTH_CALLBACK_TIMEOUT;
+
+            loop {
+                if tx.is_canceled() {
+                    return;
+                }
+                let remaining = deadline.saturating_duration_since(std::time::Instant::now());
+                if remaining.is_zero() {
+                    return;
+                }
+
+                let timeout = remaining.min(Duration::from_millis(500));
+                let Some(request) = (match server.recv_timeout(timeout) {
+                    Ok(req) => req,
+                    Err(_) => {
+                        let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
+                        return;
+                    }
+                }) else {
+                    continue;
+                };
+
+                let result = handle_oauth_callback_request(&request);
+
+                let (status_code, body) = match &result {
+                    Ok(_) => (
+                        200,
+                        oauth_callback_page(
+                            "Authorization Successful",
+                            "You can close this tab and return to Zed.",
+                            false,
+                        ),
+                    ),
+                    Err(err) => {
+                        log::error!("OAuth callback error: {}", err);
+                        (
+                            400,
+                            oauth_callback_page(
+                                "Authorization Failed",
+                                "Something went wrong. Please try again from Zed.",
+                                true,
+                            ),
+                        )
+                    }
+                };
+
+                let response = tiny_http::Response::from_string(body)
+                    .with_status_code(status_code)
+                    .with_header(
+                        tiny_http::Header::from_str("Content-Type: text/html")
+                            .expect("failed to construct response header"),
+                    )
+                    .with_header(
+                        tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
+                            .expect("failed to construct response header"),
+                    );
+                if let Err(err) = request.respond(response) {
+                    log::error!("Failed to send OAuth callback response: {}", err);
+                }
+
+                let _ = tx.send(result);
+                return;
+            }
+        });
+
+        Ok((redirect_uri, rx))
+    }
+
+    fn handle_oauth_callback_request(request: &tiny_http::Request) -> Result<OAuthCallbackParams> {
+        let url = Url::parse(&format!("http://localhost{}", request.url()))
+            .context("malformed callback request URL")?;
+
+        if url.path() != "/callback" {
+            anyhow::bail!("unexpected path in OAuth callback: {}", url.path());
+        }
+
+        let query = url
+            .query()
+            .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
+        OAuthCallbackParams::parse_query(query)
+    }
+}
+
+#[cfg(not(target_family = "wasm"))]
+pub use oauth_callback_server::{OAuthCallbackParams, start_oauth_callback_server};

crates/language_models/src/provider/openai_subscribed.rs 🔗

@@ -2,7 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
 use base64::Engine as _;
 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
 use credentials_provider::CredentialsProvider;
-use futures::{FutureExt, StreamExt, future::BoxFuture, future::Either, future::Shared};
+use futures::{FutureExt, StreamExt, future::BoxFuture, future::Shared};
 use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use language_model::{
@@ -15,9 +15,8 @@ use open_ai::{ReasoningEffort, responses::stream_response};
 use rand::RngCore as _;
 use serde::{Deserialize, Serialize};
 use sha2::{Digest, Sha256};
-use smol::io::{AsyncReadExt as _, AsyncWriteExt as _};
 use std::sync::Arc;
-use std::time::{Duration, SystemTime, UNIX_EPOCH};
+use std::time::{SystemTime, UNIX_EPOCH};
 use ui::{ConfiguredApiCard, prelude::*};
 use url::form_urlencoded;
 use util::ResultExt as _;
@@ -32,7 +31,7 @@ const CODEX_BASE_URL: &str = "https://chatgpt.com/backend-api/codex";
 const OPENAI_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
 const OPENAI_AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
 const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
-const REDIRECT_URI: &str = "http://localhost:1455/auth/callback";
+
 const CREDENTIALS_KEY: &str = "https://chatgpt.com/backend-api/codex";
 const TOKEN_REFRESH_BUFFER_MS: u64 = 5 * 60 * 1000;
 
@@ -56,7 +55,25 @@ pub struct State {
     credentials: Option<CodexCredentials>,
     sign_in_task: Option<Task<Result<()>>>,
     refresh_task: Option<Shared<Task<Result<CodexCredentials, Arc<anyhow::Error>>>>>,
+    load_task: Option<Shared<Task<Result<(), Arc<anyhow::Error>>>>>,
     credentials_provider: Arc<dyn CredentialsProvider>,
+    auth_generation: u64,
+    last_auth_error: Option<SharedString>,
+}
+
+#[derive(Debug)]
+enum RefreshError {
+    Fatal(anyhow::Error),
+    Transient(anyhow::Error),
+}
+
+impl std::fmt::Display for RefreshError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            RefreshError::Fatal(e) => write!(f, "{e}"),
+            RefreshError::Transient(e) => write!(f, "{e}"),
+        }
+    }
 }
 
 impl State {
@@ -88,7 +105,10 @@ impl OpenAiSubscribedProvider {
             credentials: None,
             sign_in_task: None,
             refresh_task: None,
+            load_task: None,
             credentials_provider,
+            auth_generation: 0,
+            last_auth_error: None,
         });
 
         let provider = Self { http_client, state };
@@ -100,31 +120,38 @@ impl OpenAiSubscribedProvider {
 
     fn load_credentials(&self, cx: &mut App) {
         let state = self.state.downgrade();
-        cx.spawn(async move |cx| {
-            let credentials_provider =
-                state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
-            let result = credentials_provider
-                .read_credentials(CREDENTIALS_KEY, &*cx)
-                .await;
-            state.update(cx, |s, cx| {
-                if let Ok(Some((_, bytes))) = result {
-                    match serde_json::from_slice::<CodexCredentials>(&bytes) {
-                        Ok(creds) => s.credentials = Some(creds),
-                        Err(err) => {
-                            log::warn!(
-                                "Failed to deserialize ChatGPT subscription credentials: {err}"
-                            );
+        let load_task = cx
+            .spawn(async move |cx| {
+                let credentials_provider =
+                    state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
+                let result = credentials_provider
+                    .read_credentials(CREDENTIALS_KEY, &*cx)
+                    .await;
+                state.update(cx, |s, cx| {
+                    if let Ok(Some((_, bytes))) = result {
+                        match serde_json::from_slice::<CodexCredentials>(&bytes) {
+                            Ok(creds) => s.credentials = Some(creds),
+                            Err(err) => {
+                                log::warn!(
+                                    "Failed to deserialize ChatGPT subscription credentials: {err}"
+                                );
+                            }
                         }
                     }
-                }
-                cx.notify();
+                    s.load_task = None;
+                    cx.notify();
+                })?;
+                Ok::<(), Arc<anyhow::Error>>(())
             })
-        })
-        .detach();
+            .shared();
+
+        self.state.update(cx, |s, _| {
+            s.load_task = Some(load_task);
+        });
     }
 
-    fn sign_out(&self, cx: &mut App) {
-        do_sign_out(&self.state.downgrade(), cx);
+    fn sign_out(&self, cx: &mut App) -> Task<Result<()>> {
+        do_sign_out(&self.state.downgrade(), cx)
     }
 
     fn create_language_model(&self, model: ChatGptModel) -> Arc<dyn LanguageModel> {
@@ -182,10 +209,29 @@ impl LanguageModelProvider for OpenAiSubscribedProvider {
         if self.is_authenticated(cx) {
             return Task::ready(Ok(()));
         }
-        Task::ready(Err(anyhow!(
-            "Sign in with your ChatGPT Plus or Pro subscription to use this provider."
-        )
-        .into()))
+        let load_task = self.state.read(cx).load_task.clone();
+        if let Some(load_task) = load_task {
+            let weak_state = self.state.downgrade();
+            cx.spawn(async move |cx| {
+                let _ = load_task.await;
+                let is_auth = weak_state
+                    .read_with(&*cx, |s, _| s.is_authenticated())
+                    .unwrap_or(false);
+                if is_auth {
+                    Ok(())
+                } else {
+                    Err(anyhow!(
+                        "Sign in with your ChatGPT Plus or Pro subscription to use this provider."
+                    )
+                    .into())
+                }
+            })
+        } else {
+            Task::ready(Err(anyhow!(
+                "Sign in with your ChatGPT Plus or Pro subscription to use this provider."
+            )
+            .into()))
+        }
     }
 
     fn configuration_view(
@@ -201,8 +247,7 @@ impl LanguageModelProvider for OpenAiSubscribedProvider {
     }
 
     fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
-        self.sign_out(cx);
-        Task::ready(Ok(()))
+        self.sign_out(cx)
     }
 }
 
@@ -380,10 +425,20 @@ impl LanguageModel for OpenAiSubscribedLanguageModel {
 
     fn count_tokens(
         &self,
-        _request: LanguageModelRequest,
-        _cx: &App,
+        request: LanguageModelRequest,
+        cx: &App,
     ) -> BoxFuture<'static, Result<u64>> {
-        futures::future::ready(Ok(0)).boxed()
+        let max_token_count = self.model.max_token_count();
+        cx.background_spawn(async move {
+            let messages = crate::provider::open_ai::collect_tiktoken_messages(request);
+            let model = if max_token_count >= 100_000 {
+                "gpt-4o"
+            } else {
+                "gpt-4"
+            };
+            tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
+        })
+        .boxed()
     }
 
     fn stream_completion(
@@ -501,6 +556,11 @@ async fn get_fresh_credentials(
     let state_clone = state.clone();
     let refresh_token_value = creds.refresh_token.clone();
 
+    // Capture the generation so we can detect sign-outs that happened during refresh.
+    let generation = state
+        .read_with(&*cx, |s, _| s.auth_generation)
+        .map_err(LanguageModelCompletionError::Other)?;
+
     let shared_task = cx
         .spawn(async move |cx| {
             let result = refresh_token(&http_client_clone, &refresh_token_value).await;
@@ -508,6 +568,16 @@ async fn get_fresh_credentials(
             match result {
                 Ok(refreshed) => {
                     let persist_result: Result<CodexCredentials, Arc<anyhow::Error>> = async {
+                        // Check if auth_generation changed (sign-out during refresh).
+                        let current_generation = state_clone
+                            .read_with(&*cx, |s, _| s.auth_generation)
+                            .map_err(|e| Arc::new(e))?;
+                        if current_generation != generation {
+                            return Err(Arc::new(anyhow!(
+                                "Sign-out occurred during token refresh"
+                            )));
+                        }
+
                         let credentials_provider = state_clone
                             .read_with(&*cx, |s, _| s.credentials_provider.clone())
                             .map_err(|e| Arc::new(e))?;
@@ -540,7 +610,28 @@ async fn get_fresh_credentials(
 
                     persist_result
                 }
-                Err(e) => {
+                Err(RefreshError::Fatal(e)) => {
+                    log::error!("ChatGPT subscription token refresh failed fatally: {e:?}");
+                    let _ = state_clone.update(cx, |s, cx| {
+                        s.refresh_task = None;
+                        s.credentials = None;
+                        s.last_auth_error =
+                            Some("Your session has expired. Please sign in again.".into());
+                        cx.notify();
+                    });
+                    // Also clear the keychain so stale credentials aren't loaded next time.
+                    if let Ok(credentials_provider) =
+                        state_clone.read_with(&*cx, |s, _| s.credentials_provider.clone())
+                    {
+                        credentials_provider
+                            .delete_credentials(CREDENTIALS_KEY, &*cx)
+                            .await
+                            .log_err();
+                    }
+                    Err(Arc::new(e))
+                }
+                Err(RefreshError::Transient(e)) => {
+                    log::warn!("ChatGPT subscription token refresh failed transiently: {e:?}");
                     let _ = state_clone.update(cx, |s, _| {
                         s.refresh_task = None;
                     });
@@ -577,6 +668,10 @@ async fn do_oauth_flow(
     http_client: Arc<dyn HttpClient>,
     cx: &AsyncApp,
 ) -> Result<CodexCredentials> {
+    // Start the callback server FIRST so the redirect URI is ready
+    let (redirect_uri, callback_rx) = http_client::start_oauth_callback_server()
+        .context("Failed to start OAuth callback server")?;
+
     // PKCE verifier: 32 random bytes → base64url (no padding)
     let mut verifier_bytes = [0u8; 32];
     rand::rng().fill_bytes(&mut verifier_bytes);
@@ -596,7 +691,7 @@ async fn do_oauth_flow(
     auth_url
         .query_pairs_mut()
         .append_pair("client_id", CLIENT_ID)
-        .append_pair("redirect_uri", REDIRECT_URI)
+        .append_pair("redirect_uri", &redirect_uri)
         .append_pair("scope", "openid profile email offline_access")
         .append_pair("response_type", "code")
         .append_pair("code_challenge", &challenge)
@@ -605,13 +700,21 @@ async fn do_oauth_flow(
         .append_pair("codex_cli_simplified_flow", "true")
         .append_pair("originator", "zed");
 
+    // Open browser AFTER the listener is ready
     cx.update(|cx| cx.open_url(auth_url.as_str()));
 
-    let code = await_oauth_callback(&oauth_state, cx)
+    // Await the callback
+    let callback = callback_rx
         .await
+        .map_err(|_| anyhow!("OAuth callback was cancelled"))?
         .context("OAuth callback failed")?;
 
-    let tokens = exchange_code(&http_client, &code, &verifier)
+    // Validate CSRF state
+    if callback.state != oauth_state {
+        return Err(anyhow!("OAuth state mismatch"));
+    }
+
+    let tokens = exchange_code(&http_client, &callback.code, &verifier, &redirect_uri)
         .await
         .context("Token exchange failed")?;
 
@@ -630,93 +733,17 @@ async fn do_oauth_flow(
     })
 }
 
-async fn await_oauth_callback(expected_state: &str, cx: &AsyncApp) -> Result<String> {
-    let listener = smol::net::TcpListener::bind("127.0.0.1:1455")
-        .await
-        .context("Failed to bind to port 1455 for OAuth callback. Another application may be using this port.")?;
-
-    let accept_future = listener.accept();
-    let timeout_future = cx.background_executor().timer(Duration::from_secs(120));
-
-    let (mut stream, _) = match futures::future::select(
-        std::pin::pin!(accept_future),
-        std::pin::pin!(timeout_future),
-    )
-    .await
-    {
-        Either::Left((result, _)) => result?,
-        Either::Right((_, _)) => {
-            return Err(anyhow!(
-                "OAuth sign-in timed out after 2 minutes. Please try again."
-            ));
-        }
-    };
-
-    let mut buffer = vec![0u8; 4096];
-    let mut total = 0;
-    loop {
-        if total >= buffer.len() {
-            return Err(anyhow!("OAuth callback request too large"));
-        }
-        let n = stream.read(&mut buffer[total..]).await?;
-        if n == 0 {
-            break;
-        }
-        total += n;
-        if buffer[..total].windows(4).any(|w| w == b"\r\n\r\n") {
-            break;
-        }
-    }
-    let request_text = std::str::from_utf8(&buffer[..total])?;
-
-    // First line: "GET /auth/callback?code=...&state=... HTTP/1.1"
-    let path = request_text
-        .lines()
-        .next()
-        .and_then(|line| line.split_whitespace().nth(1))
-        .ok_or_else(|| anyhow!("Invalid HTTP request from browser"))?;
-
-    let query = path.split('?').nth(1).unwrap_or("");
-    let mut code: Option<String> = None;
-    let mut received_state: Option<String> = None;
-    for (key, value) in form_urlencoded::parse(query.as_bytes()) {
-        match key.as_ref() {
-            "code" => code = Some(value.into_owned()),
-            "state" => received_state = Some(value.into_owned()),
-            _ => {}
-        }
-    }
-
-    let page = http_client::oauth_callback_page(
-        "Signed In",
-        "You've signed into Zed via your ChatGPT subscription. You can close this tab and return to Zed.",
-    );
-    let response = format!(
-        "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\n\r\n{}",
-        page.len(),
-        page
-    );
-    stream.write_all(response.as_bytes()).await.log_err();
-
-    let received_state =
-        received_state.ok_or_else(|| anyhow!("Missing state in OAuth callback"))?;
-    if received_state != expected_state {
-        return Err(anyhow!("OAuth state mismatch"));
-    }
-
-    code.ok_or_else(|| anyhow!("Missing authorization code in OAuth callback"))
-}
-
 async fn exchange_code(
     client: &Arc<dyn HttpClient>,
     code: &str,
     verifier: &str,
+    redirect_uri: &str,
 ) -> Result<TokenResponse> {
     let body = form_urlencoded::Serializer::new(String::new())
         .append_pair("grant_type", "authorization_code")
         .append_pair("client_id", CLIENT_ID)
         .append_pair("code", code)
-        .append_pair("redirect_uri", REDIRECT_URI)
+        .append_pair("redirect_uri", redirect_uri)
         .append_pair("code_verifier", verifier)
         .finish();
 
@@ -743,7 +770,7 @@ async fn exchange_code(
 async fn refresh_token(
     client: &Arc<dyn HttpClient>,
     refresh_token: &str,
-) -> Result<CodexCredentials> {
+) -> Result<CodexCredentials, RefreshError> {
     let body = form_urlencoded::Serializer::new(String::new())
         .append_pair("grant_type", "refresh_token")
         .append_pair("client_id", CLIENT_ID)
@@ -754,20 +781,34 @@ async fn refresh_token(
         .method(Method::POST)
         .uri(OPENAI_TOKEN_URL)
         .header("Content-Type", "application/x-www-form-urlencoded")
-        .body(AsyncBody::from(body))?;
+        .body(AsyncBody::from(body))
+        .map_err(|e| RefreshError::Transient(e.into()))?;
 
-    let mut response = client.send(request).await?;
+    let mut response = client
+        .send(request)
+        .await
+        .map_err(|e| RefreshError::Transient(e))?;
+    let status = response.status();
     let mut body = String::new();
-    smol::io::AsyncReadExt::read_to_string(response.body_mut(), &mut body).await?;
-
-    if !response.status().is_success() {
-        return Err(anyhow!(
-            "Token refresh failed (HTTP {}): {body}",
-            response.status()
-        ));
+    smol::io::AsyncReadExt::read_to_string(response.body_mut(), &mut body)
+        .await
+        .map_err(|e| RefreshError::Transient(e.into()))?;
+
+    if !status.is_success() {
+        let err = anyhow!("Token refresh failed (HTTP {}): {body}", status);
+        // 400/401/403 indicate a revoked or invalid refresh token.
+        // 5xx and other errors are treated as transient.
+        if status == http_client::StatusCode::BAD_REQUEST
+            || status == http_client::StatusCode::UNAUTHORIZED
+            || status == http_client::StatusCode::FORBIDDEN
+        {
+            return Err(RefreshError::Fatal(err));
+        }
+        return Err(RefreshError::Transient(err));
     }
 
-    let tokens: TokenResponse = serde_json::from_str(&body)?;
+    let tokens: TokenResponse =
+        serde_json::from_str(&body).map_err(|e| RefreshError::Transient(e.into()))?;
     let jwt = tokens
         .id_token
         .as_deref()
@@ -876,6 +917,7 @@ fn do_sign_in(state: &Entity<State>, http_client: &Arc<dyn HttpClient>, cx: &mut
                             .update(cx, |s, cx| {
                                 s.credentials = Some(creds);
                                 s.sign_in_task = None;
+                                s.last_auth_error = None;
                                 cx.notify();
                             })
                             .log_err();
@@ -887,6 +929,8 @@ fn do_sign_in(state: &Entity<State>, http_client: &Arc<dyn HttpClient>, cx: &mut
                         weak_state
                             .update(cx, |s, cx| {
                                 s.sign_in_task = None;
+                                s.last_auth_error =
+                                    Some("Failed to save credentials. Please try again.".into());
                                 cx.notify();
                             })
                             .log_err();
@@ -898,6 +942,7 @@ fn do_sign_in(state: &Entity<State>, http_client: &Arc<dyn HttpClient>, cx: &mut
                 weak_state
                     .update(cx, |s, cx| {
                         s.sign_in_task = None;
+                        s.last_auth_error = Some("Sign-in failed. Please try again.".into());
                         cx.notify();
                     })
                     .log_err();
@@ -907,28 +952,36 @@ fn do_sign_in(state: &Entity<State>, http_client: &Arc<dyn HttpClient>, cx: &mut
     });
 
     state.update(cx, |s, cx| {
+        s.last_auth_error = None;
         s.sign_in_task = Some(task);
         cx.notify();
     });
 }
 
-fn do_sign_out(state: &gpui::WeakEntity<State>, cx: &mut App) {
+fn do_sign_out(state: &gpui::WeakEntity<State>, cx: &mut App) -> Task<Result<()>> {
     let weak_state = state.clone();
+    // Clear credentials and cancel in-flight work immediately so the UI
+    // reflects the sign-out right away.
+    weak_state
+        .update(cx, |s, cx| {
+            s.auth_generation += 1;
+            s.credentials = None;
+            s.sign_in_task = None;
+            s.refresh_task = None;
+            s.last_auth_error = None;
+            cx.notify();
+        })
+        .log_err();
+
     cx.spawn(async move |cx| {
         let credentials_provider =
             weak_state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
         credentials_provider
             .delete_credentials(CREDENTIALS_KEY, &*cx)
             .await
-            .log_err();
-        weak_state.update(cx, |s, cx| {
-            s.credentials = None;
-            s.sign_in_task = None;
-            cx.notify();
-        })?;
+            .context("Failed to delete ChatGPT subscription credentials from keychain")?;
         anyhow::Ok(())
     })
-    .detach();
 }
 
 struct ConfigurationView {
@@ -952,7 +1005,7 @@ impl Render for ConfigurationView {
                     ConfiguredApiCard::new(SharedString::from(label))
                         .button_label("Sign Out")
                         .on_click(cx.listener(move |_this, _, _window, cx| {
-                            do_sign_out(&weak_state, cx);
+                            do_sign_out(&weak_state, cx).detach_and_log_err(cx);
                         })),
                 )
                 .into_any_element();
@@ -964,11 +1017,15 @@ impl Render for ConfigurationView {
                 .into_any_element();
         }
 
+        let last_auth_error = state.last_auth_error.clone();
         let provider_state = self.state.clone();
         let http_client = self.http_client.clone();
 
         v_flex()
             .gap_2()
+            .when_some(last_auth_error, |this, error| {
+                this.child(Label::new(error).color(Color::Error))
+            })
             .child(Label::new(
                 "Sign in with your ChatGPT Plus or Pro subscription to use OpenAI models in Zed's agent.",
             ))
@@ -1085,7 +1142,10 @@ mod tests {
             credentials: Some(make_expired_credentials()),
             sign_in_task: None,
             refresh_task: None,
+            load_task: None,
             credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+            auth_generation: 0,
+            last_auth_error: None,
         });
 
         let weak_state = cx.read(|_cx| state.downgrade());
@@ -1138,7 +1198,10 @@ mod tests {
             credentials: Some(make_fresh_credentials()),
             sign_in_task: None,
             refresh_task: None,
+            load_task: None,
             credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+            auth_generation: 0,
+            last_auth_error: None,
         });
 
         let weak_state = cx.read(|_cx| state.downgrade());
@@ -1171,7 +1234,10 @@ mod tests {
             credentials: None,
             sign_in_task: None,
             refresh_task: None,
+            load_task: None,
             credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+            auth_generation: 0,
+            last_auth_error: None,
         });
 
         let weak_state = cx.read(|_cx| state.downgrade());
@@ -1188,4 +1254,225 @@ mod tests {
             Err(LanguageModelCompletionError::NoApiKey { .. })
         ));
     }
+
+    #[gpui::test]
+    async fn test_fatal_refresh_clears_auth_state(cx: &mut TestAppContext) {
+        let http_client = FakeHttpClient::create(move |_request| async move {
+            Ok(http_client::Response::builder()
+                .status(401)
+                .body(http_client::AsyncBody::from(r#"{"error":"invalid_grant"}"#))?)
+        });
+
+        let creds_provider = Arc::new(FakeCredentialsProvider::new());
+        let state = cx.new(|_cx| State {
+            credentials: Some(make_expired_credentials()),
+            sign_in_task: None,
+            refresh_task: None,
+            load_task: None,
+            credentials_provider: creds_provider.clone(),
+            auth_generation: 0,
+            last_auth_error: None,
+        });
+
+        let weak_state = cx.read(|_cx| state.downgrade());
+        let http: Arc<dyn HttpClient> = http_client;
+
+        let weak = weak_state.clone();
+        let http_clone = http.clone();
+        let result = cx
+            .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
+            .await;
+
+        cx.run_until_parked();
+
+        assert!(result.is_err(), "fatal refresh should return an error");
+        cx.read(|cx| {
+            let s = state.read(cx);
+            assert!(
+                s.credentials.is_none(),
+                "credentials should be cleared on fatal refresh failure"
+            );
+            assert!(
+                s.last_auth_error.is_some(),
+                "last_auth_error should be set on fatal refresh failure"
+            );
+        });
+    }
+
+    #[gpui::test]
+    async fn test_transient_refresh_keeps_credentials(cx: &mut TestAppContext) {
+        let http_client = FakeHttpClient::create(move |_request| async move {
+            Ok(http_client::Response::builder()
+                .status(500)
+                .body(http_client::AsyncBody::from("Internal Server Error"))?)
+        });
+
+        let state = cx.new(|_cx| State {
+            credentials: Some(make_expired_credentials()),
+            sign_in_task: None,
+            refresh_task: None,
+            load_task: None,
+            credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+            auth_generation: 0,
+            last_auth_error: None,
+        });
+
+        let weak_state = cx.read(|_cx| state.downgrade());
+        let http: Arc<dyn HttpClient> = http_client;
+
+        let weak = weak_state.clone();
+        let http_clone = http.clone();
+        let result = cx
+            .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
+            .await;
+
+        cx.run_until_parked();
+
+        assert!(result.is_err(), "transient refresh should return an error");
+        cx.read(|cx| {
+            let s = state.read(cx);
+            assert!(
+                s.credentials.is_some(),
+                "credentials should be kept on transient refresh failure"
+            );
+            assert!(
+                s.last_auth_error.is_none(),
+                "last_auth_error should not be set on transient refresh failure"
+            );
+        });
+    }
+
+    #[gpui::test]
+    async fn test_sign_out_during_refresh_discards_result(cx: &mut TestAppContext) {
+        let (gate_tx, gate_rx) = futures::channel::oneshot::channel::<()>();
+        let gate_rx = Arc::new(Mutex::new(Some(gate_rx)));
+        let gate_rx_clone = gate_rx.clone();
+
+        let http_client = FakeHttpClient::create(move |_request| {
+            let gate_rx = gate_rx_clone.clone();
+            async move {
+                // Wait until the gate is opened, simulating a slow network.
+                let rx = gate_rx.lock().take();
+                if let Some(rx) = rx {
+                    let _ = rx.await;
+                }
+                let body = fake_token_response();
+                Ok(http_client::Response::builder()
+                    .status(200)
+                    .body(http_client::AsyncBody::from(body))?)
+            }
+        });
+
+        let creds_provider = Arc::new(FakeCredentialsProvider::new());
+        let state = cx.new(|_cx| State {
+            credentials: Some(make_expired_credentials()),
+            sign_in_task: None,
+            refresh_task: None,
+            load_task: None,
+            credentials_provider: creds_provider.clone(),
+            auth_generation: 0,
+            last_auth_error: None,
+        });
+
+        let weak_state = cx.read(|_cx| state.downgrade());
+        let http: Arc<dyn HttpClient> = http_client;
+
+        // Start a refresh
+        let weak = weak_state.clone();
+        let http_clone = http.clone();
+        let refresh_task =
+            cx.spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await);
+
+        cx.run_until_parked();
+
+        // Sign out while the refresh is in-flight
+        cx.update(|cx| {
+            do_sign_out(&weak_state, cx).detach();
+        });
+        cx.run_until_parked();
+
+        // Now let the refresh respond by opening the gate
+        let _ = gate_tx.send(());
+        cx.run_until_parked();
+
+        let result = refresh_task.await;
+        assert!(result.is_err(), "refresh should fail after sign-out");
+
+        cx.read(|cx| {
+            let s = state.read(cx);
+            assert!(
+                s.credentials.is_none(),
+                "sign-out should have cleared credentials"
+            );
+        });
+    }
+
+    #[gpui::test]
+    async fn test_sign_out_completes_fully(cx: &mut TestAppContext) {
+        let creds_provider = Arc::new(FakeCredentialsProvider::new());
+        // Pre-populate the credential store
+        creds_provider
+            .storage
+            .lock()
+            .replace(("Bearer".to_string(), b"some-creds".to_vec()));
+
+        let state = cx.new(|_cx| State {
+            credentials: Some(make_fresh_credentials()),
+            sign_in_task: None,
+            refresh_task: None,
+            load_task: None,
+            credentials_provider: creds_provider.clone(),
+            auth_generation: 0,
+            last_auth_error: None,
+        });
+
+        let weak_state = cx.read(|_cx| state.downgrade());
+        let sign_out_task = cx.update(|cx| do_sign_out(&weak_state, cx));
+
+        cx.run_until_parked();
+        sign_out_task.await.expect("sign-out should succeed");
+
+        assert!(
+            creds_provider.storage.lock().is_none(),
+            "credential store should be empty after sign-out"
+        );
+        cx.read(|cx| {
+            assert!(
+                !state.read(cx).is_authenticated(),
+                "state should show not authenticated"
+            );
+        });
+    }
+
+    #[gpui::test]
+    async fn test_authenticate_awaits_initial_load(cx: &mut TestAppContext) {
+        let creds = make_fresh_credentials();
+        let creds_json = serde_json::to_vec(&creds).unwrap();
+        let creds_provider = Arc::new(FakeCredentialsProvider::new());
+        creds_provider
+            .storage
+            .lock()
+            .replace(("Bearer".to_string(), creds_json));
+
+        let http_client = FakeHttpClient::create(|_| async {
+            Ok(http_client::Response::builder()
+                .status(200)
+                .body(http_client::AsyncBody::default())?)
+        });
+
+        let provider =
+            cx.update(|cx| OpenAiSubscribedProvider::new(http_client, creds_provider, cx));
+
+        // Before load completes, authenticate should still await the load.
+        let auth_task = cx.update(|cx| provider.authenticate(cx));
+
+        // Drive the load to completion.
+        cx.run_until_parked();
+
+        let result = auth_task.await;
+        assert!(
+            result.is_ok(),
+            "authenticate should succeed after load completes with valid credentials"
+        );
+    }
 }