Add tests for credential refresh deduplication

Richard Feldman created

Three tests covering get_fresh_credentials:
- test_concurrent_refresh_deduplicates: two concurrent callers with
  expired credentials only trigger one HTTP refresh call
- test_fresh_credentials_skip_refresh: fresh credentials return
  immediately with no HTTP call
- test_no_credentials_returns_no_api_key: missing credentials return
  the correct error variant

Change summary

crates/language_models/Cargo.toml                        |   3 
crates/language_models/src/provider/openai_subscribed.rs | 208 ++++++++++
2 files changed, 211 insertions(+)

Detailed changes

crates/language_models/Cargo.toml 🔗

@@ -71,6 +71,9 @@ vercel = { workspace = true, features = ["schemars"] }
 x_ai = { workspace = true, features = ["schemars"] }
 
 [dev-dependencies]
+gpui = { workspace = true, features = ["test-support"] }
+http_client = { workspace = true, features = ["test-support"] }
 language_model = { workspace = true, features = ["test-support"] }
+parking_lot.workspace = true
 pretty_assertions.workspace = true
 

crates/language_models/src/provider/openai_subscribed.rs 🔗

@@ -981,3 +981,211 @@ impl Render for ConfigurationView {
             .into_any_element()
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use gpui::TestAppContext;
+    use http_client::FakeHttpClient;
+    use parking_lot::Mutex;
+    use std::future::Future;
+    use std::pin::Pin;
+    use std::sync::atomic::{AtomicUsize, Ordering};
+
+    struct FakeCredentialsProvider {
+        storage: Mutex<Option<(String, Vec<u8>)>>,
+    }
+
+    impl FakeCredentialsProvider {
+        fn new() -> Self {
+            Self {
+                storage: Mutex::new(None),
+            }
+        }
+    }
+
+    impl CredentialsProvider for FakeCredentialsProvider {
+        fn read_credentials<'a>(
+            &'a self,
+            _url: &'a str,
+            _cx: &'a AsyncApp,
+        ) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
+            Box::pin(async { Ok(self.storage.lock().clone()) })
+        }
+
+        fn write_credentials<'a>(
+            &'a self,
+            _url: &'a str,
+            username: &'a str,
+            password: &'a [u8],
+            _cx: &'a AsyncApp,
+        ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+            self.storage
+                .lock()
+                .replace((username.to_string(), password.to_vec()));
+            Box::pin(async { Ok(()) })
+        }
+
+        fn delete_credentials<'a>(
+            &'a self,
+            _url: &'a str,
+            _cx: &'a AsyncApp,
+        ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+            *self.storage.lock() = None;
+            Box::pin(async { Ok(()) })
+        }
+    }
+
+    fn make_expired_credentials() -> CodexCredentials {
+        CodexCredentials {
+            access_token: "old_access".to_string(),
+            refresh_token: "old_refresh".to_string(),
+            expires_at_ms: 0,
+            account_id: None,
+            email: None,
+        }
+    }
+
+    fn make_fresh_credentials() -> CodexCredentials {
+        CodexCredentials {
+            access_token: "fresh_access".to_string(),
+            refresh_token: "fresh_refresh".to_string(),
+            expires_at_ms: now_ms() + 3_600_000,
+            account_id: None,
+            email: None,
+        }
+    }
+
+    fn fake_token_response() -> String {
+        serde_json::json!({
+            "access_token": "fresh_access",
+            "refresh_token": "fresh_refresh",
+            "expires_in": 3600
+        })
+        .to_string()
+    }
+
+    #[gpui::test]
+    async fn test_concurrent_refresh_deduplicates(cx: &mut TestAppContext) {
+        let refresh_count = Arc::new(AtomicUsize::new(0));
+        let refresh_count_clone = refresh_count.clone();
+
+        let http_client = FakeHttpClient::create(move |_request| {
+            let refresh_count = refresh_count_clone.clone();
+            async move {
+                refresh_count.fetch_add(1, Ordering::SeqCst);
+                let body = fake_token_response();
+                Ok(http_client::Response::builder()
+                    .status(200)
+                    .body(http_client::AsyncBody::from(body))?)
+            }
+        });
+
+        let state = cx.new(|_cx| State {
+            credentials: Some(make_expired_credentials()),
+            sign_in_task: None,
+            refresh_task: None,
+            credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+        });
+
+        let weak_state = cx.read(|_cx| state.downgrade());
+        let http: Arc<dyn HttpClient> = http_client;
+
+        // Spawn two concurrent refresh attempts.
+        let weak1 = weak_state.clone();
+        let http1 = http.clone();
+        let task1 =
+            cx.spawn(async move |mut cx| get_fresh_credentials(&weak1, &http1, &mut cx).await);
+
+        let weak2 = weak_state.clone();
+        let http2 = http.clone();
+        let task2 =
+            cx.spawn(async move |mut cx| get_fresh_credentials(&weak2, &http2, &mut cx).await);
+
+        // Drive both to completion.
+        cx.run_until_parked();
+        let result1 = task1.await;
+        let result2 = task2.await;
+
+        assert!(result1.is_ok(), "first refresh should succeed");
+        assert!(result2.is_ok(), "second refresh should succeed");
+        assert_eq!(result1.unwrap().access_token, "fresh_access");
+        assert_eq!(result2.unwrap().access_token, "fresh_access");
+        assert_eq!(
+            refresh_count.load(Ordering::SeqCst),
+            1,
+            "refresh_token should only be called once despite two concurrent callers"
+        );
+    }
+
+    #[gpui::test]
+    async fn test_fresh_credentials_skip_refresh(cx: &mut TestAppContext) {
+        let refresh_count = Arc::new(AtomicUsize::new(0));
+        let refresh_count_clone = refresh_count.clone();
+
+        let http_client = FakeHttpClient::create(move |_request| {
+            let refresh_count = refresh_count_clone.clone();
+            async move {
+                refresh_count.fetch_add(1, Ordering::SeqCst);
+                let body = fake_token_response();
+                Ok(http_client::Response::builder()
+                    .status(200)
+                    .body(http_client::AsyncBody::from(body))?)
+            }
+        });
+
+        let state = cx.new(|_cx| State {
+            credentials: Some(make_fresh_credentials()),
+            sign_in_task: None,
+            refresh_task: None,
+            credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+        });
+
+        let weak_state = cx.read(|_cx| state.downgrade());
+        let http: Arc<dyn HttpClient> = http_client;
+
+        let weak = weak_state.clone();
+        let http_clone = http.clone();
+        let result = cx
+            .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
+            .await;
+
+        assert!(result.is_ok());
+        assert_eq!(result.unwrap().access_token, "fresh_access");
+        assert_eq!(
+            refresh_count.load(Ordering::SeqCst),
+            0,
+            "no refresh should happen when credentials are fresh"
+        );
+    }
+
+    #[gpui::test]
+    async fn test_no_credentials_returns_no_api_key(cx: &mut TestAppContext) {
+        let http_client = FakeHttpClient::create(|_| async {
+            Ok(http_client::Response::builder()
+                .status(200)
+                .body(http_client::AsyncBody::default())?)
+        });
+
+        let state = cx.new(|_cx| State {
+            credentials: None,
+            sign_in_task: None,
+            refresh_task: None,
+            credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+        });
+
+        let weak_state = cx.read(|_cx| state.downgrade());
+        let http: Arc<dyn HttpClient> = http_client;
+
+        let weak = weak_state.clone();
+        let http_clone = http.clone();
+        let result = cx
+            .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
+            .await;
+
+        assert!(matches!(
+            result,
+            Err(LanguageModelCompletionError::NoApiKey { .. })
+        ));
+    }
+}