moved authentication for the semantic index into the EmbeddingProvider

KCaverly created

Change summary

crates/ai/src/auth.rs                             | 11 -
crates/ai/src/completion.rs                       | 18 ---
crates/ai/src/embedding.rs                        | 15 --
crates/ai/src/providers/open_ai/auth.rs           | 46 --------
crates/ai/src/providers/open_ai/completion.rs     | 78 +++++++++++---
crates/ai/src/providers/open_ai/embedding.rs      | 88 +++++++++++++---
crates/ai/src/providers/open_ai/mod.rs            |  3 
crates/ai/src/providers/open_ai/new.rs            | 11 ++
crates/ai/src/test.rs                             | 48 +++++---
crates/assistant/src/assistant_panel.rs           |  7 
crates/assistant/src/codegen.rs                   |  8 
crates/semantic_index/src/embedding_queue.rs      | 17 --
crates/semantic_index/src/semantic_index.rs       | 50 ++------
crates/semantic_index/src/semantic_index_tests.rs |  6 
14 files changed, 200 insertions(+), 206 deletions(-)

Detailed changes

crates/ai/src/auth.rs 🔗

@@ -8,17 +8,8 @@ pub enum ProviderCredential {
 }
 
 pub trait CredentialProvider: Send + Sync {
+    fn has_credentials(&self) -> bool;
     fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
     fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
     fn delete_credentials(&self, cx: &AppContext);
 }
-
-#[derive(Clone)]
-pub struct NullCredentialProvider;
-impl CredentialProvider for NullCredentialProvider {
-    fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
-        ProviderCredential::NotNeeded
-    }
-    fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {}
-    fn delete_credentials(&self, cx: &AppContext) {}
-}

crates/ai/src/completion.rs 🔗

@@ -1,28 +1,14 @@
 use anyhow::Result;
 use futures::{future::BoxFuture, stream::BoxStream};
-use gpui::AppContext;
 
-use crate::{
-    auth::{CredentialProvider, ProviderCredential},
-    models::LanguageModel,
-};
+use crate::{auth::CredentialProvider, models::LanguageModel};
 
 pub trait CompletionRequest: Send + Sync {
     fn data(&self) -> serde_json::Result<String>;
 }
 
-pub trait CompletionProvider {
+pub trait CompletionProvider: CredentialProvider {
     fn base_model(&self) -> Box<dyn LanguageModel>;
-    fn credential_provider(&self) -> Box<dyn CredentialProvider>;
-    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
-        self.credential_provider().retrieve_credentials(cx)
-    }
-    fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
-        self.credential_provider().save_credentials(cx, credential);
-    }
-    fn delete_credentials(&self, cx: &AppContext) {
-        self.credential_provider().delete_credentials(cx);
-    }
     fn complete(
         &self,
         prompt: Box<dyn CompletionRequest>,

crates/ai/src/embedding.rs 🔗

@@ -2,12 +2,11 @@ use std::time::Instant;
 
 use anyhow::Result;
 use async_trait::async_trait;
-use gpui::AppContext;
 use ordered_float::OrderedFloat;
 use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
 use rusqlite::ToSql;
 
-use crate::auth::{CredentialProvider, ProviderCredential};
+use crate::auth::CredentialProvider;
 use crate::models::LanguageModel;
 
 #[derive(Debug, PartialEq, Clone)]
@@ -70,17 +69,9 @@ impl Embedding {
 }
 
 #[async_trait]
-pub trait EmbeddingProvider: Sync + Send {
+pub trait EmbeddingProvider: CredentialProvider {
     fn base_model(&self) -> Box<dyn LanguageModel>;
-    fn credential_provider(&self) -> Box<dyn CredentialProvider>;
-    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
-        self.credential_provider().retrieve_credentials(cx)
-    }
-    async fn embed_batch(
-        &self,
-        spans: Vec<String>,
-        credential: ProviderCredential,
-    ) -> Result<Vec<Embedding>>;
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
     fn max_tokens_per_batch(&self) -> usize;
     fn rate_limit_expiration(&self) -> Option<Instant>;
 }

crates/ai/src/providers/open_ai/auth.rs 🔗

@@ -1,46 +0,0 @@
-use std::env;
-
-use gpui::AppContext;
-use util::ResultExt;
-
-use crate::auth::{CredentialProvider, ProviderCredential};
-use crate::providers::open_ai::OPENAI_API_URL;
-
-#[derive(Clone)]
-pub struct OpenAICredentialProvider {}
-
-impl CredentialProvider for OpenAICredentialProvider {
-    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
-        let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
-            Some(api_key)
-        } else if let Some((_, api_key)) = cx
-            .platform()
-            .read_credentials(OPENAI_API_URL)
-            .log_err()
-            .flatten()
-        {
-            String::from_utf8(api_key).log_err()
-        } else {
-            None
-        };
-
-        if let Some(api_key) = api_key {
-            ProviderCredential::Credentials { api_key }
-        } else {
-            ProviderCredential::NoCredentials
-        }
-    }
-    fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
-        match credential {
-            ProviderCredential::Credentials { api_key } => {
-                cx.platform()
-                    .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
-                    .log_err();
-            }
-            _ => {}
-        }
-    }
-    fn delete_credentials(&self, cx: &AppContext) {
-        cx.platform().delete_credentials(OPENAI_API_URL).log_err();
-    }
-}

