Fix auth state reliability issues in ChatGPT subscription provider

morgankrey created

- Propagate refreshed credentials to in-memory State immediately after
  persisting, so subsequent requests don't redundantly refresh again
- Clear sign_in_task when keychain write fails during OAuth, so the
  UI doesn't get stuck in a permanent "Signing in..." state

Change summary

crates/language_models/src/provider/openai_subscribed.rs | 60 +++++++---
1 file changed, 42 insertions(+), 18 deletions(-)

Detailed changes

crates/language_models/src/provider/openai_subscribed.rs 🔗

@@ -128,17 +128,37 @@ impl OpenAiSubscribedProvider {
         let task = cx.spawn(async move |cx| {
             match do_oauth_flow(http_client, &*cx).await {
                 Ok(creds) => {
-                    let credentials_provider =
-                        state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
-                    let json = serde_json::to_vec(&creds)?;
-                    credentials_provider
-                        .write_credentials(CREDENTIALS_KEY, "Bearer", &json, &*cx)
-                        .await?;
-                    state.update(cx, |s, cx| {
-                        s.credentials = Some(creds);
-                        s.sign_in_task = None;
-                        cx.notify();
-                    })?;
+                    let persist_result = async {
+                        let credentials_provider =
+                            state.read_with(&*cx, |s, _| s.credentials_provider.clone())?;
+                        let json = serde_json::to_vec(&creds)?;
+                        credentials_provider
+                            .write_credentials(CREDENTIALS_KEY, "Bearer", &json, &*cx)
+                            .await?;
+                        anyhow::Ok(())
+                    }
+                    .await;
+
+                    match persist_result {
+                        Ok(()) => {
+                            state
+                                .update(cx, |s, cx| {
+                                    s.credentials = Some(creds);
+                                    s.sign_in_task = None;
+                                    cx.notify();
+                                })
+                                .log_err();
+                        }
+                        Err(err) => {
+                            log::error!("ChatGPT subscription sign-in failed to persist credentials: {err:?}");
+                            state
+                                .update(cx, |s, cx| {
+                                    s.sign_in_task = None;
+                                    cx.notify();
+                                })
+                                .log_err();
+                        }
+                    }
                 }
                 Err(err) => {
                     log::error!("ChatGPT subscription sign-in failed: {err:?}");
@@ -392,7 +412,7 @@ impl LanguageModel for OpenAiSubscribedLanguageModel {
         let request_limiter = self.request_limiter.clone();
 
         let future = cx.spawn(async move |cx| {
-            let creds = get_fresh_credentials(&state, &http_client, &*cx).await?;
+            let creds = get_fresh_credentials(&state, &http_client, cx).await?;
 
             let mut extra_headers: Vec<(String, String)> = vec![
                 ("originator".into(), "zed".into()),
@@ -434,10 +454,10 @@ impl LanguageModel for OpenAiSubscribedLanguageModel {
 async fn get_fresh_credentials(
     state: &gpui::WeakEntity<State>,
     http_client: &Arc<dyn HttpClient>,
-    cx: &AsyncApp,
+    cx: &mut AsyncApp,
 ) -> Result<CodexCredentials, LanguageModelCompletionError> {
     let creds = state
-        .read_with(cx, |s, _| s.credentials.clone())
+        .read_with(&*cx, |s, _| s.credentials.clone())
         .map_err(|e| LanguageModelCompletionError::Other(e.into()))?
         .ok_or(LanguageModelCompletionError::NoApiKey {
             provider: PROVIDER_NAME,
@@ -452,19 +472,23 @@ async fn get_fresh_credentials(
         .map_err(LanguageModelCompletionError::Other)?;
 
     let credentials_provider = state
-        .read_with(cx, |s, _| s.credentials_provider.clone())
+        .read_with(&*cx, |s, _| s.credentials_provider.clone())
         .map_err(|e| LanguageModelCompletionError::Other(e.into()))?;
 
     let json = serde_json::to_vec(&refreshed)
         .map_err(|e| LanguageModelCompletionError::Other(e.into()))?;
 
     credentials_provider
-        .write_credentials(CREDENTIALS_KEY, "Bearer", &json, cx)
+        .write_credentials(CREDENTIALS_KEY, "Bearer", &json, &*cx)
         .await
         .map_err(LanguageModelCompletionError::Other)?;
 
-    // The entity state will get the updated credentials on next login/load;
-    // for this request we use the freshly-fetched token.
+    state
+        .update(cx, |s, _| {
+            s.credentials = Some(refreshed.clone());
+        })
+        .map_err(|e| LanguageModelCompletionError::Other(e.into()))?;
+
     Ok(refreshed)
 }