updated authentication for embedding provider

KCaverly created

Change summary

crates/ai/Cargo.toml                              |   3 
crates/ai/src/ai.rs                               |   3 
crates/ai/src/auth.rs                             |  20 ++
crates/ai/src/embedding.rs                        |   8 
crates/ai/src/prompts/base.rs                     |  41 ----
crates/ai/src/providers/dummy.rs                  |  85 -----------
crates/ai/src/providers/mod.rs                    |   1 
crates/ai/src/providers/open_ai/auth.rs           |  33 ++++
crates/ai/src/providers/open_ai/embedding.rs      |  46 +----
crates/ai/src/providers/open_ai/mod.rs            |   1 
crates/ai/src/test.rs                             | 123 +++++++++++++++++
crates/assistant/src/codegen.rs                   |  14 +
crates/semantic_index/Cargo.toml                  |   1 
crates/semantic_index/src/embedding_queue.rs      |  16 +-
crates/semantic_index/src/semantic_index.rs       |  52 ++++--
crates/semantic_index/src/semantic_index_tests.rs | 101 ++-----------
16 files changed, 277 insertions(+), 271 deletions(-)

Detailed changes

crates/ai/Cargo.toml 🔗

@@ -8,6 +8,9 @@ publish = false
 path = "src/ai.rs"
 doctest = false
 
+[features]
+test-support = []
+
 [dependencies]
 gpui = { path = "../gpui" }
 util = { path = "../util" }

crates/ai/src/ai.rs 🔗

@@ -1,5 +1,8 @@
+pub mod auth;
 pub mod completion;
 pub mod embedding;
 pub mod models;
 pub mod prompts;
 pub mod providers;
+#[cfg(any(test, feature = "test-support"))]
+pub mod test;

crates/ai/src/auth.rs 🔗

@@ -0,0 +1,20 @@
+use gpui::AppContext;
+
+#[derive(Clone)]
+pub enum ProviderCredential {
+    Credentials { api_key: String },
+    NoCredentials,
+    NotNeeded,
+}
+
+pub trait CredentialProvider: Send + Sync {
+    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
+}
+
+#[derive(Clone)]
+pub struct NullCredentialProvider;
+impl CredentialProvider for NullCredentialProvider {
+    fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+        ProviderCredential::NotNeeded
+    }
+}

crates/ai/src/embedding.rs 🔗

@@ -7,6 +7,7 @@ use ordered_float::OrderedFloat;
 use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
 use rusqlite::ToSql;
 
+use crate::auth::{CredentialProvider, ProviderCredential};
 use crate::models::LanguageModel;
 
 #[derive(Debug, PartialEq, Clone)]
