@@ -71,6 +71,9 @@ vercel = { workspace = true, features = ["schemars"] }
x_ai = { workspace = true, features = ["schemars"] }
[dev-dependencies]
+gpui = { workspace = true, features = ["test-support"] }
+http_client = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
+parking_lot.workspace = true
pretty_assertions.workspace = true
@@ -981,3 +981,211 @@ impl Render for ConfigurationView {
.into_any_element()
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use gpui::TestAppContext;
+ use http_client::FakeHttpClient;
+ use parking_lot::Mutex;
+ use std::future::Future;
+ use std::pin::Pin;
+ use std::sync::atomic::{AtomicUsize, Ordering};
+
+ struct FakeCredentialsProvider {
+ storage: Mutex<Option<(String, Vec<u8>)>>,
+ }
+
+ impl FakeCredentialsProvider {
+ fn new() -> Self {
+ Self {
+ storage: Mutex::new(None),
+ }
+ }
+ }
+
+ impl CredentialsProvider for FakeCredentialsProvider {
+ fn read_credentials<'a>(
+ &'a self,
+ _url: &'a str,
+ _cx: &'a AsyncApp,
+ ) -> Pin<Box<dyn Future<Output = Result<Option<(String, Vec<u8>)>>> + 'a>> {
+ Box::pin(async { Ok(self.storage.lock().clone()) })
+ }
+
+ fn write_credentials<'a>(
+ &'a self,
+ _url: &'a str,
+ username: &'a str,
+ password: &'a [u8],
+ _cx: &'a AsyncApp,
+ ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+ self.storage
+ .lock()
+ .replace((username.to_string(), password.to_vec()));
+ Box::pin(async { Ok(()) })
+ }
+
+ fn delete_credentials<'a>(
+ &'a self,
+ _url: &'a str,
+ _cx: &'a AsyncApp,
+ ) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
+ *self.storage.lock() = None;
+ Box::pin(async { Ok(()) })
+ }
+ }
+
+ fn make_expired_credentials() -> CodexCredentials {
+ CodexCredentials {
+ access_token: "old_access".to_string(),
+ refresh_token: "old_refresh".to_string(),
+ expires_at_ms: 0,
+ account_id: None,
+ email: None,
+ }
+ }
+
+ fn make_fresh_credentials() -> CodexCredentials {
+ CodexCredentials {
+ access_token: "fresh_access".to_string(),
+ refresh_token: "fresh_refresh".to_string(),
+ expires_at_ms: now_ms() + 3_600_000,
+ account_id: None,
+ email: None,
+ }
+ }
+
+ fn fake_token_response() -> String {
+ serde_json::json!({
+ "access_token": "fresh_access",
+ "refresh_token": "fresh_refresh",
+ "expires_in": 3600
+ })
+ .to_string()
+ }
+
+ #[gpui::test]
+ async fn test_concurrent_refresh_deduplicates(cx: &mut TestAppContext) {
+ let refresh_count = Arc::new(AtomicUsize::new(0));
+ let refresh_count_clone = refresh_count.clone();
+
+ let http_client = FakeHttpClient::create(move |_request| {
+ let refresh_count = refresh_count_clone.clone();
+ async move {
+ refresh_count.fetch_add(1, Ordering::SeqCst);
+ let body = fake_token_response();
+ Ok(http_client::Response::builder()
+ .status(200)
+ .body(http_client::AsyncBody::from(body))?)
+ }
+ });
+
+ let state = cx.new(|_cx| State {
+ credentials: Some(make_expired_credentials()),
+ sign_in_task: None,
+ refresh_task: None,
+ credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+ });
+
+ let weak_state = cx.read(|_cx| state.downgrade());
+ let http: Arc<dyn HttpClient> = http_client;
+
+ // Spawn two concurrent refresh attempts.
+ let weak1 = weak_state.clone();
+ let http1 = http.clone();
+ let task1 =
+ cx.spawn(async move |mut cx| get_fresh_credentials(&weak1, &http1, &mut cx).await);
+
+ let weak2 = weak_state.clone();
+ let http2 = http.clone();
+ let task2 =
+ cx.spawn(async move |mut cx| get_fresh_credentials(&weak2, &http2, &mut cx).await);
+
+ // Drive both to completion.
+ cx.run_until_parked();
+ let result1 = task1.await;
+ let result2 = task2.await;
+
+ assert!(result1.is_ok(), "first refresh should succeed");
+ assert!(result2.is_ok(), "second refresh should succeed");
+ assert_eq!(result1.unwrap().access_token, "fresh_access");
+ assert_eq!(result2.unwrap().access_token, "fresh_access");
+ assert_eq!(
+ refresh_count.load(Ordering::SeqCst),
+ 1,
+ "refresh_token should only be called once despite two concurrent callers"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_fresh_credentials_skip_refresh(cx: &mut TestAppContext) {
+ let refresh_count = Arc::new(AtomicUsize::new(0));
+ let refresh_count_clone = refresh_count.clone();
+
+ let http_client = FakeHttpClient::create(move |_request| {
+ let refresh_count = refresh_count_clone.clone();
+ async move {
+ refresh_count.fetch_add(1, Ordering::SeqCst);
+ let body = fake_token_response();
+ Ok(http_client::Response::builder()
+ .status(200)
+ .body(http_client::AsyncBody::from(body))?)
+ }
+ });
+
+ let state = cx.new(|_cx| State {
+ credentials: Some(make_fresh_credentials()),
+ sign_in_task: None,
+ refresh_task: None,
+ credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+ });
+
+ let weak_state = cx.read(|_cx| state.downgrade());
+ let http: Arc<dyn HttpClient> = http_client;
+
+ let weak = weak_state.clone();
+ let http_clone = http.clone();
+ let result = cx
+ .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
+ .await;
+
+ assert!(result.is_ok());
+ assert_eq!(result.unwrap().access_token, "fresh_access");
+ assert_eq!(
+ refresh_count.load(Ordering::SeqCst),
+ 0,
+ "no refresh should happen when credentials are fresh"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_no_credentials_returns_no_api_key(cx: &mut TestAppContext) {
+ let http_client = FakeHttpClient::create(|_| async {
+ Ok(http_client::Response::builder()
+ .status(200)
+ .body(http_client::AsyncBody::default())?)
+ });
+
+ let state = cx.new(|_cx| State {
+ credentials: None,
+ sign_in_task: None,
+ refresh_task: None,
+ credentials_provider: Arc::new(FakeCredentialsProvider::new()),
+ });
+
+ let weak_state = cx.read(|_cx| state.downgrade());
+ let http: Arc<dyn HttpClient> = http_client;
+
+ let weak = weak_state.clone();
+ let http_clone = http.clone();
+ let result = cx
+ .spawn(async move |mut cx| get_fresh_credentials(&weak, &http_clone, &mut cx).await)
+ .await;
+
+ assert!(matches!(
+ result,
+ Err(LanguageModelCompletionError::NoApiKey { .. })
+ ));
+ }
+}