openai_subscribed.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use base64::Engine as _;
   3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
   4use credentials_provider::CredentialsProvider;
   5use futures::{FutureExt, StreamExt, future::BoxFuture, future::Shared};
   6use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
   7use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
   8use language_model::{
   9    AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCompletionError,
  10    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  11    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  12    LanguageModelRequest, LanguageModelToolChoice, RateLimiter,
  13};
  14use open_ai::{ReasoningEffort, responses::stream_response};
  15use rand::RngCore as _;
  16use serde::{Deserialize, Serialize};
  17use sha2::{Digest, Sha256};
  18use std::sync::Arc;
  19use std::time::{SystemTime, UNIX_EPOCH};
  20use ui::{ConfiguredApiCard, prelude::*};
  21use url::form_urlencoded;
  22use util::ResultExt as _;
  23
  24use crate::provider::open_ai::{OpenAiResponseEventMapper, into_open_ai_response};
  25
  26const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai-subscribed");
  27const PROVIDER_NAME: LanguageModelProviderName =
  28    LanguageModelProviderName::new("ChatGPT Subscription");
  29
  30const CODEX_BASE_URL: &str = "https://chatgpt.com/backend-api/codex";
  31const OPENAI_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
  32const OPENAI_AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
  33const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
  34
  35const CREDENTIALS_KEY: &str = "https://chatgpt.com/backend-api/codex";
  36const TOKEN_REFRESH_BUFFER_MS: u64 = 5 * 60 * 1000;
  37
  38#[derive(Serialize, Deserialize, Clone, Debug)]
  39struct CodexCredentials {
  40    access_token: String,
  41    refresh_token: String,
  42    expires_at_ms: u64,
  43    account_id: Option<String>,
  44    email: Option<String>,
  45}
  46
  47impl CodexCredentials {
  48    fn is_expired(&self) -> bool {
  49        let now = now_ms();
  50        now + TOKEN_REFRESH_BUFFER_MS >= self.expires_at_ms
  51    }
  52}
  53
  54pub struct State {
  55    credentials: Option<CodexCredentials>,
  56    sign_in_task: Option<Task<Result<()>>>,
  57    refresh_task: Option<Shared<Task<Result<CodexCredentials, Arc<anyhow::Error>>>>>,
  58    load_task: Option<Shared<Task<Result<(), Arc<anyhow::Error>>>>>,
  59    credentials_provider: Arc<dyn CredentialsProvider>,
  60    auth_generation: u64,
  61    last_auth_error: Option<SharedString>,
  62}
  63
  64#[derive(Debug)]
  65enum RefreshError {
  66    Fatal(anyhow::Error),
  67    Transient(anyhow::Error),
  68}
  69
  70impl std::fmt::Display for RefreshError {
  71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  72        match self {
  73            RefreshError::Fatal(e) => write!(f, "{e}"),
  74            RefreshError::Transient(e) => write!(f, "{e}"),
  75        }
  76    }
  77}
  78
  79impl State {
  80    fn is_authenticated(&self) -> bool {
  81        self.credentials.is_some()
  82    }
  83
  84    fn email(&self) -> Option<&str> {
  85        self.credentials.as_ref().and_then(|c| c.email.as_deref())
  86    }
  87
  88    fn is_signing_in(&self) -> bool {
  89        self.sign_in_task.is_some()
  90    }
  91}
  92
  93pub struct OpenAiSubscribedProvider {
  94    http_client: Arc<dyn HttpClient>,
  95    state: Entity<State>,
  96}
  97
  98impl OpenAiSubscribedProvider {
  99    pub fn new(
 100        http_client: Arc<dyn HttpClient>,
 101        credentials_provider: Arc<dyn CredentialsProvider>,
 102        cx: &mut App,
 103    ) -> Self {
 104        let state = cx.new(|_cx| State {
 105            credentials: None,
 106            sign_in_task: None,
 107            refresh_task: None,
 108            load_task: None,
 109            credentials_provider,
 110            auth_generation: 0,
 111            last_auth_error: None,
 112        });
 113
 114        let provider = Self { http_client, state };
 115
 116        provider.load_credentials(cx);
 117
 118        provider
 119    }
 120
 121    fn load_credentials(&self, cx: &mut App) {
 122        let state = self.state.downgrade();
 123        let load_task = cx
 124            .spawn(async move |cx| {
 125                let credentials_provider =
 126                    state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
 127                let result = credentials_provider
 128                    .read_credentials(CREDENTIALS_KEY, &*cx)
 129                    .await;
 130                state.update(cx, |s, cx| {
 131                    if let Ok(Some((_, bytes))) = result {
 132                        match serde_json::from_slice::<CodexCredentials>(&bytes) {
 133                            Ok(creds) => s.credentials = Some(creds),
 134                            Err(err) => {
 135                                log::warn!(
 136                                    "Failed to deserialize ChatGPT subscription credentials: {err}"
 137                                );
 138                            }
 139                        }
 140                    }
 141                    s.load_task = None;
 142                    cx.notify();
 143                })?;
 144                Ok::<(), Arc<anyhow::Error>>(())
 145            })
 146            .shared();
 147
 148        self.state.update(cx, |s, _| {
 149            s.load_task = Some(load_task);
 150        });
 151    }
 152
 153    fn sign_out(&self, cx: &mut App) -> Task<Result<()>> {
 154        do_sign_out(&self.state.downgrade(), cx)
 155    }
 156
 157    fn create_language_model(&self, model: ChatGptModel) -> Arc<dyn LanguageModel> {
 158        Arc::new(OpenAiSubscribedLanguageModel {
 159            id: LanguageModelId::from(model.id().to_string()),
 160            model,
 161            state: self.state.clone(),
 162            http_client: self.http_client.clone(),
 163            request_limiter: RateLimiter::new(4),
 164        })
 165    }
 166}
 167
 168impl LanguageModelProviderState for OpenAiSubscribedProvider {
 169    type ObservableEntity = State;
 170
 171    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 172        Some(self.state.clone())
 173    }
 174}
 175
 176impl LanguageModelProvider for OpenAiSubscribedProvider {
 177    fn id(&self) -> LanguageModelProviderId {
 178        PROVIDER_ID
 179    }
 180
 181    fn name(&self) -> LanguageModelProviderName {
 182        PROVIDER_NAME
 183    }
 184
 185    fn icon(&self) -> IconOrSvg {
 186        IconOrSvg::Icon(IconName::AiOpenAi)
 187    }
 188
 189    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 190        Some(self.create_language_model(ChatGptModel::Gpt54))
 191    }
 192
 193    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 194        Some(self.create_language_model(ChatGptModel::Gpt54Mini))
 195    }
 196
 197    fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 198        ChatGptModel::all()
 199            .into_iter()
 200            .map(|m| self.create_language_model(m))
 201            .collect()
 202    }
 203
 204    fn is_authenticated(&self, cx: &App) -> bool {
 205        self.state.read(cx).is_authenticated()
 206    }
 207
 208    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 209        if self.is_authenticated(cx) {
 210            return Task::ready(Ok(()));
 211        }
 212        let load_task = self.state.read(cx).load_task.clone();
 213        if let Some(load_task) = load_task {
 214            let weak_state = self.state.downgrade();
 215            cx.spawn(async move |cx| {
 216                let _ = load_task.await;
 217                let is_auth = weak_state
 218                    .read_with(&*cx, |s, _| s.is_authenticated())
 219                    .unwrap_or(false);
 220                if is_auth {
 221                    Ok(())
 222                } else {
 223                    Err(anyhow!(
 224                        "Sign in with your ChatGPT Plus or Pro subscription to use this provider."
 225                    )
 226                    .into())
 227                }
 228            })
 229        } else {
 230            Task::ready(Err(anyhow!(
 231                "Sign in with your ChatGPT Plus or Pro subscription to use this provider."
 232            )
 233            .into()))
 234        }
 235    }
 236
 237    fn configuration_view(
 238        &self,
 239        _target_agent: language_model::ConfigurationViewTargetAgent,
 240        _window: &mut Window,
 241        cx: &mut App,
 242    ) -> AnyView {
 243        let state = self.state.clone();
 244        let http_client = self.http_client.clone();
 245        cx.new(|_cx| ConfigurationView { state, http_client })
 246            .into()
 247    }
 248
 249    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 250        self.sign_out(cx)
 251    }
 252}
 253
 254//
 255// The ChatGPT Subscription provider routes requests to chatgpt.com/backend-api/codex,
 256// which only supports a subset of OpenAI models. This list is maintained separately
 257// from the standard OpenAI API model list (open_ai::Model).
 258
 259#[derive(Clone, Debug, PartialEq)]
 260enum ChatGptModel {
 261    Gpt5,
 262    Gpt5Codex,
 263    Gpt5CodexMini,
 264    Gpt51,
 265    Gpt51Codex,
 266    Gpt51CodexMax,
 267    Gpt51CodexMini,
 268    Gpt52,
 269    Gpt52Codex,
 270    Gpt53Codex,
 271    Gpt53CodexSpark,
 272    Gpt54,
 273    Gpt54Mini,
 274}
 275
 276impl ChatGptModel {
 277    fn all() -> Vec<Self> {
 278        vec![
 279            Self::Gpt54,
 280            Self::Gpt54Mini,
 281            Self::Gpt53Codex,
 282            Self::Gpt53CodexSpark,
 283            Self::Gpt52Codex,
 284            Self::Gpt52,
 285            Self::Gpt51CodexMax,
 286            Self::Gpt51Codex,
 287            Self::Gpt51CodexMini,
 288            Self::Gpt51,
 289            Self::Gpt5Codex,
 290            Self::Gpt5CodexMini,
 291            Self::Gpt5,
 292        ]
 293    }
 294
 295    fn id(&self) -> &str {
 296        match self {
 297            Self::Gpt5 => "gpt-5",
 298            Self::Gpt5Codex => "gpt-5-codex",
 299            Self::Gpt5CodexMini => "gpt-5-codex-mini",
 300            Self::Gpt51 => "gpt-5.1",
 301            Self::Gpt51Codex => "gpt-5.1-codex",
 302            Self::Gpt51CodexMax => "gpt-5.1-codex-max",
 303            Self::Gpt51CodexMini => "gpt-5.1-codex-mini",
 304            Self::Gpt52 => "gpt-5.2",
 305            Self::Gpt52Codex => "gpt-5.2-codex",
 306            Self::Gpt53Codex => "gpt-5.3-codex",
 307            Self::Gpt53CodexSpark => "gpt-5.3-codex-spark",
 308            Self::Gpt54 => "gpt-5.4",
 309            Self::Gpt54Mini => "gpt-5.4-mini",
 310        }
 311    }
 312
 313    fn display_name(&self) -> &str {
 314        match self {
 315            Self::Gpt5 => "GPT-5",
 316            Self::Gpt5Codex => "GPT-5 Codex",
 317            Self::Gpt5CodexMini => "GPT-5 Codex Mini",
 318            Self::Gpt51 => "GPT-5.1",
 319            Self::Gpt51Codex => "GPT-5.1 Codex",
 320            Self::Gpt51CodexMax => "GPT-5.1 Codex Max",
 321            Self::Gpt51CodexMini => "GPT-5.1 Codex Mini",
 322            Self::Gpt52 => "GPT-5.2",
 323            Self::Gpt52Codex => "GPT-5.2 Codex",
 324            Self::Gpt53Codex => "GPT-5.3 Codex",
 325            Self::Gpt53CodexSpark => "GPT-5.3 Codex Spark",
 326            Self::Gpt54 => "GPT-5.4",
 327            Self::Gpt54Mini => "GPT-5.4 Mini",
 328        }
 329    }
 330
 331    fn max_token_count(&self) -> u64 {
 332        match self {
 333            Self::Gpt53CodexSpark => 128_000,
 334            Self::Gpt54 | Self::Gpt54Mini => 1_050_000,
 335            _ => 400_000,
 336        }
 337    }
 338
 339    fn max_output_tokens(&self) -> Option<u64> {
 340        match self {
 341            Self::Gpt53CodexSpark => Some(8_192),
 342            _ => Some(128_000),
 343        }
 344    }
 345
 346    fn supports_images(&self) -> bool {
 347        !matches!(self, Self::Gpt53CodexSpark)
 348    }
 349
 350    fn reasoning_effort(&self) -> Option<ReasoningEffort> {
 351        match self {
 352            Self::Gpt54 | Self::Gpt54Mini => None,
 353            _ => Some(ReasoningEffort::Medium),
 354        }
 355    }
 356
 357    fn supports_parallel_tool_calls(&self) -> bool {
 358        match self {
 359            Self::Gpt54 | Self::Gpt54Mini => true,
 360            _ => false,
 361        }
 362    }
 363
 364    fn supports_prompt_cache_key(&self) -> bool {
 365        true
 366    }
 367}
 368
 369struct OpenAiSubscribedLanguageModel {
 370    id: LanguageModelId,
 371    model: ChatGptModel,
 372    state: Entity<State>,
 373    http_client: Arc<dyn HttpClient>,
 374    request_limiter: RateLimiter,
 375}
 376
 377impl LanguageModel for OpenAiSubscribedLanguageModel {
 378    fn id(&self) -> LanguageModelId {
 379        self.id.clone()
 380    }
 381
 382    fn name(&self) -> LanguageModelName {
 383        LanguageModelName::from(self.model.display_name().to_string())
 384    }
 385
 386    fn provider_id(&self) -> LanguageModelProviderId {
 387        PROVIDER_ID
 388    }
 389
 390    fn provider_name(&self) -> LanguageModelProviderName {
 391        PROVIDER_NAME
 392    }
 393
 394    fn supports_tools(&self) -> bool {
 395        true
 396    }
 397
 398    fn supports_images(&self) -> bool {
 399        self.model.supports_images()
 400    }
 401
 402    fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
 403        true
 404    }
 405
 406    fn supports_streaming_tools(&self) -> bool {
 407        true
 408    }
 409
 410    fn supports_thinking(&self) -> bool {
 411        self.model.reasoning_effort().is_some()
 412    }
 413
 414    fn telemetry_id(&self) -> String {
 415        format!("openai-subscribed/{}", self.model.id())
 416    }
 417
 418    fn max_token_count(&self) -> u64 {
 419        self.model.max_token_count()
 420    }
 421
 422    fn max_output_tokens(&self) -> Option<u64> {
 423        self.model.max_output_tokens()
 424    }
 425
 426    fn count_tokens(
 427        &self,
 428        request: LanguageModelRequest,
 429        cx: &App,
 430    ) -> BoxFuture<'static, Result<u64>> {
 431        let max_token_count = self.model.max_token_count();
 432        cx.background_spawn(async move {
 433            let messages = crate::provider::open_ai::collect_tiktoken_messages(request);
 434            let model = if max_token_count >= 100_000 {
 435                "gpt-4o"
 436            } else {
 437                "gpt-4"
 438            };
 439            tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
 440        })
 441        .boxed()
 442    }
 443
 444    fn stream_completion(
 445        &self,
 446        request: LanguageModelRequest,
 447        cx: &AsyncApp,
 448    ) -> BoxFuture<
 449        'static,
 450        Result<
 451            futures::stream::BoxStream<
 452                'static,
 453                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
 454            >,
 455            LanguageModelCompletionError,
 456        >,
 457    > {
 458        let mut responses_request = into_open_ai_response(
 459            request,
 460            self.model.id(),
 461            self.model.supports_parallel_tool_calls(),
 462            self.model.supports_prompt_cache_key(),
 463            self.max_output_tokens(),
 464            self.model.reasoning_effort(),
 465        );
 466        responses_request.store = Some(false);
 467
 468        // The Codex backend requires system messages to be in the top-level
 469        // `instructions` field rather than as input items.
 470        let mut instructions = Vec::new();
 471        responses_request.input.retain(|item| {
 472            if let open_ai::responses::ResponseInputItem::Message(msg) = item {
 473                if msg.role == open_ai::Role::System {
 474                    for part in &msg.content {
 475                        if let open_ai::responses::ResponseInputContent::Text { text } = part {
 476                            instructions.push(text.clone());
 477                        }
 478                    }
 479                    return false;
 480                }
 481            }
 482            true
 483        });
 484        if !instructions.is_empty() {
 485            responses_request.instructions = Some(instructions.join("\n\n"));
 486        }
 487
 488        let state = self.state.downgrade();
 489        let http_client = self.http_client.clone();
 490        let request_limiter = self.request_limiter.clone();
 491
 492        let future = cx.spawn(async move |cx| {
 493            let creds = get_fresh_credentials(&state, &http_client, cx).await?;
 494
 495            let mut extra_headers: Vec<(String, String)> = vec![
 496                ("originator".into(), "zed".into()),
 497                ("OpenAI-Beta".into(), "responses=experimental".into()),
 498            ];
 499            if let Some(ref id) = creds.account_id {
 500                if !id.is_empty() {
 501                    extra_headers.push(("ChatGPT-Account-Id".into(), id.clone()));
 502                }
 503            }
 504
 505            let access_token = creds.access_token.clone();
 506            request_limiter
 507                .stream(async move {
 508                    stream_response(
 509                        http_client.as_ref(),
 510                        PROVIDER_NAME.0.as_str(),
 511                        CODEX_BASE_URL,
 512                        &access_token,
 513                        responses_request,
 514                        extra_headers,
 515                    )
 516                    .await
 517                    .map_err(LanguageModelCompletionError::from)
 518                })
 519                .await
 520        });
 521
 522        async move {
 523            let mapper = OpenAiResponseEventMapper::new();
 524            Ok(mapper.map_stream(future.await?.boxed()).boxed())
 525        }
 526        .boxed()
 527    }
 528}
 529
 530async fn get_fresh_credentials(
 531    state: &gpui::WeakEntity<State>,
 532    http_client: &Arc<dyn HttpClient>,
 533    cx: &mut AsyncApp,
 534) -> Result<CodexCredentials, LanguageModelCompletionError> {
 535    let (creds, existing_task) = state
 536        .read_with(&*cx, |s, _| (s.credentials.clone(), s.refresh_task.clone()))
 537        .map_err(LanguageModelCompletionError::Other)?;
 538
 539    let creds = creds.ok_or(LanguageModelCompletionError::NoApiKey {
 540        provider: PROVIDER_NAME,
 541    })?;
 542
 543    if !creds.is_expired() {
 544        return Ok(creds);
 545    }
 546
 547    // If another caller is already refreshing, await their result.
 548    if let Some(shared_task) = existing_task {
 549        return shared_task
 550            .await
 551            .map_err(|e| LanguageModelCompletionError::Other(anyhow::anyhow!("{e}")));
 552    }
 553
 554    // We are the first caller to notice expiry — spawn the refresh task.
 555    let http_client_clone = http_client.clone();
 556    let state_clone = state.clone();
 557    let refresh_token_value = creds.refresh_token.clone();
 558
 559    // Capture the generation so we can detect sign-outs that happened during refresh.
 560    let generation = state
 561        .read_with(&*cx, |s, _| s.auth_generation)
 562        .map_err(LanguageModelCompletionError::Other)?;
 563
 564    let shared_task = cx
 565        .spawn(async move |cx| {
 566            let result = refresh_token(&http_client_clone, &refresh_token_value).await;
 567
 568            match result {
 569                Ok(refreshed) => {
 570                    let persist_result: Result<CodexCredentials, Arc<anyhow::Error>> = async {
 571                        // Check if auth_generation changed (sign-out during refresh).
 572                        let current_generation = state_clone
 573                            .read_with(&*cx, |s, _| s.auth_generation)
 574                            .map_err(|e| Arc::new(e))?;
 575                        if current_generation != generation {
 576                            return Err(Arc::new(anyhow!(
 577                                "Sign-out occurred during token refresh"
 578                            )));
 579                        }
 580
 581                        let credentials_provider = state_clone
 582                            .read_with(&*cx, |s, _| s.credentials_provider.clone())
 583                            .map_err(|e| Arc::new(e))?;
 584
 585                        let json =
 586                            serde_json::to_vec(&refreshed).map_err(|e| Arc::new(e.into()))?;
 587
 588                        credentials_provider
 589                            .write_credentials(CREDENTIALS_KEY, "Bearer", &json, &*cx)
 590                            .await
 591                            .map_err(|e| Arc::new(e))?;
 592
 593                        state_clone
 594                            .update(cx, |s, _| {
 595                                s.credentials = Some(refreshed.clone());
 596                                s.refresh_task = None;
 597                            })
 598                            .map_err(|e| Arc::new(e))?;
 599
 600                        Ok(refreshed)
 601                    }
 602                    .await;
 603
 604                    // Clear refresh_task on failure too.
 605                    if persist_result.is_err() {
 606                        let _ = state_clone.update(cx, |s, _| {
 607                            s.refresh_task = None;
 608                        });
 609                    }
 610
 611                    persist_result
 612                }
 613                Err(RefreshError::Fatal(e)) => {
 614                    log::error!("ChatGPT subscription token refresh failed fatally: {e:?}");
 615                    let _ = state_clone.update(cx, |s, cx| {
 616                        s.refresh_task = None;
 617                        s.credentials = None;
 618                        s.last_auth_error =
 619                            Some("Your session has expired. Please sign in again.".into());
 620                        cx.notify();
 621                    });
 622                    // Also clear the keychain so stale credentials aren't loaded next time.
 623                    if let Ok(credentials_provider) =
 624                        state_clone.read_with(&*cx, |s, _| s.credentials_provider.clone())
 625                    {
 626                        credentials_provider
 627                            .delete_credentials(CREDENTIALS_KEY, &*cx)
 628                            .await
 629                            .log_err();
 630                    }
 631                    Err(Arc::new(e))
 632                }
 633                Err(RefreshError::Transient(e)) => {
 634                    log::warn!("ChatGPT subscription token refresh failed transiently: {e:?}");
 635                    let _ = state_clone.update(cx, |s, _| {
 636                        s.refresh_task = None;
 637                    });
 638                    Err(Arc::new(e))
 639                }
 640            }
 641        })
 642        .shared();
 643
 644    // Store the shared task so concurrent callers can join on it.
 645    state
 646        .update(cx, |s, _| {
 647            s.refresh_task = Some(shared_task.clone());
 648        })
 649        .map_err(LanguageModelCompletionError::Other)?;
 650
 651    shared_task
 652        .await
 653        .map_err(|e| LanguageModelCompletionError::Other(anyhow::anyhow!("{e}")))
 654}
 655
 656#[derive(Deserialize)]
 657struct TokenResponse {
 658    access_token: String,
 659    refresh_token: String,
 660    #[serde(default)]
 661    id_token: Option<String>,
 662    expires_in: u64,
 663    #[serde(default)]
 664    email: Option<String>,
 665}
 666
 667async fn do_oauth_flow(
 668    http_client: Arc<dyn HttpClient>,
 669    cx: &AsyncApp,
 670) -> Result<CodexCredentials> {
 671    // Start the callback server FIRST so the redirect URI is ready
 672    let (redirect_uri, callback_rx) = http_client::start_oauth_callback_server()
 673        .context("Failed to start OAuth callback server")?;
 674
 675    // PKCE verifier: 32 random bytes → base64url (no padding)
 676    let mut verifier_bytes = [0u8; 32];
 677    rand::rng().fill_bytes(&mut verifier_bytes);
 678    let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
 679
 680    // PKCE challenge: SHA-256(verifier) → base64url
 681    let mut hasher = Sha256::new();
 682    hasher.update(verifier.as_bytes());
 683    let challenge = URL_SAFE_NO_PAD.encode(hasher.finalize().as_slice());
 684
 685    // CSRF state: 16 random bytes → hex string
 686    let mut state_bytes = [0u8; 16];
 687    rand::rng().fill_bytes(&mut state_bytes);
 688    let oauth_state: String = state_bytes.iter().map(|b| format!("{b:02x}")).collect();
 689
 690    let mut auth_url = url::Url::parse(OPENAI_AUTHORIZE_URL).expect("valid base URL");
 691    auth_url
 692        .query_pairs_mut()
 693        .append_pair("client_id", CLIENT_ID)
 694        .append_pair("redirect_uri", &redirect_uri)
 695        .append_pair("scope", "openid profile email offline_access")
 696        .append_pair("response_type", "code")
 697        .append_pair("code_challenge", &challenge)
 698        .append_pair("code_challenge_method", "S256")
 699        .append_pair("state", &oauth_state)
 700        .append_pair("codex_cli_simplified_flow", "true")
 701        .append_pair("originator", "zed");
 702
 703    // Open browser AFTER the listener is ready
 704    cx.update(|cx| cx.open_url(auth_url.as_str()));
 705
 706    // Await the callback
 707    let callback = callback_rx
 708        .await
 709        .map_err(|_| anyhow!("OAuth callback was cancelled"))?
 710        .context("OAuth callback failed")?;
 711
 712    // Validate CSRF state
 713    if callback.state != oauth_state {
 714        return Err(anyhow!("OAuth state mismatch"));
 715    }
 716
 717    let tokens = exchange_code(&http_client, &callback.code, &verifier, &redirect_uri)
 718        .await
 719        .context("Token exchange failed")?;
 720
 721    let jwt = tokens
 722        .id_token
 723        .as_deref()
 724        .unwrap_or(tokens.access_token.as_str());
 725    let claims = extract_jwt_claims(jwt);
 726
 727    Ok(CodexCredentials {
 728        access_token: tokens.access_token,
 729        refresh_token: tokens.refresh_token,
 730        expires_at_ms: now_ms() + tokens.expires_in * 1000,
 731        account_id: claims.account_id,
 732        email: claims.email.or(tokens.email),
 733    })
 734}
 735
 736async fn exchange_code(
 737    client: &Arc<dyn HttpClient>,
 738    code: &str,
 739    verifier: &str,
 740    redirect_uri: &str,
 741) -> Result<TokenResponse> {
 742    let body = form_urlencoded::Serializer::new(String::new())
 743        .append_pair("grant_type", "authorization_code")
 744        .append_pair("client_id", CLIENT_ID)
 745        .append_pair("code", code)
 746        .append_pair("redirect_uri", redirect_uri)
 747        .append_pair("code_verifier", verifier)
 748        .finish();
 749
 750    let request = HttpRequest::builder()
 751        .method(Method::POST)
 752        .uri(OPENAI_TOKEN_URL)
 753        .header("Content-Type", "application/x-www-form-urlencoded")
 754        .body(AsyncBody::from(body))?;
 755
 756    let mut response = client.send(request).await?;
 757    let mut body = String::new();
 758    smol::io::AsyncReadExt::read_to_string(response.body_mut(), &mut body).await?;
 759
 760    if !response.status().is_success() {
 761        return Err(anyhow!(
 762            "Token exchange failed (HTTP {}): {body}",
 763            response.status()
 764        ));
 765    }
 766
 767    serde_json::from_str::<TokenResponse>(&body).context("Failed to parse token response")
 768}
 769
 770async fn refresh_token(
 771    client: &Arc<dyn HttpClient>,
 772    refresh_token: &str,
 773) -> Result<CodexCredentials, RefreshError> {
 774    let body = form_urlencoded::Serializer::new(String::new())
 775        .append_pair("grant_type", "refresh_token")
 776        .append_pair("client_id", CLIENT_ID)
 777        .append_pair("refresh_token", refresh_token)
 778        .finish();
 779
 780    let request = HttpRequest::builder()
 781        .method(Method::POST)
 782        .uri(OPENAI_TOKEN_URL)
 783        .header("Content-Type", "application/x-www-form-urlencoded")
 784        .body(AsyncBody::from(body))
 785        .map_err(|e| RefreshError::Transient(e.into()))?;
 786
 787    let mut response = client
 788        .send(request)
 789        .await
 790        .map_err(|e| RefreshError::Transient(e))?;
 791    let status = response.status();
 792    let mut body = String::new();
 793    smol::io::AsyncReadExt::read_to_string(response.body_mut(), &mut body)
 794        .await
 795        .map_err(|e| RefreshError::Transient(e.into()))?;
 796
 797    if !status.is_success() {
 798        let err = anyhow!("Token refresh failed (HTTP {}): {body}", status);
 799        // 400/401/403 indicate a revoked or invalid refresh token.
 800        // 5xx and other errors are treated as transient.
 801        if status == http_client::StatusCode::BAD_REQUEST
 802            || status == http_client::StatusCode::UNAUTHORIZED
 803            || status == http_client::StatusCode::FORBIDDEN
 804        {
 805            return Err(RefreshError::Fatal(err));
 806        }
 807        return Err(RefreshError::Transient(err));
 808    }
 809
 810    let tokens: TokenResponse =
 811        serde_json::from_str(&body).map_err(|e| RefreshError::Transient(e.into()))?;
 812    let jwt = tokens
 813        .id_token
 814        .as_deref()
 815        .unwrap_or(tokens.access_token.as_str());
 816    let claims = extract_jwt_claims(jwt);
 817
 818    Ok(CodexCredentials {
 819        access_token: tokens.access_token,
 820        refresh_token: tokens.refresh_token,
 821        expires_at_ms: now_ms() + tokens.expires_in * 1000,
 822        account_id: claims.account_id,
 823        email: claims.email.or(tokens.email),
 824    })
 825}
 826
 827struct JwtClaims {
 828    account_id: Option<String>,
 829    email: Option<String>,
 830}
 831
 832/// Extract claims from a JWT payload (base64url middle segment).
 833/// Extracts `chatgpt_account_id` from three possible locations (matching Roo Code's
 834/// implementation) and the `email` claim.
 835fn extract_jwt_claims(jwt: &str) -> JwtClaims {
 836    let Some(payload_b64) = jwt.split('.').nth(1) else {
 837        return JwtClaims {
 838            account_id: None,
 839            email: None,
 840        };
 841    };
 842    let Ok(payload) = URL_SAFE_NO_PAD.decode(payload_b64) else {
 843        return JwtClaims {
 844            account_id: None,
 845            email: None,
 846        };
 847    };
 848    let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&payload) else {
 849        return JwtClaims {
 850            account_id: None,
 851            email: None,
 852        };
 853    };
 854
 855    let account_id = claims
 856        .get("chatgpt_account_id")
 857        .and_then(|v| v.as_str())
 858        .or_else(|| {
 859            claims
 860                .get("https://api.openai.com/auth")
 861                .and_then(|v| v.get("chatgpt_account_id"))
 862                .and_then(|v| v.as_str())
 863        })
 864        .or_else(|| {
 865            claims
 866                .get("organizations")
 867                .and_then(|v| v.as_array())
 868                .and_then(|arr| arr.first())
 869                .and_then(|org| org.get("id"))
 870                .and_then(|v| v.as_str())
 871        })
 872        .map(|s| s.to_owned());
 873
 874    let email = claims
 875        .get("email")
 876        .and_then(|v| v.as_str())
 877        .map(|s| s.to_owned());
 878
 879    JwtClaims { account_id, email }
 880}
 881
 882fn now_ms() -> u64 {
 883    SystemTime::now()
 884        .duration_since(UNIX_EPOCH)
 885        .map(|d| d.as_millis() as u64)
 886        .unwrap_or_else(|err| {
 887            log::error!("System clock is before UNIX epoch: {err}");
 888            0
 889        })
 890}
 891
 892fn do_sign_in(state: &Entity<State>, http_client: &Arc<dyn HttpClient>, cx: &mut App) {
 893    if state.read(cx).is_signing_in() {
 894        return;
 895    }
 896
 897    let weak_state = state.downgrade();
 898    let http_client = http_client.clone();
 899
 900    let task = cx.spawn(async move |cx| {
 901        match do_oauth_flow(http_client, &*cx).await {
 902            Ok(creds) => {
 903                let persist_result = async {
 904                    let credentials_provider =
 905                        weak_state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
 906                    let json = serde_json::to_vec(&creds)?;
 907                    credentials_provider
 908                        .write_credentials(CREDENTIALS_KEY, "Bearer", &json, &*cx)
 909                        .await?;
 910                    anyhow::Ok(())
 911                }
 912                .await;
 913
 914                match persist_result {
 915                    Ok(()) => {
 916                        weak_state
 917                            .update(cx, |s, cx| {
 918                                s.credentials = Some(creds);
 919                                s.sign_in_task = None;
 920                                s.last_auth_error = None;
 921                                cx.notify();
 922                            })
 923                            .log_err();
 924                    }
 925                    Err(err) => {
 926                        log::error!(
 927                            "ChatGPT subscription sign-in failed to persist credentials: {err:?}"
 928                        );
 929                        weak_state
 930                            .update(cx, |s, cx| {
 931                                s.sign_in_task = None;
 932                                s.last_auth_error =
 933                                    Some("Failed to save credentials. Please try again.".into());
 934                                cx.notify();
 935                            })
 936                            .log_err();
 937                    }
 938                }
 939            }
 940            Err(err) => {
 941                log::error!("ChatGPT subscription sign-in failed: {err:?}");
 942                weak_state
 943                    .update(cx, |s, cx| {
 944                        s.sign_in_task = None;
 945                        s.last_auth_error = Some("Sign-in failed. Please try again.".into());
 946                        cx.notify();
 947                    })
 948                    .log_err();
 949            }
 950        }
 951        anyhow::Ok(())
 952    });
 953
 954    state.update(cx, |s, cx| {
 955        s.last_auth_error = None;
 956        s.sign_in_task = Some(task);
 957        cx.notify();
 958    });
 959}
 960
 961fn do_sign_out(state: &gpui::WeakEntity<State>, cx: &mut App) -> Task<Result<()>> {
 962    let weak_state = state.clone();
 963    // Clear credentials and cancel in-flight work immediately so the UI
 964    // reflects the sign-out right away.
 965    weak_state
 966        .update(cx, |s, cx| {
 967            s.auth_generation += 1;
 968            s.credentials = None;
 969            s.sign_in_task = None;
 970            s.refresh_task = None;
 971            s.last_auth_error = None;
 972            cx.notify();
 973        })
 974        .log_err();
 975
 976    cx.spawn(async move |cx| {
 977        let credentials_provider =
 978            weak_state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
 979        credentials_provider
 980            .delete_credentials(CREDENTIALS_KEY, &*cx)
 981            .await
 982            .context("Failed to delete ChatGPT subscription credentials from keychain")?;
 983        anyhow::Ok(())
 984    })
 985}
 986
 987struct ConfigurationView {
 988    state: Entity<State>,
 989    http_client: Arc<dyn HttpClient>,
 990}
 991
 992impl Render for ConfigurationView {
 993    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 994        let state = self.state.read(cx);
 995
 996        if state.is_authenticated() {
 997            let label = state
 998                .email()
 999                .map(|e| format!("Signed in as {e}"))
1000                .unwrap_or_else(|| "Signed in".to_string());
1001
1002            let weak_state = self.state.downgrade();
1003            return v_flex()
1004                .child(
1005                    ConfiguredApiCard::new(SharedString::from(label))
1006                        .button_label("Sign Out")
1007                        .on_click(cx.listener(move |_this, _, _window, cx| {
1008                            do_sign_out(&weak_state, cx).detach_and_log_err(cx);
1009                        })),
1010                )
1011                .into_any_element();
1012        }
1013
1014        if state.is_signing_in() {
1015            return v_flex()
1016                .child(Label::new("Signing in…").color(Color::Muted))
1017                .into_any_element();
1018        }
1019
1020        let last_auth_error = state.last_auth_error.clone();
1021        let provider_state = self.state.clone();
1022        let http_client = self.http_client.clone();
1023
1024        v_flex()
1025            .gap_2()
1026            .when_some(last_auth_error, |this, error| {
1027                this.child(Label::new(error).color(Color::Error))
1028            })
1029            .child(Label::new(
1030                "Sign in with your ChatGPT Plus or Pro subscription to use OpenAI models in Zed's agent.",
1031            ))
1032            .child(
1033                Button::new("sign-in", "Sign in with ChatGPT")
1034                    .on_click(move |_, _window, cx| {
1035                        do_sign_in(&provider_state, &http_client, cx);
1036                    }),
1037            )
1038            .into_any_element()
1039    }
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044    use super::*;
1045    use gpui::TestAppContext;
1046    use http_client::FakeHttpClient;
1047    use parking_lot::Mutex;
1048    use std::future::Future;
1049    use std::pin::Pin;
1050    use std::sync::atomic::{AtomicUsize, Ordering};
1051
1052    struct FakeCredentialsProvider {
1053        storage: Mutex<Option<(String, Vec<u8>)>>,
1054    }
1055
1056    impl FakeCredentialsProvider {
1057        fn new() -> Self {
1058            Self {
1059                storage: Mutex::new(None),
1060            }
1061        }
1062    }
1063
1064    impl CredentialsProvider for FakeCredentialsProvider {
1065        fn read_credentials<'a>(
1066            &'a self,
1067            _url: &'a str,
1068            _cx: &'a AsyncApp,
1069        ) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
1070            Box::pin(async { Ok(self.storage.lock().clone()) })
1071        }
1072
1073        fn write_credentials<'a>(
1074            &'a self,
1075            _url: &'a str,
1076            username: &'a str,
1077            password: &'a [u8],
1078            _cx: &'a AsyncApp,
1079        ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1080            self.storage
1081                .lock()
1082                .replace((username.to_string(), password.to_vec()));
1083            Box::pin(async { Ok(()) })
1084        }
1085
1086        fn delete_credentials<'a>(
1087            &'a self,
1088            _url: &'a str,
1089            _cx: &'a AsyncApp,
1090        ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1091            *self.storage.lock() = None;
1092            Box::pin(async { Ok(()) })
1093        }
1094    }
1095
1096    fn make_expired_credentials() -> CodexCredentials {
1097        CodexCredentials {
1098            access_token: "old_access".to_string(),
1099            refresh_token: "old_refresh".to_string(),
1100            expires_at_ms: 0,
1101            account_id: None,
1102            email: None,
1103        }
1104    }
1105
1106    fn make_fresh_credentials() -> CodexCredentials {
1107        CodexCredentials {
1108            access_token: "fresh_access".to_string(),
1109            refresh_token: "fresh_refresh".to_string(),
1110            expires_at_ms: now_ms() + 3_600_000,
1111            account_id: None,
1112            email: None,
1113        }
1114    }
1115
1116    fn fake_token_response() -> String {
1117        serde_json::json!({
1118            "access_token": "fresh_access",
1119            "refresh_token": "fresh_refresh",
1120            "expires_in": 3600
1121        })
1122        .to_string()
1123    }
1124
1125    #[gpui::test]
1126    async fn test_concurrent_refresh_deduplicates(cx: &mut TestAppContext) {
1127        let refresh_count = Arc::new(AtomicUsize::new(0));
1128        let refresh_count_clone = refresh_count.clone();
1129
1130        let http_client = FakeHttpClient::create(move |_request| {
1131            let refresh_count = refresh_count_clone.clone();
1132            async move {
1133                refresh_count.fetch_add(1, Ordering::SeqCst);
1134                let body = fake_token_response();
1135                Ok(http_client::Response::builder()
1136                    .status(200)
1137                    .body(http_client::AsyncBody::from(body))?)
1138            }
1139        });
1140
1141        let state = cx.new(|_cx| State {
1142            credentials: Some(make_expired_credentials()),
1143            sign_in_task: None,
1144            refresh_task: None,
1145            load_task: None,
1146            credentials_provider: Arc::new(FakeCredentialsProvider::new()),
1147            auth_generation: 0,
1148            last_auth_error: None,
1149        });
1150
1151        let weak_state = cx.read(|_cx| state.downgrade());
1152        let http: Arc<dyn HttpClient> = http_client;
1153
1154        // Spawn two concurrent refresh attempts.
1155        let weak1 = weak_state.clone();
1156        let http1 = http.clone();
1157        let task1 =
1158            cx.spawn(async move |mut cx| get_fresh_credentials(&weak1, &http1, &mut cx).await);
1159
1160        let weak2 = weak_state.clone();
1161        let http2 = http.clone();
1162        let task2 =
1163            cx.spawn(async move |mut cx| get_fresh_credentials(&weak2, &http2, &mut cx).await);
1164
1165        // Drive both to completion.
1166        cx.run_until_parked();
1167        let result1 = task1.await;
1168        let result2 = task2.await;
1169
1170        assert!(result1.is_ok(), "first refresh should succeed");
1171        assert!(result2.is_ok(), "second refresh should succeed");
1172        assert_eq!(result1.unwrap().access_token, "fresh_access");
1173        assert_eq!(result2.unwrap().access_token, "fresh_access");
1174        assert_eq!(
1175            refresh_count.load(Ordering::SeqCst),
1176            1,
1177            "refresh_token should only be called once despite two concurrent callers"
1178        );
1179    }
1180
1181    #[gpui::test]
1182    async fn test_fresh_credentials_skip_refresh(cx: &mut TestAppContext) {
1183        let refresh_count = Arc::new(AtomicUsize::new(0));
1184        let refresh_count_clone = refresh_count.clone();
1185
1186        let http_client = FakeHttpClient::create(move |_request| {
1187            let refresh_count = refresh_count_clone.clone();
1188            async move {
1189                refresh_count.fetch_add(1, Ordering::SeqCst);
1190                let body = fake_token_response();
1191                Ok(http_client::Response::builder()
1192                    .status(200)
1193                    .body(http_client::AsyncBody::from(body))?)
1194            }
1195        });
1196
1197        let state = cx.new(|_cx| State {
1198            credentials: Some(make_fresh_credentials()),
1199            sign_in_task: None,
1200            refresh_task: None,
1201            load_task: None,
1202            credentials_provider: Arc::new(FakeCredentialsProvider::new()),
1203            auth_generation: 0,
1204            last_auth_error: None,
1205        });
1206
1207        let weak_state = cx.read(|_cx| state.downgrade());
1208        let http: Arc<dyn HttpClient> = http_client;
1209
1210        let weak = weak_state.clone();
1211        let http_clone = http.clone();
1212        let result = cx
1213            .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
1214            .await;
1215
1216        assert!(result.is_ok());
1217        assert_eq!(result.unwrap().access_token, "fresh_access");
1218        assert_eq!(
1219            refresh_count.load(Ordering::SeqCst),
1220            0,
1221            "no refresh should happen when credentials are fresh"
1222        );
1223    }
1224
1225    #[gpui::test]
1226    async fn test_no_credentials_returns_no_api_key(cx: &mut TestAppContext) {
1227        let http_client = FakeHttpClient::create(|_| async {
1228            Ok(http_client::Response::builder()
1229                .status(200)
1230                .body(http_client::AsyncBody::default())?)
1231        });
1232
1233        let state = cx.new(|_cx| State {
1234            credentials: None,
1235            sign_in_task: None,
1236            refresh_task: None,
1237            load_task: None,
1238            credentials_provider: Arc::new(FakeCredentialsProvider::new()),
1239            auth_generation: 0,
1240            last_auth_error: None,
1241        });
1242
1243        let weak_state = cx.read(|_cx| state.downgrade());
1244        let http: Arc<dyn HttpClient> = http_client;
1245
1246        let weak = weak_state.clone();
1247        let http_clone = http.clone();
1248        let result = cx
1249            .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
1250            .await;
1251
1252        assert!(matches!(
1253            result,
1254            Err(LanguageModelCompletionError::NoApiKey { .. })
1255        ));
1256    }
1257
1258    #[gpui::test]
1259    async fn test_fatal_refresh_clears_auth_state(cx: &mut TestAppContext) {
1260        let http_client = FakeHttpClient::create(move |_request| async move {
1261            Ok(http_client::Response::builder()
1262                .status(401)
1263                .body(http_client::AsyncBody::from(r#"{"error":"invalid_grant"}"#))?)
1264        });
1265
1266        let creds_provider = Arc::new(FakeCredentialsProvider::new());
1267        let state = cx.new(|_cx| State {
1268            credentials: Some(make_expired_credentials()),
1269            sign_in_task: None,
1270            refresh_task: None,
1271            load_task: None,
1272            credentials_provider: creds_provider.clone(),
1273            auth_generation: 0,
1274            last_auth_error: None,
1275        });
1276
1277        let weak_state = cx.read(|_cx| state.downgrade());
1278        let http: Arc<dyn HttpClient> = http_client;
1279
1280        let weak = weak_state.clone();
1281        let http_clone = http.clone();
1282        let result = cx
1283            .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
1284            .await;
1285
1286        cx.run_until_parked();
1287
1288        assert!(result.is_err(), "fatal refresh should return an error");
1289        cx.read(|cx| {
1290            let s = state.read(cx);
1291            assert!(
1292                s.credentials.is_none(),
1293                "credentials should be cleared on fatal refresh failure"
1294            );
1295            assert!(
1296                s.last_auth_error.is_some(),
1297                "last_auth_error should be set on fatal refresh failure"
1298            );
1299        });
1300    }
1301
1302    #[gpui::test]
1303    async fn test_transient_refresh_keeps_credentials(cx: &mut TestAppContext) {
1304        let http_client = FakeHttpClient::create(move |_request| async move {
1305            Ok(http_client::Response::builder()
1306                .status(500)
1307                .body(http_client::AsyncBody::from("Internal Server Error"))?)
1308        });
1309
1310        let state = cx.new(|_cx| State {
1311            credentials: Some(make_expired_credentials()),
1312            sign_in_task: None,
1313            refresh_task: None,
1314            load_task: None,
1315            credentials_provider: Arc::new(FakeCredentialsProvider::new()),
1316            auth_generation: 0,
1317            last_auth_error: None,
1318        });
1319
1320        let weak_state = cx.read(|_cx| state.downgrade());
1321        let http: Arc<dyn HttpClient> = http_client;
1322
1323        let weak = weak_state.clone();
1324        let http_clone = http.clone();
1325        let result = cx
1326            .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
1327            .await;
1328
1329        cx.run_until_parked();
1330
1331        assert!(result.is_err(), "transient refresh should return an error");
1332        cx.read(|cx| {
1333            let s = state.read(cx);
1334            assert!(
1335                s.credentials.is_some(),
1336                "credentials should be kept on transient refresh failure"
1337            );
1338            assert!(
1339                s.last_auth_error.is_none(),
1340                "last_auth_error should not be set on transient refresh failure"
1341            );
1342        });
1343    }
1344
1345    #[gpui::test]
1346    async fn test_sign_out_during_refresh_discards_result(cx: &mut TestAppContext) {
1347        let (gate_tx, gate_rx) = futures::channel::oneshot::channel::<()>();
1348        let gate_rx = Arc::new(Mutex::new(Some(gate_rx)));
1349        let gate_rx_clone = gate_rx.clone();
1350
1351        let http_client = FakeHttpClient::create(move |_request| {
1352            let gate_rx = gate_rx_clone.clone();
1353            async move {
1354                // Wait until the gate is opened, simulating a slow network.
1355                let rx = gate_rx.lock().take();
1356                if let Some(rx) = rx {
1357                    let _ = rx.await;
1358                }
1359                let body = fake_token_response();
1360                Ok(http_client::Response::builder()
1361                    .status(200)
1362                    .body(http_client::AsyncBody::from(body))?)
1363            }
1364        });
1365
1366        let creds_provider = Arc::new(FakeCredentialsProvider::new());
1367        let state = cx.new(|_cx| State {
1368            credentials: Some(make_expired_credentials()),
1369            sign_in_task: None,
1370            refresh_task: None,
1371            load_task: None,
1372            credentials_provider: creds_provider.clone(),
1373            auth_generation: 0,
1374            last_auth_error: None,
1375        });
1376
1377        let weak_state = cx.read(|_cx| state.downgrade());
1378        let http: Arc<dyn HttpClient> = http_client;
1379
1380        // Start a refresh
1381        let weak = weak_state.clone();
1382        let http_clone = http.clone();
1383        let refresh_task =
1384            cx.spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await);
1385
1386        cx.run_until_parked();
1387
1388        // Sign out while the refresh is in-flight
1389        cx.update(|cx| {
1390            do_sign_out(&weak_state, cx).detach();
1391        });
1392        cx.run_until_parked();
1393
1394        // Now let the refresh respond by opening the gate
1395        let _ = gate_tx.send(());
1396        cx.run_until_parked();
1397
1398        let result = refresh_task.await;
1399        assert!(result.is_err(), "refresh should fail after sign-out");
1400
1401        cx.read(|cx| {
1402            let s = state.read(cx);
1403            assert!(
1404                s.credentials.is_none(),
1405                "sign-out should have cleared credentials"
1406            );
1407        });
1408    }
1409
1410    #[gpui::test]
1411    async fn test_sign_out_completes_fully(cx: &mut TestAppContext) {
1412        let creds_provider = Arc::new(FakeCredentialsProvider::new());
1413        // Pre-populate the credential store
1414        creds_provider
1415            .storage
1416            .lock()
1417            .replace(("Bearer".to_string(), b"some-creds".to_vec()));
1418
1419        let state = cx.new(|_cx| State {
1420            credentials: Some(make_fresh_credentials()),
1421            sign_in_task: None,
1422            refresh_task: None,
1423            load_task: None,
1424            credentials_provider: creds_provider.clone(),
1425            auth_generation: 0,
1426            last_auth_error: None,
1427        });
1428
1429        let weak_state = cx.read(|_cx| state.downgrade());
1430        let sign_out_task = cx.update(|cx| do_sign_out(&weak_state, cx));
1431
1432        cx.run_until_parked();
1433        sign_out_task.await.expect("sign-out should succeed");
1434
1435        assert!(
1436            creds_provider.storage.lock().is_none(),
1437            "credential store should be empty after sign-out"
1438        );
1439        cx.read(|cx| {
1440            assert!(
1441                !state.read(cx).is_authenticated(),
1442                "state should show not authenticated"
1443            );
1444        });
1445    }
1446
1447    #[gpui::test]
1448    async fn test_authenticate_awaits_initial_load(cx: &mut TestAppContext) {
1449        let creds = make_fresh_credentials();
1450        let creds_json = serde_json::to_vec(&creds).unwrap();
1451        let creds_provider = Arc::new(FakeCredentialsProvider::new());
1452        creds_provider
1453            .storage
1454            .lock()
1455            .replace(("Bearer".to_string(), creds_json));
1456
1457        let http_client = FakeHttpClient::create(|_| async {
1458            Ok(http_client::Response::builder()
1459                .status(200)
1460                .body(http_client::AsyncBody::default())?)
1461        });
1462
1463        let provider =
1464            cx.update(|cx| OpenAiSubscribedProvider::new(http_client, creds_provider, cx));
1465
1466        // Before load completes, authenticate should still await the load.
1467        let auth_task = cx.update(|cx| provider.authenticate(cx));
1468
1469        // Drive the load to completion.
1470        cx.run_until_parked();
1471
1472        let result = auth_task.await;
1473        assert!(
1474            result.is_ok(),
1475            "authenticate should succeed after load completes with valid credentials"
1476        );
1477    }
1478}