@@ -71,11 +72,14 @@ impl Embedding {
 #[async_trait]
 pub trait EmbeddingProvider: Sync + Send {
     fn base_model(&self) -> Box<dyn LanguageModel>;
-    fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
+    fn credential_provider(&self) -> Box<dyn CredentialProvider>;
+    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+        self.credential_provider().retrieve_credentials(cx)
+    }
     async fn embed_batch(
         &self,
         spans: Vec<String>,
-        api_key: Option<String>,
+        credential: ProviderCredential,
     ) -> Result<Vec<Embedding>>;
     fn max_tokens_per_batch(&self) -> usize;
     fn rate_limit_expiration(&self) -> Option<Instant>;

crates/ai/src/prompts/base.rs 🔗

@@ -126,6 +126,7 @@ impl PromptChain {
 #[cfg(test)]
 pub(crate) mod tests {
     use crate::models::TruncationDirection;
+    use crate::test::FakeLanguageModel;
 
     use super::*;
 
@@ -181,39 +182,7 @@ pub(crate) mod tests {
             }
         }
 
-        #[derive(Clone)]
-        struct DummyLanguageModel {
-            capacity: usize,
-        }
-
-        impl LanguageModel for DummyLanguageModel {
-            fn name(&self) -> String {
-                "dummy".to_string()
-            }
-            fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
-                anyhow::Ok(content.chars().collect::<Vec<char>>().len())
-            }
-            fn truncate(
-                &self,
-                content: &str,
-                length: usize,
-                direction: TruncationDirection,
-            ) -> anyhow::Result<String> {
-                anyhow::Ok(match direction {
-                    TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
-                        .into_iter()
-                        .collect::<String>(),
-                    TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
-                        .into_iter()
-                        .collect::<String>(),
-                })
-            }
-            fn capacity(&self) -> anyhow::Result<usize> {
-                anyhow::Ok(self.capacity)
-            }
-        }
-
-        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
+        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
         let args = PromptArguments {
             model: model.clone(),
             language_name: None,
@@ -249,7 +218,7 @@ pub(crate) mod tests {
 
         // Testing with Truncation Off
         // Should ignore capacity and return all prompts
-        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
+        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
         let args = PromptArguments {
             model: model.clone(),
             language_name: None,
@@ -286,7 +255,7 @@ pub(crate) mod tests {
         // Testing with Truncation Off
         // Should ignore capacity and return all prompts
         let capacity = 20;
-        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
         let args = PromptArguments {
             model: model.clone(),
             language_name: None,
@@ -322,7 +291,7 @@ pub(crate) mod tests {
         // Change Ordering of Prompts Based on Priority
         let capacity = 120;
         let reserved_tokens = 10;
-        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
         let args = PromptArguments {
             model: model.clone(),
             language_name: None,

crates/ai/src/providers/dummy.rs 🔗

@@ -1,85 +0,0 @@
-use std::time::Instant;
-
-use crate::{
-    completion::CompletionRequest,
-    embedding::{Embedding, EmbeddingProvider},
-    models::{LanguageModel, TruncationDirection},
-};
-use async_trait::async_trait;
-use gpui::AppContext;
-use serde::Serialize;
-
-pub struct DummyLanguageModel {}
-
-impl LanguageModel for DummyLanguageModel {
-    fn name(&self) -> String {
-        "dummy".to_string()
-    }
-    fn capacity(&self) -> anyhow::Result<usize> {
-        anyhow::Ok(1000)
-    }
-    fn truncate(
-        &self,
-        content: &str,
-        length: usize,
-        direction: crate::models::TruncationDirection,
-    ) -> anyhow::Result<String> {
-        if content.len() < length {
-            return anyhow::Ok(content.to_string());
-        }
-
-        let truncated = match direction {
-            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
-                .iter()
-                .collect::<String>(),
-            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[..length]
-                .iter()
-                .collect::<String>(),
-        };
-
-        anyhow::Ok(truncated)
-    }
-    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
-        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
-    }
-}
-
-#[derive(Serialize)]
-pub struct DummyCompletionRequest {
-    pub name: String,
-}
-
-impl CompletionRequest for DummyCompletionRequest {
-    fn data(&self) -> serde_json::Result<String> {
-        serde_json::to_string(self)
-    }
-}
-
-pub struct DummyEmbeddingProvider {}
-
-#[async_trait]
-impl EmbeddingProvider for DummyEmbeddingProvider {
-    fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
-        Some("Dummy Credentials".to_string())
-    }
-    fn base_model(&self) -> Box<dyn LanguageModel> {
-        Box::new(DummyLanguageModel {})
-    }
-    fn rate_limit_expiration(&self) -> Option<Instant> {
-        None
-    }
-    async fn embed_batch(
-        &self,
-        spans: Vec<String>,
-        api_key: Option<String>,
-    ) -> anyhow::Result<Vec<Embedding>> {
-        // 1024 is the OpenAI Embeddings size for ada models.
-        // the model we will likely be starting with.
-        let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
-        return Ok(vec![dummy_vec; spans.len()]);
-    }
-
-    fn max_tokens_per_batch(&self) -> usize {
-        8190
-    }
-}

crates/ai/src/providers/open_ai/auth.rs 🔗

@@ -0,0 +1,33 @@
+use std::env;
+
+use gpui::AppContext;
+use util::ResultExt;
+
+use crate::auth::{CredentialProvider, ProviderCredential};
+use crate::providers::open_ai::OPENAI_API_URL;
+
+#[derive(Clone)]
+pub struct OpenAICredentialProvider {}
+
+impl CredentialProvider for OpenAICredentialProvider {
+    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+        let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+            Some(api_key)
+        } else if let Some((_, api_key)) = cx
+            .platform()
+            .read_credentials(OPENAI_API_URL)
+            .log_err()
+            .flatten()
+        {
+            String::from_utf8(api_key).log_err()
+        } else {
+            None
+        };
+
+        if let Some(api_key) = api_key {
+            ProviderCredential::Credentials { api_key }
+        } else {
+            ProviderCredential::NoCredentials
+        }
+    }
+}

crates/ai/src/providers/open_ai/embedding.rs 🔗

@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
 use async_trait::async_trait;
 use futures::AsyncReadExt;
 use gpui::executor::Background;
-use gpui::{serde_json, AppContext};
+use gpui::serde_json;
 use isahc::http::StatusCode;
 use isahc::prelude::Configurable;
 use isahc::{AsyncBody, Response};
@@ -17,13 +17,13 @@ use std::sync::Arc;
 use std::time::{Duration, Instant};
 use tiktoken_rs::{cl100k_base, CoreBPE};
 use util::http::{HttpClient, Request};
-use util::ResultExt;
 
+use crate::auth::{CredentialProvider, ProviderCredential};
 use crate::embedding::{Embedding, EmbeddingProvider};
 use crate::models::LanguageModel;
 use crate::providers::open_ai::OpenAILanguageModel;
 
-use super::OPENAI_API_URL;
+use crate::providers::open_ai::auth::OpenAICredentialProvider;
 
 lazy_static! {
     static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
@@ -33,6 +33,7 @@ lazy_static! {
 #[derive(Clone)]
 pub struct OpenAIEmbeddingProvider {
     model: OpenAILanguageModel,
+    credential_provider: OpenAICredentialProvider,
     pub client: Arc<dyn HttpClient>,
     pub executor: Arc<Background>,
     rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@@ -73,6 +74,7 @@ impl OpenAIEmbeddingProvider {
 
         OpenAIEmbeddingProvider {
             model,
+            credential_provider: OpenAICredentialProvider {},
             client,
             executor,
             rate_limit_count_rx,
@@ -138,25 +140,17 @@ impl OpenAIEmbeddingProvider {
 
 #[async_trait]
 impl EmbeddingProvider for OpenAIEmbeddingProvider {
-    fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
-        let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
-            Some(api_key)
-        } else if let Some((_, api_key)) = cx
-            .platform()
-            .read_credentials(OPENAI_API_URL)
-            .log_err()
-            .flatten()
-        {
-            String::from_utf8(api_key).log_err()
-        } else {
-            None
-        };
-        api_key
-    }
     fn base_model(&self) -> Box<dyn LanguageModel> {
         let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
         model
     }
+
+    fn credential_provider(&self) -> Box<dyn CredentialProvider> {
+        let credential_provider: Box<dyn CredentialProvider> =
+            Box::new(self.credential_provider.clone());
+        credential_provider
+    }
+
     fn max_tokens_per_batch(&self) -> usize {
         50000
     }
@@ -164,25 +158,11 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
     fn rate_limit_expiration(&self) -> Option<Instant> {
         *self.rate_limit_count_rx.borrow()
     }
-    // fn truncate(&self, span: &str) -> (String, usize) {
-    //     let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
-    //     let output = if tokens.len() > OPENAI_INPUT_LIMIT {
-    //         tokens.truncate(OPENAI_INPUT_LIMIT);
-    //         OPENAI_BPE_TOKENIZER
-    //             .decode(tokens.clone())
-    //             .ok()
-    //             .unwrap_or_else(|| span.to_string())
-    //     } else {
-    //         span.to_string()
-    //     };
-
-    //     (output, tokens.len())
-    // }
 
     async fn embed_batch(
         &self,
         spans: Vec<String>,
-        api_key: Option<String>,
+        _credential: ProviderCredential,
     ) -> Result<Vec<Embedding>> {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
         const MAX_RETRIES: usize = 4;

crates/ai/src/test.rs 🔗

@@ -0,0 +1,123 @@
+use std::{
+    sync::atomic::{self, AtomicUsize, Ordering},
+    time::Instant,
+};
+
+use async_trait::async_trait;
+
+use crate::{
+    auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
+    embedding::{Embedding, EmbeddingProvider},
+    models::{LanguageModel, TruncationDirection},
+};
+
+#[derive(Clone)]
+pub struct FakeLanguageModel {
+    pub capacity: usize,
+}
+
+impl LanguageModel for FakeLanguageModel {
+    fn name(&self) -> String {
+        "dummy".to_string()
+    }
+    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
+    }
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: TruncationDirection,
+    ) -> anyhow::Result<String> {
+        anyhow::Ok(match direction {
+            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
+                .into_iter()
+                .collect::<String>(),
+            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
+                .into_iter()
+                .collect::<String>(),
+        })
+    }
+    fn capacity(&self) -> anyhow::Result<usize> {
+        anyhow::Ok(self.capacity)
+    }
+}
+
+pub struct FakeEmbeddingProvider {
+    pub embedding_count: AtomicUsize,
+    pub credential_provider: NullCredentialProvider,
+}
+
+impl Clone for FakeEmbeddingProvider {
+    fn clone(&self) -> Self {
+        FakeEmbeddingProvider {
+            embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
+            credential_provider: self.credential_provider.clone(),
+        }
+    }
+}
+
+impl Default for FakeEmbeddingProvider {
+    fn default() -> Self {
+        FakeEmbeddingProvider {
+            embedding_count: AtomicUsize::default(),
+            credential_provider: NullCredentialProvider {},
+        }
+    }
+}
+
+impl FakeEmbeddingProvider {
+    pub fn embedding_count(&self) -> usize {
+        self.embedding_count.load(atomic::Ordering::SeqCst)
+    }
+
+    pub fn embed_sync(&self, span: &str) -> Embedding {
+        let mut result = vec![1.0; 26];
+        for letter in span.chars() {
+            let letter = letter.to_ascii_lowercase();
+            if letter as u32 >= 'a' as u32 {
+                let ix = (letter as u32) - ('a' as u32);
+                if ix < 26 {
+                    result[ix as usize] += 1.0;
+                }
+            }
+        }
+
+        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+        for x in &mut result {
+            *x /= norm;
+        }
+
+        result.into()
+    }
+}
+
+#[async_trait]
+impl EmbeddingProvider for FakeEmbeddingProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel> {
+        Box::new(FakeLanguageModel { capacity: 1000 })
+    }
+    fn credential_provider(&self) -> Box<dyn CredentialProvider> {
+        let credential_provider: Box<dyn CredentialProvider> =
+            Box::new(self.credential_provider.clone());
+        credential_provider
+    }
+    fn max_tokens_per_batch(&self) -> usize {
+        1000
+    }
+
+    fn rate_limit_expiration(&self) -> Option<Instant> {
+        None
+    }
+
+    async fn embed_batch(
+        &self,
+        spans: Vec<String>,
+        _credential: ProviderCredential,
+    ) -> anyhow::Result<Vec<Embedding>> {
+        self.embedding_count
+            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
+
+        anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
+    }
+}

crates/assistant/src/codegen.rs 🔗

@@ -335,7 +335,6 @@ fn strip_markdown_codeblock(
 #[cfg(test)]
 mod tests {
     use super::*;
-    use ai::providers::dummy::DummyCompletionRequest;
     use futures::{
         future::BoxFuture,
         stream::{self, BoxStream},
@@ -345,9 +344,21 @@ mod tests {
     use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
     use parking_lot::Mutex;
     use rand::prelude::*;
+    use serde::Serialize;
     use settings::SettingsStore;
     use smol::future::FutureExt;
 
+    #[derive(Serialize)]
+    pub struct DummyCompletionRequest {
+        pub name: String,
+    }
+
+    impl CompletionRequest for DummyCompletionRequest {
+        fn data(&self) -> serde_json::Result<String> {
+            serde_json::to_string(self)
+        }
+    }
+
     #[gpui::test(iterations = 10)]
     async fn test_transform_autoindent(
         cx: &mut TestAppContext,
@@ -381,6 +392,7 @@ mod tests {
                 cx,
             )
         });
+
         let request = Box::new(DummyCompletionRequest {
             name: "test".to_string(),
         });

crates/semantic_index/Cargo.toml 🔗

@@ -42,6 +42,7 @@ sha1 = "0.10.5"
 ndarray = { version = "0.15.0" }
 
 [dev-dependencies]
+ai = { path = "../ai", features = ["test-support"] }
 collections = { path = "../collections", features = ["test-support"] }
 gpui = { path = "../gpui", features = ["test-support"] }
 language = { path = "../language", features = ["test-support"] }

crates/semantic_index/src/embedding_queue.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{parsing::Span, JobHandle};
-use ai::embedding::EmbeddingProvider;
+use ai::{auth::ProviderCredential, embedding::EmbeddingProvider};
 use gpui::executor::Background;
 use parking_lot::Mutex;
 use smol::channel;
@@ -41,7 +41,7 @@ pub struct EmbeddingQueue {
     pending_batch_token_count: usize,
     finished_files_tx: channel::Sender<FileToEmbed>,
     finished_files_rx: channel::Receiver<FileToEmbed>,
-    api_key: Option<String>,
+    provider_credential: ProviderCredential,
 }
 
 #[derive(Clone)]
@@ -54,7 +54,7 @@ impl EmbeddingQueue {
     pub fn new(
         embedding_provider: Arc<dyn EmbeddingProvider>,
         executor: Arc<Background>,
-        api_key: Option<String>,
+        provider_credential: ProviderCredential,
     ) -> Self {
         let (finished_files_tx, finished_files_rx) = channel::unbounded();
         Self {
@@ -64,12 +64,12 @@ impl EmbeddingQueue {
             pending_batch_token_count: 0,
             finished_files_tx,
             finished_files_rx,
-            api_key,
+            provider_credential,
         }
     }
 
-    pub fn set_api_key(&mut self, api_key: Option<String>) {
-        self.api_key = api_key
+    pub fn set_credential(&mut self, credential: ProviderCredential) {
+        self.provider_credential = credential
     }
 
     pub fn push(&mut self, file: FileToEmbed) {
@@ -118,7 +118,7 @@ impl EmbeddingQueue {
 
         let finished_files_tx = self.finished_files_tx.clone();
         let embedding_provider = self.embedding_provider.clone();
-        let api_key = self.api_key.clone();
+        let credential = self.provider_credential.clone();
 
         self.executor
             .spawn(async move {
@@ -143,7 +143,7 @@ impl EmbeddingQueue {
                     return;
                 };
 
-                match embedding_provider.embed_batch(spans, api_key).await {
+                match embedding_provider.embed_batch(spans, credential).await {
                     Ok(embeddings) => {
                         let mut embeddings = embeddings.into_iter();
                         for fragment in batch {

crates/semantic_index/src/semantic_index.rs 🔗

@@ -7,6 +7,7 @@ pub mod semantic_index_settings;
 mod semantic_index_tests;
 
 use crate::semantic_index_settings::SemanticIndexSettings;
+use ai::auth::ProviderCredential;
 use ai::embedding::{Embedding, EmbeddingProvider};
 use ai::providers::open_ai::OpenAIEmbeddingProvider;
 use anyhow::{anyhow, Result};
@@ -124,7 +125,7 @@ pub struct SemanticIndex {
     _embedding_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
-    api_key: Option<String>,
+    provider_credential: ProviderCredential,
     embedding_queue: Arc<Mutex<EmbeddingQueue>>,
 }
 
@@ -279,18 +280,27 @@ impl SemanticIndex {
         }
     }
 
-    pub fn authenticate(&mut self, cx: &AppContext) {
-        if self.api_key.is_none() {
-            self.api_key = self.embedding_provider.retrieve_credentials(cx);
-
-            self.embedding_queue
-                .lock()
-                .set_api_key(self.api_key.clone());
+    pub fn authenticate(&mut self, cx: &AppContext) -> bool {
+        let credential = self.provider_credential.clone();
+        match credential {
+            ProviderCredential::NoCredentials => {
+                let credential = self.embedding_provider.retrieve_credentials(cx);
+                self.provider_credential = credential;
+            }
+            _ => {}
         }
+
+        self.embedding_queue.lock().set_credential(credential);
+
+        self.is_authenticated()
     }
 
     pub fn is_authenticated(&self) -> bool {
-        self.api_key.is_some()
+        let credential = &self.provider_credential;
+        match credential {
+            &ProviderCredential::Credentials { .. } => true,
+            _ => false,
+        }
     }
 
     pub fn enabled(cx: &AppContext) -> bool {
@@ -340,7 +350,7 @@ impl SemanticIndex {
         Ok(cx.add_model(|cx| {
             let t0 = Instant::now();
             let embedding_queue =
-                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
+                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials);
             let _embedding_task = cx.background().spawn({
                 let embedded_files = embedding_queue.finished_files();
                 let db = db.clone();
@@ -405,7 +415,7 @@ impl SemanticIndex {
                 _embedding_task,
                 _parsing_files_tasks,
                 projects: Default::default(),
-                api_key: None,
+                provider_credential: ProviderCredential::NoCredentials,
                 embedding_queue
             }
         }))
@@ -721,13 +731,14 @@ impl SemanticIndex {
 
         let index = self.index_project(project.clone(), cx);
         let embedding_provider = self.embedding_provider.clone();
-        let api_key = self.api_key.clone();
+        let credential = self.provider_credential.clone();
 
         cx.spawn(|this, mut cx| async move {
             index.await?;
             let t0 = Instant::now();
+
             let query = embedding_provider
-                .embed_batch(vec![query], api_key)
+                .embed_batch(vec![query], credential)
                 .await?
                 .pop()
                 .ok_or_else(|| anyhow!("could not embed query"))?;
@@ -945,7 +956,7 @@ impl SemanticIndex {
         let fs = self.fs.clone();
         let db_path = self.db.path().clone();
         let background = cx.background().clone();
-        let api_key = self.api_key.clone();
+        let credential = self.provider_credential.clone();
         cx.background().spawn(async move {
             let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
             let mut results = Vec::<SearchResult>::new();
@@ -964,7 +975,7 @@ impl SemanticIndex {
                     &mut spans,
                     embedding_provider.as_ref(),
                     &db,
-                    api_key.clone(),
+                    credential.clone(),
                 )
                 .await
                 .log_err()
@@ -1008,9 +1019,8 @@ impl SemanticIndex {
         project: ModelHandle<Project>,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<()>> {
-        if self.api_key.is_none() {
-            self.authenticate(cx);
-            if self.api_key.is_none() {
+        if !self.is_authenticated() {
+            if !self.authenticate(cx) {
                 return Task::ready(Err(anyhow!("user is not authenticated")));
             }
         }
@@ -1193,7 +1203,7 @@ impl SemanticIndex {
         spans: &mut [Span],
         embedding_provider: &dyn EmbeddingProvider,
         db: &VectorDatabase,
-        api_key: Option<String>,
+        credential: ProviderCredential,
     ) -> Result<()> {
         let mut batch = Vec::new();
         let mut batch_tokens = 0;
@@ -1216,7 +1226,7 @@ impl SemanticIndex {
 
             if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
                 let batch_embeddings = embedding_provider
-                    .embed_batch(mem::take(&mut batch), api_key.clone())
+                    .embed_batch(mem::take(&mut batch), credential.clone())
                     .await?;
                 embeddings.extend(batch_embeddings);
                 batch_tokens = 0;
@@ -1228,7 +1238,7 @@ impl SemanticIndex {
 
         if !batch.is_empty() {
             let batch_embeddings = embedding_provider
-                .embed_batch(mem::take(&mut batch), api_key)
+                .embed_batch(mem::take(&mut batch), credential)
                 .await?;
 
             embeddings.extend(batch_embeddings);

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -4,14 +4,9 @@ use crate::{
     semantic_index_settings::SemanticIndexSettings,
     FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
 };
-use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel};
-use ai::{
-    embedding::{Embedding, EmbeddingProvider},
-    models::LanguageModel,
-};
-use anyhow::Result;
-use async_trait::async_trait;
-use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
+use ai::test::FakeEmbeddingProvider;
+
+use gpui::{executor::Deterministic, Task, TestAppContext};
 use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
 use parking_lot::Mutex;
 use pretty_assertions::assert_eq;
@@ -19,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
 use rand::{rngs::StdRng, Rng};
 use serde_json::json;
 use settings::SettingsStore;
-use std::{
-    path::Path,
-    sync::{
-        atomic::{self, AtomicUsize},
-        Arc,
-    },
-    time::{Instant, SystemTime},
-};
+use std::{path::Path, sync::Arc, time::SystemTime};
 use unindent::Unindent;
 use util::RandomCharIter;
 
@@ -232,7 +220,11 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
 
     let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 
-    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
+    let mut queue = EmbeddingQueue::new(
+        embedding_provider.clone(),
+        cx.background(),
+        ai::auth::ProviderCredential::NoCredentials,
+    );
     for file in &files {
         queue.push(file.clone());
     }
@@ -284,7 +276,7 @@ fn assert_search_results(
 #[gpui::test]
 async fn test_code_context_retrieval_rust() {
     let language = rust_lang();
-    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
@@ -386,7 +378,7 @@ async fn test_code_context_retrieval_rust() {
 #[gpui::test]
 async fn test_code_context_retrieval_json() {
     let language = json_lang();
-    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -470,7 +462,7 @@ fn assert_documents_eq(
 #[gpui::test]
 async fn test_code_context_retrieval_javascript() {
     let language = js_lang();
-    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
@@ -569,7 +561,7 @@ async fn test_code_context_retrieval_javascript() {
 #[gpui::test]
 async fn test_code_context_retrieval_lua() {
     let language = lua_lang();
-    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -643,7 +635,7 @@ async fn test_code_context_retrieval_lua() {
 #[gpui::test]
 async fn test_code_context_retrieval_elixir() {
     let language = elixir_lang();
-    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -760,7 +752,7 @@ async fn test_code_context_retrieval_elixir() {
 #[gpui::test]
 async fn test_code_context_retrieval_cpp() {
     let language = cpp_lang();
-    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
@@ -913,7 +905,7 @@ async fn test_code_context_retrieval_cpp() {
 #[gpui::test]
 async fn test_code_context_retrieval_ruby() {
     let language = ruby_lang();
-    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -1104,7 +1096,7 @@ async fn test_code_context_retrieval_ruby() {
 #[gpui::test]
 async fn test_code_context_retrieval_php() {
     let language = php_lang();
-    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -1252,65 +1244,6 @@ async fn test_code_context_retrieval_php() {
     );
 }
 
-#[derive(Default)]
-struct FakeEmbeddingProvider {
-    embedding_count: AtomicUsize,
-}
-
-impl FakeEmbeddingProvider {
-    fn embedding_count(&self) -> usize {
-        self.embedding_count.load(atomic::Ordering::SeqCst)
-    }
-
-    fn embed_sync(&self, span: &str) -> Embedding {
-        let mut result = vec![1.0; 26];
-        for letter in span.chars() {
-            let letter = letter.to_ascii_lowercase();
-            if letter as u32 >= 'a' as u32 {
-                let ix = (letter as u32) - ('a' as u32);
-                if ix < 26 {
-                    result[ix as usize] += 1.0;
-                }
-            }
-        }
-
-        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
-        for x in &mut result {
-            *x /= norm;
-        }
-
-        result.into()
-    }
-}
-
-#[async_trait]
-impl EmbeddingProvider for FakeEmbeddingProvider {
-    fn base_model(&self) -> Box<dyn LanguageModel> {
-        Box::new(DummyLanguageModel {})
-    }
-    fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
-        Some("Fake Credentials".to_string())
-    }
-    fn max_tokens_per_batch(&self) -> usize {
-        1000
-    }
-
-    fn rate_limit_expiration(&self) -> Option<Instant> {
-        None
-    }
-
-    async fn embed_batch(
-        &self,
-        spans: Vec<String>,
-        _api_key: Option<String>,
-    ) -> Result<Vec<Embedding>> {
-        self.embedding_count
-            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
-
-        anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
-    }
-}
-
 fn js_lang() -> Arc<Language> {
     Arc::new(
         Language::new(