@@ -2,7 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
use base64::Engine as _;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use credentials_provider::CredentialsProvider;
-use futures::{FutureExt, StreamExt, future::BoxFuture};
+use futures::{FutureExt, StreamExt, future::BoxFuture, future::Shared};
use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use language_model::{
@@ -54,6 +54,7 @@ impl CodexCredentials {
pub struct State {
credentials: Option<CodexCredentials>,
sign_in_task: Option<Task<Result<()>>>,
+ refresh_task: Option<Shared<Task<Result<CodexCredentials, Arc<anyhow::Error>>>>>,
credentials_provider: Arc<dyn CredentialsProvider>,
}
@@ -85,6 +86,7 @@ impl OpenAiSubscribedProvider {
let state = cx.new(|_cx| State {
credentials: None,
sign_in_task: None,
+ refresh_task: None,
credentials_provider,
});
@@ -536,40 +538,89 @@ async fn get_fresh_credentials(
http_client: &Arc<dyn HttpClient>,
cx: &mut AsyncApp,
) -> Result<CodexCredentials, LanguageModelCompletionError> {
- let creds = state
- .read_with(&*cx, |s, _| s.credentials.clone())
- .map_err(LanguageModelCompletionError::Other)?
- .ok_or(LanguageModelCompletionError::NoApiKey {
- provider: PROVIDER_NAME,
- })?;
+ let (creds, existing_task) = state
+ .read_with(&*cx, |s, _| (s.credentials.clone(), s.refresh_task.clone()))
+ .map_err(LanguageModelCompletionError::Other)?;
+
+ let creds = creds.ok_or(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ })?;
if !creds.is_expired() {
return Ok(creds);
}
- let refreshed = refresh_token(http_client, &creds.refresh_token)
- .await
- .map_err(LanguageModelCompletionError::Other)?;
+ // If another caller is already refreshing, await their result.
+ if let Some(shared_task) = existing_task {
+ return shared_task
+ .await
+ .map_err(|e| LanguageModelCompletionError::Other(anyhow::anyhow!("{e}")));
+ }
- let credentials_provider = state
- .read_with(&*cx, |s, _| s.credentials_provider.clone())
- .map_err(LanguageModelCompletionError::Other)?;
+ // We are the first caller to notice expiry — spawn the refresh task.
+ let http_client_clone = http_client.clone();
+ let state_clone = state.clone();
+ let refresh_token_value = creds.refresh_token.clone();
- let json = serde_json::to_vec(&refreshed)
- .map_err(|e| LanguageModelCompletionError::Other(e.into()))?;
+ let shared_task = cx
+ .spawn(async move |cx| {
+ let result = refresh_token(&http_client_clone, &refresh_token_value).await;
- credentials_provider
- .write_credentials(CREDENTIALS_KEY, "Bearer", &json, &*cx)
- .await
- .map_err(LanguageModelCompletionError::Other)?;
+ match result {
+ Ok(refreshed) => {
+ let persist_result: Result<CodexCredentials, Arc<anyhow::Error>> = async {
+ let credentials_provider = state_clone
+ .read_with(&*cx, |s, _| s.credentials_provider.clone())
+ .map_err(|e| Arc::new(e))?;
+
+ let json =
+ serde_json::to_vec(&refreshed).map_err(|e| Arc::new(e.into()))?;
+
+ credentials_provider
+ .write_credentials(CREDENTIALS_KEY, "Bearer", &json, &*cx)
+ .await
+ .map_err(|e| Arc::new(e))?;
+ state_clone
+ .update(cx, |s, _| {
+ s.credentials = Some(refreshed.clone());
+ s.refresh_task = None;
+ })
+ .map_err(|e| Arc::new(e))?;
+
+ Ok(refreshed)
+ }
+ .await;
+
+ // Clear refresh_task on failure too.
+ if persist_result.is_err() {
+ let _ = state_clone.update(cx, |s, _| {
+ s.refresh_task = None;
+ });
+ }
+
+ persist_result
+ }
+ Err(e) => {
+ let _ = state_clone.update(cx, |s, _| {
+ s.refresh_task = None;
+ });
+ Err(Arc::new(e))
+ }
+ }
+ })
+ .shared();
+
+ // Store the shared task so concurrent callers can join on it.
state
.update(cx, |s, _| {
- s.credentials = Some(refreshed.clone());
+ s.refresh_task = Some(shared_task.clone());
})
.map_err(LanguageModelCompletionError::Other)?;
- Ok(refreshed)
+ shared_task
+ .await
+ .map_err(|e| LanguageModelCompletionError::Other(anyhow::anyhow!("{e}")))
}
// --- OAuth PKCE flow ---