diff --git a/crates/language_models/src/provider/openai_subscribed.rs b/crates/language_models/src/provider/openai_subscribed.rs index 79bb08e9e0ee16baef18956859cb06827ca21e70..5ce6b44b2bcb9c98ddf59c41242fb40d2b7b4f69 100644 --- a/crates/language_models/src/provider/openai_subscribed.rs +++ b/crates/language_models/src/provider/openai_subscribed.rs @@ -2,7 +2,7 @@ use anyhow::{Context as _, Result, anyhow}; use base64::Engine as _; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use credentials_provider::CredentialsProvider; -use futures::{FutureExt, StreamExt, future::BoxFuture}; +use futures::{FutureExt, StreamExt, future::BoxFuture, future::Shared}; use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use language_model::{ @@ -54,6 +54,7 @@ impl CodexCredentials { pub struct State { credentials: Option, sign_in_task: Option>>, + refresh_task: Option>>>>, credentials_provider: Arc, } @@ -85,6 +86,7 @@ impl OpenAiSubscribedProvider { let state = cx.new(|_cx| State { credentials: None, sign_in_task: None, + refresh_task: None, credentials_provider, }); @@ -536,40 +538,89 @@ async fn get_fresh_credentials( http_client: &Arc, cx: &mut AsyncApp, ) -> Result { - let creds = state - .read_with(&*cx, |s, _| s.credentials.clone()) - .map_err(LanguageModelCompletionError::Other)? - .ok_or(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - })?; + let (creds, existing_task) = state + .read_with(&*cx, |s, _| (s.credentials.clone(), s.refresh_task.clone())) + .map_err(LanguageModelCompletionError::Other)?; + + let creds = creds.ok_or(LanguageModelCompletionError::NoApiKey { + provider: PROVIDER_NAME, + })?; if !creds.is_expired() { return Ok(creds); } - let refreshed = refresh_token(http_client, &creds.refresh_token) - .await - .map_err(LanguageModelCompletionError::Other)?; + // If another caller is already refreshing, await their result. + if let Some(shared_task) = existing_task { + return shared_task + .await + .map_err(|e| LanguageModelCompletionError::Other(anyhow::anyhow!("{e}"))); + } - let credentials_provider = state - .read_with(&*cx, |s, _| s.credentials_provider.clone()) - .map_err(LanguageModelCompletionError::Other)?; + // We are the first caller to notice expiry — spawn the refresh task. + let http_client_clone = http_client.clone(); + let state_clone = state.clone(); + let refresh_token_value = creds.refresh_token.clone(); - let json = serde_json::to_vec(&refreshed) - .map_err(|e| LanguageModelCompletionError::Other(e.into()))?; + let shared_task = cx + .spawn(async move |cx| { + let result = refresh_token(&http_client_clone, &refresh_token_value).await; - credentials_provider - .write_credentials(CREDENTIALS_KEY, "Bearer", &json, &*cx) - .await - .map_err(LanguageModelCompletionError::Other)?; + match result { + Ok(refreshed) => { + let persist_result: Result> = async { + let credentials_provider = state_clone + .read_with(&*cx, |s, _| s.credentials_provider.clone()) + .map_err(|e| Arc::new(e))?; + + let json = + serde_json::to_vec(&refreshed).map_err(|e| Arc::new(e.into()))?; + + credentials_provider + .write_credentials(CREDENTIALS_KEY, "Bearer", &json, &*cx) + .await + .map_err(|e| Arc::new(e))?; + state_clone + .update(cx, |s, _| { + s.credentials = Some(refreshed.clone()); + s.refresh_task = None; + }) + .map_err(|e| Arc::new(e))?; + + Ok(refreshed) + } + .await; + + // Clear refresh_task on failure too. + if persist_result.is_err() { + let _ = state_clone.update(cx, |s, _| { + s.refresh_task = None; + }); + } + + persist_result + } + Err(e) => { + let _ = state_clone.update(cx, |s, _| { + s.refresh_task = None; + }); + Err(Arc::new(e)) + } + } + }) + .shared(); + + // Store the shared task so concurrent callers can join on it. state .update(cx, |s, _| { - s.credentials = Some(refreshed.clone()); + s.refresh_task = Some(shared_task.clone()); }) .map_err(LanguageModelCompletionError::Other)?; - Ok(refreshed) + shared_task + .await + .map_err(|e| LanguageModelCompletionError::Other(anyhow::anyhow!("{e}"))) } // --- OAuth PKCE flow ---