diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index e349cf100456087c494d4e4426998f62c6854a1d..4b141205053efc2dbf2fce81087ec9ed8dc25e75 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -71,6 +71,9 @@ vercel = { workspace = true, features = ["schemars"] } x_ai = { workspace = true, features = ["schemars"] } [dev-dependencies] +gpui = { workspace = true, features = ["test-support"] } +http_client = { workspace = true, features = ["test-support"] } language_model = { workspace = true, features = ["test-support"] } +parking_lot.workspace = true pretty_assertions.workspace = true diff --git a/crates/language_models/src/provider/openai_subscribed.rs b/crates/language_models/src/provider/openai_subscribed.rs index ea503c8a3f1f6b4e0aa86ccb94756efc1ad5ad83..53096696d6b9d6268dced807d4ebad0c7ecf784a 100644 --- a/crates/language_models/src/provider/openai_subscribed.rs +++ b/crates/language_models/src/provider/openai_subscribed.rs @@ -981,3 +981,211 @@ impl Render for ConfigurationView { .into_any_element() } } + +#[cfg(test)] +mod tests { + use super::*; + use gpui::TestAppContext; + use http_client::FakeHttpClient; + use parking_lot::Mutex; + use std::future::Future; + use std::pin::Pin; + use std::sync::atomic::{AtomicUsize, Ordering}; + + struct FakeCredentialsProvider { + storage: Mutex)>>, + } + + impl FakeCredentialsProvider { + fn new() -> Self { + Self { + storage: Mutex::new(None), + } + } + } + + impl CredentialsProvider for FakeCredentialsProvider { + fn read_credentials<'a>( + &'a self, + _url: &'a str, + _cx: &'a AsyncApp, + ) -> Pin)>>> + 'a>> { + Box::pin(async { Ok(self.storage.lock().clone()) }) + } + + fn write_credentials<'a>( + &'a self, + _url: &'a str, + username: &'a str, + password: &'a [u8], + _cx: &'a AsyncApp, + ) -> Pin> + 'a>> { + self.storage + .lock() + .replace((username.to_string(), password.to_vec())); + Box::pin(async { Ok(()) }) + } + + fn delete_credentials<'a>( + &'a self, + _url: &'a str, + _cx: &'a AsyncApp, + ) -> Pin> + 'a>> { + *self.storage.lock() = None; + Box::pin(async { Ok(()) }) + } + } + + fn make_expired_credentials() -> CodexCredentials { + CodexCredentials { + access_token: "old_access".to_string(), + refresh_token: "old_refresh".to_string(), + expires_at_ms: 0, + account_id: None, + email: None, + } + } + + fn make_fresh_credentials() -> CodexCredentials { + CodexCredentials { + access_token: "fresh_access".to_string(), + refresh_token: "fresh_refresh".to_string(), + expires_at_ms: now_ms() + 3_600_000, + account_id: None, + email: None, + } + } + + fn fake_token_response() -> String { + serde_json::json!({ + "access_token": "fresh_access", + "refresh_token": "fresh_refresh", + "expires_in": 3600 + }) + .to_string() + } + + #[gpui::test] + async fn test_concurrent_refresh_deduplicates(cx: &mut TestAppContext) { + let refresh_count = Arc::new(AtomicUsize::new(0)); + let refresh_count_clone = refresh_count.clone(); + + let http_client = FakeHttpClient::create(move |_request| { + let refresh_count = refresh_count_clone.clone(); + async move { + refresh_count.fetch_add(1, Ordering::SeqCst); + let body = fake_token_response(); + Ok(http_client::Response::builder() + .status(200) + .body(http_client::AsyncBody::from(body))?) + } + }); + + let state = cx.new(|_cx| State { + credentials: Some(make_expired_credentials()), + sign_in_task: None, + refresh_task: None, + credentials_provider: Arc::new(FakeCredentialsProvider::new()), + }); + + let weak_state = cx.read(|_cx| state.downgrade()); + let http: Arc = http_client; + + // Spawn two concurrent refresh attempts. + let weak1 = weak_state.clone(); + let http1 = http.clone(); + let task1 = + cx.spawn(async move |mut cx| get_fresh_credentials(&weak1, &http1, &mut cx).await); + + let weak2 = weak_state.clone(); + let http2 = http.clone(); + let task2 = + cx.spawn(async move |mut cx| get_fresh_credentials(&weak2, &http2, &mut cx).await); + + // Drive both to completion. + cx.run_until_parked(); + let result1 = task1.await; + let result2 = task2.await; + + assert!(result1.is_ok(), "first refresh should succeed"); + assert!(result2.is_ok(), "second refresh should succeed"); + assert_eq!(result1.unwrap().access_token, "fresh_access"); + assert_eq!(result2.unwrap().access_token, "fresh_access"); + assert_eq!( + refresh_count.load(Ordering::SeqCst), + 1, + "refresh_token should only be called once despite two concurrent callers" + ); + } + + #[gpui::test] + async fn test_fresh_credentials_skip_refresh(cx: &mut TestAppContext) { + let refresh_count = Arc::new(AtomicUsize::new(0)); + let refresh_count_clone = refresh_count.clone(); + + let http_client = FakeHttpClient::create(move |_request| { + let refresh_count = refresh_count_clone.clone(); + async move { + refresh_count.fetch_add(1, Ordering::SeqCst); + let body = fake_token_response(); + Ok(http_client::Response::builder() + .status(200) + .body(http_client::AsyncBody::from(body))?) + } + }); + + let state = cx.new(|_cx| State { + credentials: Some(make_fresh_credentials()), + sign_in_task: None, + refresh_task: None, + credentials_provider: Arc::new(FakeCredentialsProvider::new()), + }); + + let weak_state = cx.read(|_cx| state.downgrade()); + let http: Arc = http_client; + + let weak = weak_state.clone(); + let http_clone = http.clone(); + let result = cx + .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().access_token, "fresh_access"); + assert_eq!( + refresh_count.load(Ordering::SeqCst), + 0, + "no refresh should happen when credentials are fresh" + ); + } + + #[gpui::test] + async fn test_no_credentials_returns_no_api_key(cx: &mut TestAppContext) { + let http_client = FakeHttpClient::create(|_| async { + Ok(http_client::Response::builder() + .status(200) + .body(http_client::AsyncBody::default())?) + }); + + let state = cx.new(|_cx| State { + credentials: None, + sign_in_task: None, + refresh_task: None, + credentials_provider: Arc::new(FakeCredentialsProvider::new()), + }); + + let weak_state = cx.read(|_cx| state.downgrade()); + let http: Arc = http_client; + + let weak = weak_state.clone(); + let http_clone = http.clone(); + let result = cx + .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await) + .await; + + assert!(matches!( + result, + Err(LanguageModelCompletionError::NoApiKey { .. }) + )); + } +}