crates/ai/src/providers/open_ai/completion.rs 🔗

@@ -3,14 +3,17 @@ use futures::{
     future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
     Stream, StreamExt,
 };
-use gpui::executor::Background;
+use gpui::{executor::Background, AppContext};
 use isahc::{http::StatusCode, Request, RequestExt};
+use parking_lot::RwLock;
 use serde::{Deserialize, Serialize};
 use std::{
+    env,
     fmt::{self, Display},
     io,
     sync::Arc,
 };
+use util::ResultExt;
 
 use crate::{
     auth::{CredentialProvider, ProviderCredential},
@@ -18,9 +21,7 @@ use crate::{
     models::LanguageModel,
 };
 
-use super::{auth::OpenAICredentialProvider, OpenAILanguageModel};
-
-pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
+use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
 
 #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 #[serde(rename_all = "lowercase")]
@@ -194,42 +195,83 @@ pub async fn stream_completion(
 
 pub struct OpenAICompletionProvider {
     model: OpenAILanguageModel,
-    credential_provider: OpenAICredentialProvider,
-    credential: ProviderCredential,
+    credential: Arc<RwLock<ProviderCredential>>,
     executor: Arc<Background>,
 }
 
 impl OpenAICompletionProvider {
-    pub fn new(
-        model_name: &str,
-        credential: ProviderCredential,
-        executor: Arc<Background>,
-    ) -> Self {
+    pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
         let model = OpenAILanguageModel::load(model_name);
-        let credential_provider = OpenAICredentialProvider {};
+        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
         Self {
             model,
-            credential_provider,
             credential,
             executor,
         }
     }
 }
 
+impl CredentialProvider for OpenAICompletionProvider {
+    fn has_credentials(&self) -> bool {
+        match *self.credential.read() {
+            ProviderCredential::Credentials { .. } => true,
+            _ => false,
+        }
+    }
+    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+        let mut credential = self.credential.write();
+        match *credential {
+            ProviderCredential::Credentials { .. } => {
+                return credential.clone();
+            }
+            _ => {
+                if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+                    *credential = ProviderCredential::Credentials { api_key };
+                } else if let Some((_, api_key)) = cx
+                    .platform()
+                    .read_credentials(OPENAI_API_URL)
+                    .log_err()
+                    .flatten()
+                {
+                    if let Some(api_key) = String::from_utf8(api_key).log_err() {
+                        *credential = ProviderCredential::Credentials { api_key };
+                    }
+                } else {
+                };
+            }
+        }
+
+        credential.clone()
+    }
+
+    fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
+        match credential.clone() {
+            ProviderCredential::Credentials { api_key } => {
+                cx.platform()
+                    .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+                    .log_err();
+            }
+            _ => {}
+        }
+
+        *self.credential.write() = credential;
+    }
+    fn delete_credentials(&self, cx: &AppContext) {
+        cx.platform().delete_credentials(OPENAI_API_URL).log_err();
+        *self.credential.write() = ProviderCredential::NoCredentials;
+    }
+}
+
 impl CompletionProvider for OpenAICompletionProvider {
     fn base_model(&self) -> Box<dyn LanguageModel> {
         let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
         model
     }
-    fn credential_provider(&self) -> Box<dyn CredentialProvider> {
-        let provider: Box<dyn CredentialProvider> = Box::new(self.credential_provider.clone());
-        provider
-    }
     fn complete(
         &self,
         prompt: Box<dyn CompletionRequest>,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
-        let credential = self.credential.clone();
+        let credential = self.credential.read().clone();
         let request = stream_completion(credential, self.executor.clone(), prompt);
         async move {
             let response = request.await?;

crates/ai/src/providers/open_ai/embedding.rs 🔗

@@ -2,27 +2,29 @@ use anyhow::{anyhow, Result};
 use async_trait::async_trait;
 use futures::AsyncReadExt;
 use gpui::executor::Background;
-use gpui::serde_json;
+use gpui::{serde_json, AppContext};
 use isahc::http::StatusCode;
 use isahc::prelude::Configurable;
 use isahc::{AsyncBody, Response};
 use lazy_static::lazy_static;
-use parking_lot::Mutex;
+use parking_lot::{Mutex, RwLock};
 use parse_duration::parse;
 use postage::watch;
 use serde::{Deserialize, Serialize};
+use std::env;
 use std::ops::Add;
 use std::sync::Arc;
 use std::time::{Duration, Instant};
 use tiktoken_rs::{cl100k_base, CoreBPE};
 use util::http::{HttpClient, Request};
+use util::ResultExt;
 
 use crate::auth::{CredentialProvider, ProviderCredential};
 use crate::embedding::{Embedding, EmbeddingProvider};
 use crate::models::LanguageModel;
 use crate::providers::open_ai::OpenAILanguageModel;
 
-use crate::providers::open_ai::auth::OpenAICredentialProvider;
+use crate::providers::open_ai::OPENAI_API_URL;
 
 lazy_static! {
     static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
@@ -31,7 +33,7 @@ lazy_static! {
 #[derive(Clone)]
 pub struct OpenAIEmbeddingProvider {
     model: OpenAILanguageModel,
-    credential_provider: OpenAICredentialProvider,
+    credential: Arc<RwLock<ProviderCredential>>,
     pub client: Arc<dyn HttpClient>,
     pub executor: Arc<Background>,
     rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@@ -69,10 +71,11 @@ impl OpenAIEmbeddingProvider {
         let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 
         let model = OpenAILanguageModel::load("text-embedding-ada-002");
+        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
 
         OpenAIEmbeddingProvider {
             model,
-            credential_provider: OpenAICredentialProvider {},
+            credential,
             client,
             executor,
             rate_limit_count_rx,
@@ -80,6 +83,13 @@ impl OpenAIEmbeddingProvider {
         }
     }
 
+    fn get_api_key(&self) -> Result<String> {
+        match self.credential.read().clone() {
+            ProviderCredential::Credentials { api_key } => Ok(api_key),
+            _ => Err(anyhow!("api credentials not provided")),
+        }
+    }
+
     fn resolve_rate_limit(&self) {
         let reset_time = *self.rate_limit_count_tx.lock().borrow();
 
@@ -136,6 +146,57 @@ impl OpenAIEmbeddingProvider {
     }
 }
 
+impl CredentialProvider for OpenAIEmbeddingProvider {
+    fn has_credentials(&self) -> bool {
+        match *self.credential.read() {
+            ProviderCredential::Credentials { .. } => true,
+            _ => false,
+        }
+    }
+    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+        let mut credential = self.credential.write();
+        match *credential {
+            ProviderCredential::Credentials { .. } => {
+                return credential.clone();
+            }
+            _ => {
+                if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+                    *credential = ProviderCredential::Credentials { api_key };
+                } else if let Some((_, api_key)) = cx
+                    .platform()
+                    .read_credentials(OPENAI_API_URL)
+                    .log_err()
+                    .flatten()
+                {
+                    if let Some(api_key) = String::from_utf8(api_key).log_err() {
+                        *credential = ProviderCredential::Credentials { api_key };
+                    }
+                } else {
+                };
+            }
+        }
+
+        credential.clone()
+    }
+
+    fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
+        match credential.clone() {
+            ProviderCredential::Credentials { api_key } => {
+                cx.platform()
+                    .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+                    .log_err();
+            }
+            _ => {}
+        }
+
+        *self.credential.write() = credential;
+    }
+    fn delete_credentials(&self, cx: &AppContext) {
+        cx.platform().delete_credentials(OPENAI_API_URL).log_err();
+        *self.credential.write() = ProviderCredential::NoCredentials;
+    }
+}
+
 #[async_trait]
 impl EmbeddingProvider for OpenAIEmbeddingProvider {
     fn base_model(&self) -> Box<dyn LanguageModel> {
@@ -143,12 +204,6 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
         model
     }
 
-    fn credential_provider(&self) -> Box<dyn CredentialProvider> {
-        let credential_provider: Box<dyn CredentialProvider> =
-            Box::new(self.credential_provider.clone());
-        credential_provider
-    }
-
     fn max_tokens_per_batch(&self) -> usize {
         50000
     }
@@ -157,18 +212,11 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
         *self.rate_limit_count_rx.borrow()
     }
 
-    async fn embed_batch(
-        &self,
-        spans: Vec<String>,
-        credential: ProviderCredential,
-    ) -> Result<Vec<Embedding>> {
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
         const MAX_RETRIES: usize = 4;
 
-        let api_key = match credential {
-            ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key),
-            _ => Err(anyhow!("no api key provided")),
-        }?;
+        let api_key = self.get_api_key()?;
 
         let mut request_number = 0;
         let mut rate_limiting = false;

crates/ai/src/providers/open_ai/mod.rs 🔗

@@ -1,4 +1,3 @@
-pub mod auth;
 pub mod completion;
 pub mod embedding;
 pub mod model;
@@ -6,3 +5,5 @@ pub mod model;
 pub use completion::*;
 pub use embedding::*;
 pub use model::OpenAILanguageModel;
+
+pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";

crates/ai/src/providers/open_ai/new.rs 🔗

@@ -0,0 +1,11 @@
+pub trait LanguageModel {
+    fn name(&self) -> String;
+    fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: TruncationDirection,
+    ) -> anyhow::Result<String>;
+    fn capacity(&self) -> anyhow::Result<usize>;
+}

crates/ai/src/test.rs 🔗

@@ -5,10 +5,11 @@ use std::{
 
 use async_trait::async_trait;
 use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui::AppContext;
 use parking_lot::Mutex;
 
 use crate::{
-    auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
+    auth::{CredentialProvider, ProviderCredential},
     completion::{CompletionProvider, CompletionRequest},
     embedding::{Embedding, EmbeddingProvider},
     models::{LanguageModel, TruncationDirection},
@@ -52,14 +53,12 @@ impl LanguageModel for FakeLanguageModel {
 
 pub struct FakeEmbeddingProvider {
     pub embedding_count: AtomicUsize,
-    pub credential_provider: NullCredentialProvider,
 }
 
 impl Clone for FakeEmbeddingProvider {
     fn clone(&self) -> Self {
         FakeEmbeddingProvider {
             embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
-            credential_provider: self.credential_provider.clone(),
         }
     }
 }
@@ -68,7 +67,6 @@ impl Default for FakeEmbeddingProvider {
     fn default() -> Self {
         FakeEmbeddingProvider {
             embedding_count: AtomicUsize::default(),
-            credential_provider: NullCredentialProvider {},
         }
     }
 }
@@ -99,16 +97,22 @@ impl FakeEmbeddingProvider {
     }
 }
 
+impl CredentialProvider for FakeEmbeddingProvider {
+    fn has_credentials(&self) -> bool {
+        true
+    }
+    fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+        ProviderCredential::NotNeeded
+    }
+    fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
+    fn delete_credentials(&self, _cx: &AppContext) {}
+}
+
 #[async_trait]
 impl EmbeddingProvider for FakeEmbeddingProvider {
     fn base_model(&self) -> Box<dyn LanguageModel> {
         Box::new(FakeLanguageModel { capacity: 1000 })
     }
-    fn credential_provider(&self) -> Box<dyn CredentialProvider> {
-        let credential_provider: Box<dyn CredentialProvider> =
-            Box::new(self.credential_provider.clone());
-        credential_provider
-    }
     fn max_tokens_per_batch(&self) -> usize {
         1000
     }
@@ -117,11 +121,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
         None
     }
 
-    async fn embed_batch(
-        &self,
-        spans: Vec<String>,
-        _credential: ProviderCredential,
-    ) -> anyhow::Result<Vec<Embedding>> {
+    async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
         self.embedding_count
             .fetch_add(spans.len(), atomic::Ordering::SeqCst);
 
@@ -129,11 +129,11 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
     }
 }
 
-pub struct TestCompletionProvider {
+pub struct FakeCompletionProvider {
     last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
 }
 
-impl TestCompletionProvider {
+impl FakeCompletionProvider {
     pub fn new() -> Self {
         Self {
             last_completion_tx: Mutex::new(None),
@@ -150,14 +150,22 @@ impl TestCompletionProvider {
     }
 }
 
-impl CompletionProvider for TestCompletionProvider {
+impl CredentialProvider for FakeCompletionProvider {
+    fn has_credentials(&self) -> bool {
+        true
+    }
+    fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+        ProviderCredential::NotNeeded
+    }
+    fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
+    fn delete_credentials(&self, _cx: &AppContext) {}
+}
+
+impl CompletionProvider for FakeCompletionProvider {
     fn base_model(&self) -> Box<dyn LanguageModel> {
         let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
         model
     }
-    fn credential_provider(&self) -> Box<dyn CredentialProvider> {
-        Box::new(NullCredentialProvider {})
-    }
     fn complete(
         &self,
         _prompt: Box<dyn CompletionRequest>,

crates/assistant/src/assistant_panel.rs 🔗

@@ -10,7 +10,7 @@ use ai::{
     auth::ProviderCredential,
     completion::{CompletionProvider, CompletionRequest},
     providers::open_ai::{
-        stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+        stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage,
     },
 };
 
@@ -48,7 +48,7 @@ use semantic_index::{SemanticIndex, SemanticIndexStatus};
 use settings::SettingsStore;
 use std::{
     cell::{Cell, RefCell},
-    cmp, env,
+    cmp,
     fmt::Write,
     iter,
     ops::Range,
@@ -210,7 +210,6 @@ impl AssistantPanel {
                     // Defaulting currently to GPT4, allow for this to be set via config.
                     let completion_provider = Box::new(OpenAICompletionProvider::new(
                         "gpt-4",
-                        ProviderCredential::NoCredentials,
                         cx.background().clone(),
                     ));
 
@@ -298,7 +297,6 @@ impl AssistantPanel {
         cx: &mut ViewContext<Self>,
         project: &ModelHandle<Project>,
     ) {
-        let credential = self.credential.borrow().clone();
         let selection = editor.read(cx).selections.newest_anchor().clone();
         if selection.start.excerpt_id() != selection.end.excerpt_id() {
             return;
@@ -330,7 +328,6 @@ impl AssistantPanel {
         let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
         let provider = Arc::new(OpenAICompletionProvider::new(
             "gpt-4",
-            credential,
             cx.background().clone(),
         ));
 

crates/assistant/src/codegen.rs 🔗

@@ -335,7 +335,7 @@ fn strip_markdown_codeblock(
 #[cfg(test)]
 mod tests {
     use super::*;
-    use ai::test::TestCompletionProvider;
+    use ai::test::FakeCompletionProvider;
     use futures::stream::{self};
     use gpui::{executor::Deterministic, TestAppContext};
     use indoc::indoc;
@@ -379,7 +379,7 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
         });
-        let provider = Arc::new(TestCompletionProvider::new());
+        let provider = Arc::new(FakeCompletionProvider::new());
         let codegen = cx.add_model(|cx| {
             Codegen::new(
                 buffer.clone(),
@@ -445,7 +445,7 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 6))
         });
-        let provider = Arc::new(TestCompletionProvider::new());
+        let provider = Arc::new(FakeCompletionProvider::new());
         let codegen = cx.add_model(|cx| {
             Codegen::new(
                 buffer.clone(),
@@ -511,7 +511,7 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 2))
         });
-        let provider = Arc::new(TestCompletionProvider::new());
+        let provider = Arc::new(FakeCompletionProvider::new());
         let codegen = cx.add_model(|cx| {
             Codegen::new(
                 buffer.clone(),

crates/semantic_index/src/embedding_queue.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{parsing::Span, JobHandle};
-use ai::{auth::ProviderCredential, embedding::EmbeddingProvider};
+use ai::embedding::EmbeddingProvider;
 use gpui::executor::Background;
 use parking_lot::Mutex;
 use smol::channel;
@@ -41,7 +41,6 @@ pub struct EmbeddingQueue {
     pending_batch_token_count: usize,
     finished_files_tx: channel::Sender<FileToEmbed>,
     finished_files_rx: channel::Receiver<FileToEmbed>,
-    pub provider_credential: ProviderCredential,
 }
 
 #[derive(Clone)]
@@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed {
 }
 
 impl EmbeddingQueue {
-    pub fn new(
-        embedding_provider: Arc<dyn EmbeddingProvider>,
-        executor: Arc<Background>,
-        provider_credential: ProviderCredential,
-    ) -> Self {
+    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
         let (finished_files_tx, finished_files_rx) = channel::unbounded();
         Self {
             embedding_provider,
@@ -64,14 +59,9 @@ impl EmbeddingQueue {
             pending_batch_token_count: 0,
             finished_files_tx,
             finished_files_rx,
-            provider_credential,
         }
     }
 
-    pub fn set_credential(&mut self, credential: ProviderCredential) {
-        self.provider_credential = credential;
-    }
-
     pub fn push(&mut self, file: FileToEmbed) {
         if file.spans.is_empty() {
             self.finished_files_tx.try_send(file).unwrap();
@@ -118,7 +108,6 @@ impl EmbeddingQueue {
 
         let finished_files_tx = self.finished_files_tx.clone();
         let embedding_provider = self.embedding_provider.clone();
-        let credential = self.provider_credential.clone();
 
         self.executor
             .spawn(async move {
@@ -143,7 +132,7 @@ impl EmbeddingQueue {
                     return;
                 };
 
-                match embedding_provider.embed_batch(spans, credential).await {
+                match embedding_provider.embed_batch(spans).await {
                     Ok(embeddings) => {
                         let mut embeddings = embeddings.into_iter();
                         for fragment in batch {

crates/semantic_index/src/semantic_index.rs 🔗

@@ -7,7 +7,6 @@ pub mod semantic_index_settings;
 mod semantic_index_tests;
 
 use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::auth::ProviderCredential;
 use ai::embedding::{Embedding, EmbeddingProvider};
 use ai::providers::open_ai::OpenAIEmbeddingProvider;
 use anyhow::{anyhow, Result};
@@ -125,8 +124,6 @@ pub struct SemanticIndex {
     _embedding_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
-    provider_credential: ProviderCredential,
-    embedding_queue: Arc<Mutex<EmbeddingQueue>>,
 }
 
 struct ProjectState {
@@ -281,24 +278,17 @@ impl SemanticIndex {
     }
 
     pub fn authenticate(&mut self, cx: &AppContext) -> bool {
-        let existing_credential = self.provider_credential.clone();
-        let credential = match existing_credential {
-            ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx),
-            _ => existing_credential,
-        };
+        if !self.embedding_provider.has_credentials() {
+            self.embedding_provider.retrieve_credentials(cx);
+        } else {
+            return true;
+        }
 
-        self.provider_credential = credential.clone();
-        self.embedding_queue.lock().set_credential(credential);
-        self.is_authenticated()
+        self.embedding_provider.has_credentials()
     }
 
     pub fn is_authenticated(&self) -> bool {
-        let credential = &self.provider_credential;
-        match credential {
-            &ProviderCredential::Credentials { .. } => true,
-            &ProviderCredential::NotNeeded => true,
-            _ => false,
-        }
+        self.embedding_provider.has_credentials()
     }
 
     pub fn enabled(cx: &AppContext) -> bool {
@@ -348,7 +338,7 @@ impl SemanticIndex {
         Ok(cx.add_model(|cx| {
             let t0 = Instant::now();
             let embedding_queue =
-                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials);
+                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
             let _embedding_task = cx.background().spawn({
                 let embedded_files = embedding_queue.finished_files();
                 let db = db.clone();
@@ -413,8 +403,6 @@ impl SemanticIndex {
                 _embedding_task,
                 _parsing_files_tasks,
                 projects: Default::default(),
-                provider_credential: ProviderCredential::NoCredentials,
-                embedding_queue
             }
         }))
     }
@@ -729,14 +717,13 @@ impl SemanticIndex {
 
         let index = self.index_project(project.clone(), cx);
         let embedding_provider = self.embedding_provider.clone();
-        let credential = self.provider_credential.clone();
 
         cx.spawn(|this, mut cx| async move {
             index.await?;
             let t0 = Instant::now();
 
             let query = embedding_provider
-                .embed_batch(vec![query], credential)
+                .embed_batch(vec![query])
                 .await?
                 .pop()
                 .ok_or_else(|| anyhow!("could not embed query"))?;
@@ -954,7 +941,6 @@ impl SemanticIndex {
         let fs = self.fs.clone();
         let db_path = self.db.path().clone();
         let background = cx.background().clone();
-        let credential = self.provider_credential.clone();
         cx.background().spawn(async move {
             let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
             let mut results = Vec::<SearchResult>::new();
@@ -969,15 +955,10 @@ impl SemanticIndex {
                     .parse_file_with_template(None, &snapshot.text(), language)
                     .log_err()
                     .unwrap_or_default();
-                if Self::embed_spans(
-                    &mut spans,
-                    embedding_provider.as_ref(),
-                    &db,
-                    credential.clone(),
-                )
-                .await
-                .log_err()
-                .is_some()
+                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
+                    .await
+                    .log_err()
+                    .is_some()
                 {
                     for span in spans {
                         let similarity = span.embedding.unwrap().similarity(&query);
@@ -1201,7 +1182,6 @@ impl SemanticIndex {
         spans: &mut [Span],
         embedding_provider: &dyn EmbeddingProvider,
         db: &VectorDatabase,
-        credential: ProviderCredential,
     ) -> Result<()> {
         let mut batch = Vec::new();
         let mut batch_tokens = 0;
@@ -1224,7 +1204,7 @@ impl SemanticIndex {
 
             if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
                 let batch_embeddings = embedding_provider
-                    .embed_batch(mem::take(&mut batch), credential.clone())
+                    .embed_batch(mem::take(&mut batch))
                     .await?;
                 embeddings.extend(batch_embeddings);
                 batch_tokens = 0;
@@ -1236,7 +1216,7 @@ impl SemanticIndex {
 
         if !batch.is_empty() {
             let batch_embeddings = embedding_provider
-                .embed_batch(mem::take(&mut batch), credential)
+                .embed_batch(mem::take(&mut batch))
                 .await?;
 
             embeddings.extend(batch_embeddings);

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -220,11 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
 
     let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 
-    let mut queue = EmbeddingQueue::new(
-        embedding_provider.clone(),
-        cx.background(),
-        ai::auth::ProviderCredential::NoCredentials,
-    );
+    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
     for file in &files {
         queue.push(file.clone());
     }