Deduplicate concurrent credential refresh with shared task

Richard Feldman created

Multiple concurrent stream_completion calls seeing expired credentials
would all independently call refresh_token(). With OAuth rotating
refresh tokens, the first refresh invalidates the old token, causing
all subsequent concurrent refreshes to fail.

Add a Shared<Task> to State so the first caller to notice expiry
spawns the refresh task and subsequent callers join the same future.

Change summary

crates/language_models/src/provider/openai_subscribed.rs | 93 +++++++--
1 file changed, 72 insertions(+), 21 deletions(-)

Detailed changes

crates/language_models/src/provider/openai_subscribed.rs 🔗

@@ -2,7 +2,7 @@ use anyhow::{Context as _, Result, anyhow};
 use base64::Engine as _;
 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
 use credentials_provider::CredentialsProvider;
-use futures::{FutureExt, StreamExt, future::BoxFuture};
+use futures::{FutureExt, StreamExt, future::BoxFuture, future::Shared};
 use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 use language_model::{
@@ -54,6 +54,7 @@ impl CodexCredentials {
 pub struct State {
     credentials: Option<CodexCredentials>,
     sign_in_task: Option<Task<Result<()>>>,
+    refresh_task: Option<Shared<Task<Result<CodexCredentials, Arc<anyhow::Error>>>>>,
     credentials_provider: Arc<dyn CredentialsProvider>,
 }
 
@@ -85,6 +86,7 @@ impl OpenAiSubscribedProvider {
         let state = cx.new(|_cx| State {
             credentials: None,
             sign_in_task: None,
+            refresh_task: None,
             credentials_provider,
         });
 
@@ -536,40 +538,89 @@ async fn get_fresh_credentials(
     http_client: &Arc<dyn HttpClient>,
     cx: &mut AsyncApp,
 ) -> Result<CodexCredentials, LanguageModelCompletionError> {
-    let creds = state
-        .read_with(&*cx, |s, _| s.credentials.clone())
-        .map_err(LanguageModelCompletionError::Other)?
-        .ok_or(LanguageModelCompletionError::NoApiKey {
-            provider: PROVIDER_NAME,
-        })?;
+    let (creds, existing_task) = state
+        .read_with(&*cx, |s, _| (s.credentials.clone(), s.refresh_task.clone()))
+        .map_err(LanguageModelCompletionError::Other)?;
+
+    let creds = creds.ok_or(LanguageModelCompletionError::NoApiKey {
+        provider: PROVIDER_NAME,
+    })?;
 
     if !creds.is_expired() {
         return Ok(creds);
     }
 
-    let refreshed = refresh_token(http_client, &creds.refresh_token)
-        .await
-        .map_err(LanguageModelCompletionError::Other)?;
+    // If another caller is already refreshing, await their result.
+    if let Some(shared_task) = existing_task {
+        return shared_task
+            .await
+            .map_err(|e| LanguageModelCompletionError::Other(anyhow::anyhow!("{e}")));
+    }
 
-    let credentials_provider = state
-        .read_with(&*cx, |s, _| s.credentials_provider.clone())
-        .map_err(LanguageModelCompletionError::Other)?;
+    // We are the first caller to notice expiry — spawn the refresh task.
+    let http_client_clone = http_client.clone();
+    let state_clone = state.clone();
+    let refresh_token_value = creds.refresh_token.clone();
 
-    let json = serde_json::to_vec(&refreshed)
-        .map_err(|e| LanguageModelCompletionError::Other(e.into()))?;
+    let shared_task = cx
+        .spawn(async move |cx| {
+            let result = refresh_token(&http_client_clone, &refresh_token_value).await;
 
-    credentials_provider
-        .write_credentials(CREDENTIALS_KEY, "Bearer", &json, &*cx)
-        .await
-        .map_err(LanguageModelCompletionError::Other)?;
+            match result {
+                Ok(refreshed) => {
+                    let persist_result: Result<CodexCredentials, Arc<anyhow::Error>> = async {
+                        let credentials_provider = state_clone
+                            .read_with(&*cx, |s, _| s.credentials_provider.clone())
+                            .map_err(|e| Arc::new(e))?;
+
+                        let json =
+                            serde_json::to_vec(&refreshed).map_err(|e| Arc::new(e.into()))?;
+
+                        credentials_provider
+                            .write_credentials(CREDENTIALS_KEY, "Bearer", &json, &*cx)
+                            .await
+                            .map_err(|e| Arc::new(e))?;
 
+                        state_clone
+                            .update(cx, |s, _| {
+                                s.credentials = Some(refreshed.clone());
+                                s.refresh_task = None;
+                            })
+                            .map_err(|e| Arc::new(e))?;
+
+                        Ok(refreshed)
+                    }
+                    .await;
+
+                    // Clear refresh_task on failure too.
+                    if persist_result.is_err() {
+                        let _ = state_clone.update(cx, |s, _| {
+                            s.refresh_task = None;
+                        });
+                    }
+
+                    persist_result
+                }
+                Err(e) => {
+                    let _ = state_clone.update(cx, |s, _| {
+                        s.refresh_task = None;
+                    });
+                    Err(Arc::new(e))
+                }
+            }
+        })
+        .shared();
+
+    // Store the shared task so concurrent callers can join on it.
     state
         .update(cx, |s, _| {
-            s.credentials = Some(refreshed.clone());
+            s.refresh_task = Some(shared_task.clone());
         })
         .map_err(LanguageModelCompletionError::Other)?;
 
-    Ok(refreshed)
+    shared_task
+        .await
+        .map_err(|e| LanguageModelCompletionError::Other(anyhow::anyhow!("{e}")))
 }
 
 // --- OAuth PKCE flow ---