1use anyhow::{Context as _, Result, anyhow};
2use base64::Engine as _;
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use credentials_provider::CredentialsProvider;
5use futures::{FutureExt, StreamExt, future::BoxFuture, future::Shared};
6use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
7use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
8use language_model::{
9 AuthenticateError, IconOrSvg, LanguageModel, LanguageModelCompletionError,
10 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
11 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
12 LanguageModelRequest, LanguageModelToolChoice, RateLimiter,
13};
14use open_ai::{ReasoningEffort, responses::stream_response};
15use rand::RngCore as _;
16use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18use std::sync::Arc;
19use std::time::{SystemTime, UNIX_EPOCH};
20use ui::{ConfiguredApiCard, prelude::*};
21use url::form_urlencoded;
22use util::ResultExt as _;
23
24use crate::provider::open_ai::{OpenAiResponseEventMapper, into_open_ai_response};
25
26const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openai-subscribed");
27const PROVIDER_NAME: LanguageModelProviderName =
28 LanguageModelProviderName::new("ChatGPT Subscription");
29
30const CODEX_BASE_URL: &str = "https://chatgpt.com/backend-api/codex";
31const OPENAI_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
32const OPENAI_AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
33const CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
34
35const CREDENTIALS_KEY: &str = "https://chatgpt.com/backend-api/codex";
36const TOKEN_REFRESH_BUFFER_MS: u64 = 5 * 60 * 1000;
37
38#[derive(Serialize, Deserialize, Clone, Debug)]
39struct CodexCredentials {
40 access_token: String,
41 refresh_token: String,
42 expires_at_ms: u64,
43 account_id: Option<String>,
44 email: Option<String>,
45}
46
47impl CodexCredentials {
48 fn is_expired(&self) -> bool {
49 let now = now_ms();
50 now + TOKEN_REFRESH_BUFFER_MS >= self.expires_at_ms
51 }
52}
53
54pub struct State {
55 credentials: Option<CodexCredentials>,
56 sign_in_task: Option<Task<Result<()>>>,
57 refresh_task: Option<Shared<Task<Result<CodexCredentials, Arc<anyhow::Error>>>>>,
58 load_task: Option<Shared<Task<Result<(), Arc<anyhow::Error>>>>>,
59 credentials_provider: Arc<dyn CredentialsProvider>,
60 auth_generation: u64,
61 last_auth_error: Option<SharedString>,
62}
63
64#[derive(Debug)]
65enum RefreshError {
66 Fatal(anyhow::Error),
67 Transient(anyhow::Error),
68}
69
70impl std::fmt::Display for RefreshError {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 match self {
73 RefreshError::Fatal(e) => write!(f, "{e}"),
74 RefreshError::Transient(e) => write!(f, "{e}"),
75 }
76 }
77}
78
79impl State {
80 fn is_authenticated(&self) -> bool {
81 self.credentials.is_some()
82 }
83
84 fn email(&self) -> Option<&str> {
85 self.credentials.as_ref().and_then(|c| c.email.as_deref())
86 }
87
88 fn is_signing_in(&self) -> bool {
89 self.sign_in_task.is_some()
90 }
91}
92
93pub struct OpenAiSubscribedProvider {
94 http_client: Arc<dyn HttpClient>,
95 state: Entity<State>,
96}
97
98impl OpenAiSubscribedProvider {
99 pub fn new(
100 http_client: Arc<dyn HttpClient>,
101 credentials_provider: Arc<dyn CredentialsProvider>,
102 cx: &mut App,
103 ) -> Self {
104 let state = cx.new(|_cx| State {
105 credentials: None,
106 sign_in_task: None,
107 refresh_task: None,
108 load_task: None,
109 credentials_provider,
110 auth_generation: 0,
111 last_auth_error: None,
112 });
113
114 let provider = Self { http_client, state };
115
116 provider.load_credentials(cx);
117
118 provider
119 }
120
121 fn load_credentials(&self, cx: &mut App) {
122 let state = self.state.downgrade();
123 let load_task = cx
124 .spawn(async move |cx| {
125 let credentials_provider =
126 state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
127 let result = credentials_provider
128 .read_credentials(CREDENTIALS_KEY, &*cx)
129 .await;
130 state.update(cx, |s, cx| {
131 if let Ok(Some((_, bytes))) = result {
132 match serde_json::from_slice::<CodexCredentials>(&bytes) {
133 Ok(creds) => s.credentials = Some(creds),
134 Err(err) => {
135 log::warn!(
136 "Failed to deserialize ChatGPT subscription credentials: {err}"
137 );
138 }
139 }
140 }
141 s.load_task = None;
142 cx.notify();
143 })?;
144 Ok::<(), Arc<anyhow::Error>>(())
145 })
146 .shared();
147
148 self.state.update(cx, |s, _| {
149 s.load_task = Some(load_task);
150 });
151 }
152
153 fn sign_out(&self, cx: &mut App) -> Task<Result<()>> {
154 do_sign_out(&self.state.downgrade(), cx)
155 }
156
157 fn create_language_model(&self, model: ChatGptModel) -> Arc<dyn LanguageModel> {
158 Arc::new(OpenAiSubscribedLanguageModel {
159 id: LanguageModelId::from(model.id().to_string()),
160 model,
161 state: self.state.clone(),
162 http_client: self.http_client.clone(),
163 request_limiter: RateLimiter::new(4),
164 })
165 }
166}
167
168impl LanguageModelProviderState for OpenAiSubscribedProvider {
169 type ObservableEntity = State;
170
171 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
172 Some(self.state.clone())
173 }
174}
175
176impl LanguageModelProvider for OpenAiSubscribedProvider {
177 fn id(&self) -> LanguageModelProviderId {
178 PROVIDER_ID
179 }
180
181 fn name(&self) -> LanguageModelProviderName {
182 PROVIDER_NAME
183 }
184
185 fn icon(&self) -> IconOrSvg {
186 IconOrSvg::Icon(IconName::AiOpenAi)
187 }
188
189 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
190 Some(self.create_language_model(ChatGptModel::Gpt54))
191 }
192
193 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
194 Some(self.create_language_model(ChatGptModel::Gpt54Mini))
195 }
196
197 fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
198 ChatGptModel::all()
199 .into_iter()
200 .map(|m| self.create_language_model(m))
201 .collect()
202 }
203
204 fn is_authenticated(&self, cx: &App) -> bool {
205 self.state.read(cx).is_authenticated()
206 }
207
208 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
209 if self.is_authenticated(cx) {
210 return Task::ready(Ok(()));
211 }
212 let load_task = self.state.read(cx).load_task.clone();
213 if let Some(load_task) = load_task {
214 let weak_state = self.state.downgrade();
215 cx.spawn(async move |cx| {
216 let _ = load_task.await;
217 let is_auth = weak_state
218 .read_with(&*cx, |s, _| s.is_authenticated())
219 .unwrap_or(false);
220 if is_auth {
221 Ok(())
222 } else {
223 Err(anyhow!(
224 "Sign in with your ChatGPT Plus or Pro subscription to use this provider."
225 )
226 .into())
227 }
228 })
229 } else {
230 Task::ready(Err(anyhow!(
231 "Sign in with your ChatGPT Plus or Pro subscription to use this provider."
232 )
233 .into()))
234 }
235 }
236
237 fn configuration_view(
238 &self,
239 _target_agent: language_model::ConfigurationViewTargetAgent,
240 _window: &mut Window,
241 cx: &mut App,
242 ) -> AnyView {
243 let state = self.state.clone();
244 let http_client = self.http_client.clone();
245 cx.new(|_cx| ConfigurationView { state, http_client })
246 .into()
247 }
248
249 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
250 self.sign_out(cx)
251 }
252}
253
254//
255// The ChatGPT Subscription provider routes requests to chatgpt.com/backend-api/codex,
256// which only supports a subset of OpenAI models. This list is maintained separately
257// from the standard OpenAI API model list (open_ai::Model).
258
259#[derive(Clone, Debug, PartialEq)]
260enum ChatGptModel {
261 Gpt5,
262 Gpt5Codex,
263 Gpt5CodexMini,
264 Gpt51,
265 Gpt51Codex,
266 Gpt51CodexMax,
267 Gpt51CodexMini,
268 Gpt52,
269 Gpt52Codex,
270 Gpt53Codex,
271 Gpt53CodexSpark,
272 Gpt54,
273 Gpt54Mini,
274}
275
276impl ChatGptModel {
277 fn all() -> Vec<Self> {
278 vec![
279 Self::Gpt54,
280 Self::Gpt54Mini,
281 Self::Gpt53Codex,
282 Self::Gpt53CodexSpark,
283 Self::Gpt52Codex,
284 Self::Gpt52,
285 Self::Gpt51CodexMax,
286 Self::Gpt51Codex,
287 Self::Gpt51CodexMini,
288 Self::Gpt51,
289 Self::Gpt5Codex,
290 Self::Gpt5CodexMini,
291 Self::Gpt5,
292 ]
293 }
294
295 fn id(&self) -> &str {
296 match self {
297 Self::Gpt5 => "gpt-5",
298 Self::Gpt5Codex => "gpt-5-codex",
299 Self::Gpt5CodexMini => "gpt-5-codex-mini",
300 Self::Gpt51 => "gpt-5.1",
301 Self::Gpt51Codex => "gpt-5.1-codex",
302 Self::Gpt51CodexMax => "gpt-5.1-codex-max",
303 Self::Gpt51CodexMini => "gpt-5.1-codex-mini",
304 Self::Gpt52 => "gpt-5.2",
305 Self::Gpt52Codex => "gpt-5.2-codex",
306 Self::Gpt53Codex => "gpt-5.3-codex",
307 Self::Gpt53CodexSpark => "gpt-5.3-codex-spark",
308 Self::Gpt54 => "gpt-5.4",
309 Self::Gpt54Mini => "gpt-5.4-mini",
310 }
311 }
312
313 fn display_name(&self) -> &str {
314 match self {
315 Self::Gpt5 => "GPT-5",
316 Self::Gpt5Codex => "GPT-5 Codex",
317 Self::Gpt5CodexMini => "GPT-5 Codex Mini",
318 Self::Gpt51 => "GPT-5.1",
319 Self::Gpt51Codex => "GPT-5.1 Codex",
320 Self::Gpt51CodexMax => "GPT-5.1 Codex Max",
321 Self::Gpt51CodexMini => "GPT-5.1 Codex Mini",
322 Self::Gpt52 => "GPT-5.2",
323 Self::Gpt52Codex => "GPT-5.2 Codex",
324 Self::Gpt53Codex => "GPT-5.3 Codex",
325 Self::Gpt53CodexSpark => "GPT-5.3 Codex Spark",
326 Self::Gpt54 => "GPT-5.4",
327 Self::Gpt54Mini => "GPT-5.4 Mini",
328 }
329 }
330
331 fn max_token_count(&self) -> u64 {
332 match self {
333 Self::Gpt53CodexSpark => 128_000,
334 Self::Gpt54 | Self::Gpt54Mini => 1_050_000,
335 _ => 400_000,
336 }
337 }
338
339 fn max_output_tokens(&self) -> Option<u64> {
340 match self {
341 Self::Gpt53CodexSpark => Some(8_192),
342 _ => Some(128_000),
343 }
344 }
345
346 fn supports_images(&self) -> bool {
347 !matches!(self, Self::Gpt53CodexSpark)
348 }
349
350 fn reasoning_effort(&self) -> Option<ReasoningEffort> {
351 match self {
352 Self::Gpt54 | Self::Gpt54Mini => None,
353 _ => Some(ReasoningEffort::Medium),
354 }
355 }
356
357 fn supports_parallel_tool_calls(&self) -> bool {
358 match self {
359 Self::Gpt54 | Self::Gpt54Mini => true,
360 _ => false,
361 }
362 }
363
364 fn supports_prompt_cache_key(&self) -> bool {
365 true
366 }
367}
368
369struct OpenAiSubscribedLanguageModel {
370 id: LanguageModelId,
371 model: ChatGptModel,
372 state: Entity<State>,
373 http_client: Arc<dyn HttpClient>,
374 request_limiter: RateLimiter,
375}
376
377impl LanguageModel for OpenAiSubscribedLanguageModel {
378 fn id(&self) -> LanguageModelId {
379 self.id.clone()
380 }
381
382 fn name(&self) -> LanguageModelName {
383 LanguageModelName::from(self.model.display_name().to_string())
384 }
385
386 fn provider_id(&self) -> LanguageModelProviderId {
387 PROVIDER_ID
388 }
389
390 fn provider_name(&self) -> LanguageModelProviderName {
391 PROVIDER_NAME
392 }
393
394 fn supports_tools(&self) -> bool {
395 true
396 }
397
398 fn supports_images(&self) -> bool {
399 self.model.supports_images()
400 }
401
402 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
403 true
404 }
405
406 fn supports_streaming_tools(&self) -> bool {
407 true
408 }
409
410 fn supports_thinking(&self) -> bool {
411 self.model.reasoning_effort().is_some()
412 }
413
414 fn telemetry_id(&self) -> String {
415 format!("openai-subscribed/{}", self.model.id())
416 }
417
418 fn max_token_count(&self) -> u64 {
419 self.model.max_token_count()
420 }
421
422 fn max_output_tokens(&self) -> Option<u64> {
423 self.model.max_output_tokens()
424 }
425
426 fn count_tokens(
427 &self,
428 request: LanguageModelRequest,
429 cx: &App,
430 ) -> BoxFuture<'static, Result<u64>> {
431 let max_token_count = self.model.max_token_count();
432 cx.background_spawn(async move {
433 let messages = crate::provider::open_ai::collect_tiktoken_messages(request);
434 let model = if max_token_count >= 100_000 {
435 "gpt-4o"
436 } else {
437 "gpt-4"
438 };
439 tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
440 })
441 .boxed()
442 }
443
444 fn stream_completion(
445 &self,
446 request: LanguageModelRequest,
447 cx: &AsyncApp,
448 ) -> BoxFuture<
449 'static,
450 Result<
451 futures::stream::BoxStream<
452 'static,
453 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
454 >,
455 LanguageModelCompletionError,
456 >,
457 > {
458 let mut responses_request = into_open_ai_response(
459 request,
460 self.model.id(),
461 self.model.supports_parallel_tool_calls(),
462 self.model.supports_prompt_cache_key(),
463 self.max_output_tokens(),
464 self.model.reasoning_effort(),
465 );
466 responses_request.store = Some(false);
467
468 // The Codex backend requires system messages to be in the top-level
469 // `instructions` field rather than as input items.
470 let mut instructions = Vec::new();
471 responses_request.input.retain(|item| {
472 if let open_ai::responses::ResponseInputItem::Message(msg) = item {
473 if msg.role == open_ai::Role::System {
474 for part in &msg.content {
475 if let open_ai::responses::ResponseInputContent::Text { text } = part {
476 instructions.push(text.clone());
477 }
478 }
479 return false;
480 }
481 }
482 true
483 });
484 if !instructions.is_empty() {
485 responses_request.instructions = Some(instructions.join("\n\n"));
486 }
487
488 let state = self.state.downgrade();
489 let http_client = self.http_client.clone();
490 let request_limiter = self.request_limiter.clone();
491
492 let future = cx.spawn(async move |cx| {
493 let creds = get_fresh_credentials(&state, &http_client, cx).await?;
494
495 let mut extra_headers: Vec<(String, String)> = vec![
496 ("originator".into(), "zed".into()),
497 ("OpenAI-Beta".into(), "responses=experimental".into()),
498 ];
499 if let Some(ref id) = creds.account_id {
500 if !id.is_empty() {
501 extra_headers.push(("ChatGPT-Account-Id".into(), id.clone()));
502 }
503 }
504
505 let access_token = creds.access_token.clone();
506 request_limiter
507 .stream(async move {
508 stream_response(
509 http_client.as_ref(),
510 PROVIDER_NAME.0.as_str(),
511 CODEX_BASE_URL,
512 &access_token,
513 responses_request,
514 extra_headers,
515 )
516 .await
517 .map_err(LanguageModelCompletionError::from)
518 })
519 .await
520 });
521
522 async move {
523 let mapper = OpenAiResponseEventMapper::new();
524 Ok(mapper.map_stream(future.await?.boxed()).boxed())
525 }
526 .boxed()
527 }
528}
529
530async fn get_fresh_credentials(
531 state: &gpui::WeakEntity<State>,
532 http_client: &Arc<dyn HttpClient>,
533 cx: &mut AsyncApp,
534) -> Result<CodexCredentials, LanguageModelCompletionError> {
535 let (creds, existing_task) = state
536 .read_with(&*cx, |s, _| (s.credentials.clone(), s.refresh_task.clone()))
537 .map_err(LanguageModelCompletionError::Other)?;
538
539 let creds = creds.ok_or(LanguageModelCompletionError::NoApiKey {
540 provider: PROVIDER_NAME,
541 })?;
542
543 if !creds.is_expired() {
544 return Ok(creds);
545 }
546
547 // If another caller is already refreshing, await their result.
548 if let Some(shared_task) = existing_task {
549 return shared_task
550 .await
551 .map_err(|e| LanguageModelCompletionError::Other(anyhow::anyhow!("{e}")));
552 }
553
554 // We are the first caller to notice expiry — spawn the refresh task.
555 let http_client_clone = http_client.clone();
556 let state_clone = state.clone();
557 let refresh_token_value = creds.refresh_token.clone();
558
559 // Capture the generation so we can detect sign-outs that happened during refresh.
560 let generation = state
561 .read_with(&*cx, |s, _| s.auth_generation)
562 .map_err(LanguageModelCompletionError::Other)?;
563
564 let shared_task = cx
565 .spawn(async move |cx| {
566 let result = refresh_token(&http_client_clone, &refresh_token_value).await;
567
568 match result {
569 Ok(refreshed) => {
570 let persist_result: Result<CodexCredentials, Arc<anyhow::Error>> = async {
571 // Check if auth_generation changed (sign-out during refresh).
572 let current_generation = state_clone
573 .read_with(&*cx, |s, _| s.auth_generation)
574 .map_err(|e| Arc::new(e))?;
575 if current_generation != generation {
576 return Err(Arc::new(anyhow!(
577 "Sign-out occurred during token refresh"
578 )));
579 }
580
581 let credentials_provider = state_clone
582 .read_with(&*cx, |s, _| s.credentials_provider.clone())
583 .map_err(|e| Arc::new(e))?;
584
585 let json =
586 serde_json::to_vec(&refreshed).map_err(|e| Arc::new(e.into()))?;
587
588 credentials_provider
589 .write_credentials(CREDENTIALS_KEY, "Bearer", &json, &*cx)
590 .await
591 .map_err(|e| Arc::new(e))?;
592
593 state_clone
594 .update(cx, |s, _| {
595 s.credentials = Some(refreshed.clone());
596 s.refresh_task = None;
597 })
598 .map_err(|e| Arc::new(e))?;
599
600 Ok(refreshed)
601 }
602 .await;
603
604 // Clear refresh_task on failure too.
605 if persist_result.is_err() {
606 let _ = state_clone.update(cx, |s, _| {
607 s.refresh_task = None;
608 });
609 }
610
611 persist_result
612 }
613 Err(RefreshError::Fatal(e)) => {
614 log::error!("ChatGPT subscription token refresh failed fatally: {e:?}");
615 let _ = state_clone.update(cx, |s, cx| {
616 s.refresh_task = None;
617 s.credentials = None;
618 s.last_auth_error =
619 Some("Your session has expired. Please sign in again.".into());
620 cx.notify();
621 });
622 // Also clear the keychain so stale credentials aren't loaded next time.
623 if let Ok(credentials_provider) =
624 state_clone.read_with(&*cx, |s, _| s.credentials_provider.clone())
625 {
626 credentials_provider
627 .delete_credentials(CREDENTIALS_KEY, &*cx)
628 .await
629 .log_err();
630 }
631 Err(Arc::new(e))
632 }
633 Err(RefreshError::Transient(e)) => {
634 log::warn!("ChatGPT subscription token refresh failed transiently: {e:?}");
635 let _ = state_clone.update(cx, |s, _| {
636 s.refresh_task = None;
637 });
638 Err(Arc::new(e))
639 }
640 }
641 })
642 .shared();
643
644 // Store the shared task so concurrent callers can join on it.
645 state
646 .update(cx, |s, _| {
647 s.refresh_task = Some(shared_task.clone());
648 })
649 .map_err(LanguageModelCompletionError::Other)?;
650
651 shared_task
652 .await
653 .map_err(|e| LanguageModelCompletionError::Other(anyhow::anyhow!("{e}")))
654}
655
656#[derive(Deserialize)]
657struct TokenResponse {
658 access_token: String,
659 refresh_token: String,
660 #[serde(default)]
661 id_token: Option<String>,
662 expires_in: u64,
663 #[serde(default)]
664 email: Option<String>,
665}
666
667async fn do_oauth_flow(
668 http_client: Arc<dyn HttpClient>,
669 cx: &AsyncApp,
670) -> Result<CodexCredentials> {
671 // Start the callback server FIRST so the redirect URI is ready
672 let (redirect_uri, callback_rx) = http_client::start_oauth_callback_server()
673 .context("Failed to start OAuth callback server")?;
674
675 // PKCE verifier: 32 random bytes → base64url (no padding)
676 let mut verifier_bytes = [0u8; 32];
677 rand::rng().fill_bytes(&mut verifier_bytes);
678 let verifier = URL_SAFE_NO_PAD.encode(verifier_bytes);
679
680 // PKCE challenge: SHA-256(verifier) → base64url
681 let mut hasher = Sha256::new();
682 hasher.update(verifier.as_bytes());
683 let challenge = URL_SAFE_NO_PAD.encode(hasher.finalize().as_slice());
684
685 // CSRF state: 16 random bytes → hex string
686 let mut state_bytes = [0u8; 16];
687 rand::rng().fill_bytes(&mut state_bytes);
688 let oauth_state: String = state_bytes.iter().map(|b| format!("{b:02x}")).collect();
689
690 let mut auth_url = url::Url::parse(OPENAI_AUTHORIZE_URL).expect("valid base URL");
691 auth_url
692 .query_pairs_mut()
693 .append_pair("client_id", CLIENT_ID)
694 .append_pair("redirect_uri", &redirect_uri)
695 .append_pair("scope", "openid profile email offline_access")
696 .append_pair("response_type", "code")
697 .append_pair("code_challenge", &challenge)
698 .append_pair("code_challenge_method", "S256")
699 .append_pair("state", &oauth_state)
700 .append_pair("codex_cli_simplified_flow", "true")
701 .append_pair("originator", "zed");
702
703 // Open browser AFTER the listener is ready
704 cx.update(|cx| cx.open_url(auth_url.as_str()));
705
706 // Await the callback
707 let callback = callback_rx
708 .await
709 .map_err(|_| anyhow!("OAuth callback was cancelled"))?
710 .context("OAuth callback failed")?;
711
712 // Validate CSRF state
713 if callback.state != oauth_state {
714 return Err(anyhow!("OAuth state mismatch"));
715 }
716
717 let tokens = exchange_code(&http_client, &callback.code, &verifier, &redirect_uri)
718 .await
719 .context("Token exchange failed")?;
720
721 let jwt = tokens
722 .id_token
723 .as_deref()
724 .unwrap_or(tokens.access_token.as_str());
725 let claims = extract_jwt_claims(jwt);
726
727 Ok(CodexCredentials {
728 access_token: tokens.access_token,
729 refresh_token: tokens.refresh_token,
730 expires_at_ms: now_ms() + tokens.expires_in * 1000,
731 account_id: claims.account_id,
732 email: claims.email.or(tokens.email),
733 })
734}
735
736async fn exchange_code(
737 client: &Arc<dyn HttpClient>,
738 code: &str,
739 verifier: &str,
740 redirect_uri: &str,
741) -> Result<TokenResponse> {
742 let body = form_urlencoded::Serializer::new(String::new())
743 .append_pair("grant_type", "authorization_code")
744 .append_pair("client_id", CLIENT_ID)
745 .append_pair("code", code)
746 .append_pair("redirect_uri", redirect_uri)
747 .append_pair("code_verifier", verifier)
748 .finish();
749
750 let request = HttpRequest::builder()
751 .method(Method::POST)
752 .uri(OPENAI_TOKEN_URL)
753 .header("Content-Type", "application/x-www-form-urlencoded")
754 .body(AsyncBody::from(body))?;
755
756 let mut response = client.send(request).await?;
757 let mut body = String::new();
758 smol::io::AsyncReadExt::read_to_string(response.body_mut(), &mut body).await?;
759
760 if !response.status().is_success() {
761 return Err(anyhow!(
762 "Token exchange failed (HTTP {}): {body}",
763 response.status()
764 ));
765 }
766
767 serde_json::from_str::<TokenResponse>(&body).context("Failed to parse token response")
768}
769
770async fn refresh_token(
771 client: &Arc<dyn HttpClient>,
772 refresh_token: &str,
773) -> Result<CodexCredentials, RefreshError> {
774 let body = form_urlencoded::Serializer::new(String::new())
775 .append_pair("grant_type", "refresh_token")
776 .append_pair("client_id", CLIENT_ID)
777 .append_pair("refresh_token", refresh_token)
778 .finish();
779
780 let request = HttpRequest::builder()
781 .method(Method::POST)
782 .uri(OPENAI_TOKEN_URL)
783 .header("Content-Type", "application/x-www-form-urlencoded")
784 .body(AsyncBody::from(body))
785 .map_err(|e| RefreshError::Transient(e.into()))?;
786
787 let mut response = client
788 .send(request)
789 .await
790 .map_err(|e| RefreshError::Transient(e))?;
791 let status = response.status();
792 let mut body = String::new();
793 smol::io::AsyncReadExt::read_to_string(response.body_mut(), &mut body)
794 .await
795 .map_err(|e| RefreshError::Transient(e.into()))?;
796
797 if !status.is_success() {
798 let err = anyhow!("Token refresh failed (HTTP {}): {body}", status);
799 // 400/401/403 indicate a revoked or invalid refresh token.
800 // 5xx and other errors are treated as transient.
801 if status == http_client::StatusCode::BAD_REQUEST
802 || status == http_client::StatusCode::UNAUTHORIZED
803 || status == http_client::StatusCode::FORBIDDEN
804 {
805 return Err(RefreshError::Fatal(err));
806 }
807 return Err(RefreshError::Transient(err));
808 }
809
810 let tokens: TokenResponse =
811 serde_json::from_str(&body).map_err(|e| RefreshError::Transient(e.into()))?;
812 let jwt = tokens
813 .id_token
814 .as_deref()
815 .unwrap_or(tokens.access_token.as_str());
816 let claims = extract_jwt_claims(jwt);
817
818 Ok(CodexCredentials {
819 access_token: tokens.access_token,
820 refresh_token: tokens.refresh_token,
821 expires_at_ms: now_ms() + tokens.expires_in * 1000,
822 account_id: claims.account_id,
823 email: claims.email.or(tokens.email),
824 })
825}
826
827struct JwtClaims {
828 account_id: Option<String>,
829 email: Option<String>,
830}
831
832/// Extract claims from a JWT payload (base64url middle segment).
833/// Extracts `chatgpt_account_id` from three possible locations (matching Roo Code's
834/// implementation) and the `email` claim.
835fn extract_jwt_claims(jwt: &str) -> JwtClaims {
836 let Some(payload_b64) = jwt.split('.').nth(1) else {
837 return JwtClaims {
838 account_id: None,
839 email: None,
840 };
841 };
842 let Ok(payload) = URL_SAFE_NO_PAD.decode(payload_b64) else {
843 return JwtClaims {
844 account_id: None,
845 email: None,
846 };
847 };
848 let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&payload) else {
849 return JwtClaims {
850 account_id: None,
851 email: None,
852 };
853 };
854
855 let account_id = claims
856 .get("chatgpt_account_id")
857 .and_then(|v| v.as_str())
858 .or_else(|| {
859 claims
860 .get("https://api.openai.com/auth")
861 .and_then(|v| v.get("chatgpt_account_id"))
862 .and_then(|v| v.as_str())
863 })
864 .or_else(|| {
865 claims
866 .get("organizations")
867 .and_then(|v| v.as_array())
868 .and_then(|arr| arr.first())
869 .and_then(|org| org.get("id"))
870 .and_then(|v| v.as_str())
871 })
872 .map(|s| s.to_owned());
873
874 let email = claims
875 .get("email")
876 .and_then(|v| v.as_str())
877 .map(|s| s.to_owned());
878
879 JwtClaims { account_id, email }
880}
881
882fn now_ms() -> u64 {
883 SystemTime::now()
884 .duration_since(UNIX_EPOCH)
885 .map(|d| d.as_millis() as u64)
886 .unwrap_or_else(|err| {
887 log::error!("System clock is before UNIX epoch: {err}");
888 0
889 })
890}
891
892fn do_sign_in(state: &Entity<State>, http_client: &Arc<dyn HttpClient>, cx: &mut App) {
893 if state.read(cx).is_signing_in() {
894 return;
895 }
896
897 let weak_state = state.downgrade();
898 let http_client = http_client.clone();
899
900 let task = cx.spawn(async move |cx| {
901 match do_oauth_flow(http_client, &*cx).await {
902 Ok(creds) => {
903 let persist_result = async {
904 let credentials_provider =
905 weak_state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
906 let json = serde_json::to_vec(&creds)?;
907 credentials_provider
908 .write_credentials(CREDENTIALS_KEY, "Bearer", &json, &*cx)
909 .await?;
910 anyhow::Ok(())
911 }
912 .await;
913
914 match persist_result {
915 Ok(()) => {
916 weak_state
917 .update(cx, |s, cx| {
918 s.credentials = Some(creds);
919 s.sign_in_task = None;
920 s.last_auth_error = None;
921 cx.notify();
922 })
923 .log_err();
924 }
925 Err(err) => {
926 log::error!(
927 "ChatGPT subscription sign-in failed to persist credentials: {err:?}"
928 );
929 weak_state
930 .update(cx, |s, cx| {
931 s.sign_in_task = None;
932 s.last_auth_error =
933 Some("Failed to save credentials. Please try again.".into());
934 cx.notify();
935 })
936 .log_err();
937 }
938 }
939 }
940 Err(err) => {
941 log::error!("ChatGPT subscription sign-in failed: {err:?}");
942 weak_state
943 .update(cx, |s, cx| {
944 s.sign_in_task = None;
945 s.last_auth_error = Some("Sign-in failed. Please try again.".into());
946 cx.notify();
947 })
948 .log_err();
949 }
950 }
951 anyhow::Ok(())
952 });
953
954 state.update(cx, |s, cx| {
955 s.last_auth_error = None;
956 s.sign_in_task = Some(task);
957 cx.notify();
958 });
959}
960
961fn do_sign_out(state: &gpui::WeakEntity<State>, cx: &mut App) -> Task<Result<()>> {
962 let weak_state = state.clone();
963 // Clear credentials and cancel in-flight work immediately so the UI
964 // reflects the sign-out right away.
965 weak_state
966 .update(cx, |s, cx| {
967 s.auth_generation += 1;
968 s.credentials = None;
969 s.sign_in_task = None;
970 s.refresh_task = None;
971 s.last_auth_error = None;
972 cx.notify();
973 })
974 .log_err();
975
976 cx.spawn(async move |cx| {
977 let credentials_provider =
978 weak_state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
979 credentials_provider
980 .delete_credentials(CREDENTIALS_KEY, &*cx)
981 .await
982 .context("Failed to delete ChatGPT subscription credentials from keychain")?;
983 anyhow::Ok(())
984 })
985}
986
987struct ConfigurationView {
988 state: Entity<State>,
989 http_client: Arc<dyn HttpClient>,
990}
991
992impl Render for ConfigurationView {
993 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
994 let state = self.state.read(cx);
995
996 if state.is_authenticated() {
997 let label = state
998 .email()
999 .map(|e| format!("Signed in as {e}"))
1000 .unwrap_or_else(|| "Signed in".to_string());
1001
1002 let weak_state = self.state.downgrade();
1003 return v_flex()
1004 .child(
1005 ConfiguredApiCard::new(SharedString::from(label))
1006 .button_label("Sign Out")
1007 .on_click(cx.listener(move |_this, _, _window, cx| {
1008 do_sign_out(&weak_state, cx).detach_and_log_err(cx);
1009 })),
1010 )
1011 .into_any_element();
1012 }
1013
1014 if state.is_signing_in() {
1015 return v_flex()
1016 .child(Label::new("Signing in…").color(Color::Muted))
1017 .into_any_element();
1018 }
1019
1020 let last_auth_error = state.last_auth_error.clone();
1021 let provider_state = self.state.clone();
1022 let http_client = self.http_client.clone();
1023
1024 v_flex()
1025 .gap_2()
1026 .when_some(last_auth_error, |this, error| {
1027 this.child(Label::new(error).color(Color::Error))
1028 })
1029 .child(Label::new(
1030 "Sign in with your ChatGPT Plus or Pro subscription to use OpenAI models in Zed's agent.",
1031 ))
1032 .child(
1033 Button::new("sign-in", "Sign in with ChatGPT")
1034 .on_click(move |_, _window, cx| {
1035 do_sign_in(&provider_state, &http_client, cx);
1036 }),
1037 )
1038 .into_any_element()
1039 }
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044 use super::*;
1045 use gpui::TestAppContext;
1046 use http_client::FakeHttpClient;
1047 use parking_lot::Mutex;
1048 use std::future::Future;
1049 use std::pin::Pin;
1050 use std::sync::atomic::{AtomicUsize, Ordering};
1051
1052 struct FakeCredentialsProvider {
1053 storage: Mutex<Option<(String, Vec<u8>)>>,
1054 }
1055
1056 impl FakeCredentialsProvider {
1057 fn new() -> Self {
1058 Self {
1059 storage: Mutex::new(None),
1060 }
1061 }
1062 }
1063
1064 impl CredentialsProvider for FakeCredentialsProvider {
1065 fn read_credentials<'a>(
1066 &'a self,
1067 _url: &'a str,
1068 _cx: &'a AsyncApp,
1069 ) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
1070 Box::pin(async { Ok(self.storage.lock().clone()) })
1071 }
1072
1073 fn write_credentials<'a>(
1074 &'a self,
1075 _url: &'a str,
1076 username: &'a str,
1077 password: &'a [u8],
1078 _cx: &'a AsyncApp,
1079 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1080 self.storage
1081 .lock()
1082 .replace((username.to_string(), password.to_vec()));
1083 Box::pin(async { Ok(()) })
1084 }
1085
1086 fn delete_credentials<'a>(
1087 &'a self,
1088 _url: &'a str,
1089 _cx: &'a AsyncApp,
1090 ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
1091 *self.storage.lock() = None;
1092 Box::pin(async { Ok(()) })
1093 }
1094 }
1095
1096 fn make_expired_credentials() -> CodexCredentials {
1097 CodexCredentials {
1098 access_token: "old_access".to_string(),
1099 refresh_token: "old_refresh".to_string(),
1100 expires_at_ms: 0,
1101 account_id: None,
1102 email: None,
1103 }
1104 }
1105
1106 fn make_fresh_credentials() -> CodexCredentials {
1107 CodexCredentials {
1108 access_token: "fresh_access".to_string(),
1109 refresh_token: "fresh_refresh".to_string(),
1110 expires_at_ms: now_ms() + 3_600_000,
1111 account_id: None,
1112 email: None,
1113 }
1114 }
1115
1116 fn fake_token_response() -> String {
1117 serde_json::json!({
1118 "access_token": "fresh_access",
1119 "refresh_token": "fresh_refresh",
1120 "expires_in": 3600
1121 })
1122 .to_string()
1123 }
1124
1125 #[gpui::test]
1126 async fn test_concurrent_refresh_deduplicates(cx: &mut TestAppContext) {
1127 let refresh_count = Arc::new(AtomicUsize::new(0));
1128 let refresh_count_clone = refresh_count.clone();
1129
1130 let http_client = FakeHttpClient::create(move |_request| {
1131 let refresh_count = refresh_count_clone.clone();
1132 async move {
1133 refresh_count.fetch_add(1, Ordering::SeqCst);
1134 let body = fake_token_response();
1135 Ok(http_client::Response::builder()
1136 .status(200)
1137 .body(http_client::AsyncBody::from(body))?)
1138 }
1139 });
1140
1141 let state = cx.new(|_cx| State {
1142 credentials: Some(make_expired_credentials()),
1143 sign_in_task: None,
1144 refresh_task: None,
1145 load_task: None,
1146 credentials_provider: Arc::new(FakeCredentialsProvider::new()),
1147 auth_generation: 0,
1148 last_auth_error: None,
1149 });
1150
1151 let weak_state = cx.read(|_cx| state.downgrade());
1152 let http: Arc<dyn HttpClient> = http_client;
1153
1154 // Spawn two concurrent refresh attempts.
1155 let weak1 = weak_state.clone();
1156 let http1 = http.clone();
1157 let task1 =
1158 cx.spawn(async move |mut cx| get_fresh_credentials(&weak1, &http1, &mut cx).await);
1159
1160 let weak2 = weak_state.clone();
1161 let http2 = http.clone();
1162 let task2 =
1163 cx.spawn(async move |mut cx| get_fresh_credentials(&weak2, &http2, &mut cx).await);
1164
1165 // Drive both to completion.
1166 cx.run_until_parked();
1167 let result1 = task1.await;
1168 let result2 = task2.await;
1169
1170 assert!(result1.is_ok(), "first refresh should succeed");
1171 assert!(result2.is_ok(), "second refresh should succeed");
1172 assert_eq!(result1.unwrap().access_token, "fresh_access");
1173 assert_eq!(result2.unwrap().access_token, "fresh_access");
1174 assert_eq!(
1175 refresh_count.load(Ordering::SeqCst),
1176 1,
1177 "refresh_token should only be called once despite two concurrent callers"
1178 );
1179 }
1180
1181 #[gpui::test]
1182 async fn test_fresh_credentials_skip_refresh(cx: &mut TestAppContext) {
1183 let refresh_count = Arc::new(AtomicUsize::new(0));
1184 let refresh_count_clone = refresh_count.clone();
1185
1186 let http_client = FakeHttpClient::create(move |_request| {
1187 let refresh_count = refresh_count_clone.clone();
1188 async move {
1189 refresh_count.fetch_add(1, Ordering::SeqCst);
1190 let body = fake_token_response();
1191 Ok(http_client::Response::builder()
1192 .status(200)
1193 .body(http_client::AsyncBody::from(body))?)
1194 }
1195 });
1196
1197 let state = cx.new(|_cx| State {
1198 credentials: Some(make_fresh_credentials()),
1199 sign_in_task: None,
1200 refresh_task: None,
1201 load_task: None,
1202 credentials_provider: Arc::new(FakeCredentialsProvider::new()),
1203 auth_generation: 0,
1204 last_auth_error: None,
1205 });
1206
1207 let weak_state = cx.read(|_cx| state.downgrade());
1208 let http: Arc<dyn HttpClient> = http_client;
1209
1210 let weak = weak_state.clone();
1211 let http_clone = http.clone();
1212 let result = cx
1213 .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
1214 .await;
1215
1216 assert!(result.is_ok());
1217 assert_eq!(result.unwrap().access_token, "fresh_access");
1218 assert_eq!(
1219 refresh_count.load(Ordering::SeqCst),
1220 0,
1221 "no refresh should happen when credentials are fresh"
1222 );
1223 }
1224
1225 #[gpui::test]
1226 async fn test_no_credentials_returns_no_api_key(cx: &mut TestAppContext) {
1227 let http_client = FakeHttpClient::create(|_| async {
1228 Ok(http_client::Response::builder()
1229 .status(200)
1230 .body(http_client::AsyncBody::default())?)
1231 });
1232
1233 let state = cx.new(|_cx| State {
1234 credentials: None,
1235 sign_in_task: None,
1236 refresh_task: None,
1237 load_task: None,
1238 credentials_provider: Arc::new(FakeCredentialsProvider::new()),
1239 auth_generation: 0,
1240 last_auth_error: None,
1241 });
1242
1243 let weak_state = cx.read(|_cx| state.downgrade());
1244 let http: Arc<dyn HttpClient> = http_client;
1245
1246 let weak = weak_state.clone();
1247 let http_clone = http.clone();
1248 let result = cx
1249 .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
1250 .await;
1251
1252 assert!(matches!(
1253 result,
1254 Err(LanguageModelCompletionError::NoApiKey { .. })
1255 ));
1256 }
1257
1258 #[gpui::test]
1259 async fn test_fatal_refresh_clears_auth_state(cx: &mut TestAppContext) {
1260 let http_client = FakeHttpClient::create(move |_request| async move {
1261 Ok(http_client::Response::builder()
1262 .status(401)
1263 .body(http_client::AsyncBody::from(r#"{"error":"invalid_grant"}"#))?)
1264 });
1265
1266 let creds_provider = Arc::new(FakeCredentialsProvider::new());
1267 let state = cx.new(|_cx| State {
1268 credentials: Some(make_expired_credentials()),
1269 sign_in_task: None,
1270 refresh_task: None,
1271 load_task: None,
1272 credentials_provider: creds_provider.clone(),
1273 auth_generation: 0,
1274 last_auth_error: None,
1275 });
1276
1277 let weak_state = cx.read(|_cx| state.downgrade());
1278 let http: Arc<dyn HttpClient> = http_client;
1279
1280 let weak = weak_state.clone();
1281 let http_clone = http.clone();
1282 let result = cx
1283 .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
1284 .await;
1285
1286 cx.run_until_parked();
1287
1288 assert!(result.is_err(), "fatal refresh should return an error");
1289 cx.read(|cx| {
1290 let s = state.read(cx);
1291 assert!(
1292 s.credentials.is_none(),
1293 "credentials should be cleared on fatal refresh failure"
1294 );
1295 assert!(
1296 s.last_auth_error.is_some(),
1297 "last_auth_error should be set on fatal refresh failure"
1298 );
1299 });
1300 }
1301
1302 #[gpui::test]
1303 async fn test_transient_refresh_keeps_credentials(cx: &mut TestAppContext) {
1304 let http_client = FakeHttpClient::create(move |_request| async move {
1305 Ok(http_client::Response::builder()
1306 .status(500)
1307 .body(http_client::AsyncBody::from("Internal Server Error"))?)
1308 });
1309
1310 let state = cx.new(|_cx| State {
1311 credentials: Some(make_expired_credentials()),
1312 sign_in_task: None,
1313 refresh_task: None,
1314 load_task: None,
1315 credentials_provider: Arc::new(FakeCredentialsProvider::new()),
1316 auth_generation: 0,
1317 last_auth_error: None,
1318 });
1319
1320 let weak_state = cx.read(|_cx| state.downgrade());
1321 let http: Arc<dyn HttpClient> = http_client;
1322
1323 let weak = weak_state.clone();
1324 let http_clone = http.clone();
1325 let result = cx
1326 .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
1327 .await;
1328
1329 cx.run_until_parked();
1330
1331 assert!(result.is_err(), "transient refresh should return an error");
1332 cx.read(|cx| {
1333 let s = state.read(cx);
1334 assert!(
1335 s.credentials.is_some(),
1336 "credentials should be kept on transient refresh failure"
1337 );
1338 assert!(
1339 s.last_auth_error.is_none(),
1340 "last_auth_error should not be set on transient refresh failure"
1341 );
1342 });
1343 }
1344
1345 #[gpui::test]
1346 async fn test_sign_out_during_refresh_discards_result(cx: &mut TestAppContext) {
1347 let (gate_tx, gate_rx) = futures::channel::oneshot::channel::<()>();
1348 let gate_rx = Arc::new(Mutex::new(Some(gate_rx)));
1349 let gate_rx_clone = gate_rx.clone();
1350
1351 let http_client = FakeHttpClient::create(move |_request| {
1352 let gate_rx = gate_rx_clone.clone();
1353 async move {
1354 // Wait until the gate is opened, simulating a slow network.
1355 let rx = gate_rx.lock().take();
1356 if let Some(rx) = rx {
1357 let _ = rx.await;
1358 }
1359 let body = fake_token_response();
1360 Ok(http_client::Response::builder()
1361 .status(200)
1362 .body(http_client::AsyncBody::from(body))?)
1363 }
1364 });
1365
1366 let creds_provider = Arc::new(FakeCredentialsProvider::new());
1367 let state = cx.new(|_cx| State {
1368 credentials: Some(make_expired_credentials()),
1369 sign_in_task: None,
1370 refresh_task: None,
1371 load_task: None,
1372 credentials_provider: creds_provider.clone(),
1373 auth_generation: 0,
1374 last_auth_error: None,
1375 });
1376
1377 let weak_state = cx.read(|_cx| state.downgrade());
1378 let http: Arc<dyn HttpClient> = http_client;
1379
1380 // Start a refresh
1381 let weak = weak_state.clone();
1382 let http_clone = http.clone();
1383 let refresh_task =
1384 cx.spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await);
1385
1386 cx.run_until_parked();
1387
1388 // Sign out while the refresh is in-flight
1389 cx.update(|cx| {
1390 do_sign_out(&weak_state, cx).detach();
1391 });
1392 cx.run_until_parked();
1393
1394 // Now let the refresh respond by opening the gate
1395 let _ = gate_tx.send(());
1396 cx.run_until_parked();
1397
1398 let result = refresh_task.await;
1399 assert!(result.is_err(), "refresh should fail after sign-out");
1400
1401 cx.read(|cx| {
1402 let s = state.read(cx);
1403 assert!(
1404 s.credentials.is_none(),
1405 "sign-out should have cleared credentials"
1406 );
1407 });
1408 }
1409
1410 #[gpui::test]
1411 async fn test_sign_out_completes_fully(cx: &mut TestAppContext) {
1412 let creds_provider = Arc::new(FakeCredentialsProvider::new());
1413 // Pre-populate the credential store
1414 creds_provider
1415 .storage
1416 .lock()
1417 .replace(("Bearer".to_string(), b"some-creds".to_vec()));
1418
1419 let state = cx.new(|_cx| State {
1420 credentials: Some(make_fresh_credentials()),
1421 sign_in_task: None,
1422 refresh_task: None,
1423 load_task: None,
1424 credentials_provider: creds_provider.clone(),
1425 auth_generation: 0,
1426 last_auth_error: None,
1427 });
1428
1429 let weak_state = cx.read(|_cx| state.downgrade());
1430 let sign_out_task = cx.update(|cx| do_sign_out(&weak_state, cx));
1431
1432 cx.run_until_parked();
1433 sign_out_task.await.expect("sign-out should succeed");
1434
1435 assert!(
1436 creds_provider.storage.lock().is_none(),
1437 "credential store should be empty after sign-out"
1438 );
1439 cx.read(|cx| {
1440 assert!(
1441 !state.read(cx).is_authenticated(),
1442 "state should show not authenticated"
1443 );
1444 });
1445 }
1446
1447 #[gpui::test]
1448 async fn test_authenticate_awaits_initial_load(cx: &mut TestAppContext) {
1449 let creds = make_fresh_credentials();
1450 let creds_json = serde_json::to_vec(&creds).unwrap();
1451 let creds_provider = Arc::new(FakeCredentialsProvider::new());
1452 creds_provider
1453 .storage
1454 .lock()
1455 .replace(("Bearer".to_string(), creds_json));
1456
1457 let http_client = FakeHttpClient::create(|_| async {
1458 Ok(http_client::Response::builder()
1459 .status(200)
1460 .body(http_client::AsyncBody::default())?)
1461 });
1462
1463 let provider =
1464 cx.update(|cx| OpenAiSubscribedProvider::new(http_client, creds_provider, cx));
1465
1466 // Before load completes, authenticate should still await the load.
1467 let auth_task = cx.update(|cx| provider.authenticate(cx));
1468
1469 // Drive the load to completion.
1470 cx.run_until_parked();
1471
1472 let result = auth_task.await;
1473 assert!(
1474 result.is_ok(),
1475 "authenticate should succeed after load completes with valid credentials"
1476 );
1477 }
1478}