ensure OpenAIEmbeddingProvider is using the provider credentials

KCaverly created

Change summary

crates/ai/src/providers/open_ai/embedding.rs | 11 ++++++-----
crates/semantic_index/src/embedding_queue.rs |  2 +-
crates/semantic_index/src/semantic_index.rs  | 17 ++++++-----------
3 files changed, 13 insertions(+), 17 deletions(-)

Detailed changes

crates/ai/src/providers/open_ai/embedding.rs 🔗

@@ -162,14 +162,15 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
     async fn embed_batch(
         &self,
         spans: Vec<String>,
-        _credential: ProviderCredential,
+        credential: ProviderCredential,
     ) -> Result<Vec<Embedding>> {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
         const MAX_RETRIES: usize = 4;
 
-        let api_key = OPENAI_API_KEY
-            .as_ref()
-            .ok_or_else(|| anyhow!("no api key"))?;
+        let api_key = match credential {
+            ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key),
+            _ => Err(anyhow!("no api key provided")),
+        }?;
 
         let mut request_number = 0;
         let mut rate_limiting = false;
@@ -178,7 +179,7 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
         while request_number < MAX_RETRIES {
             response = self
                 .send_request(
-                    api_key,
+                    &api_key,
                     spans.iter().map(|x| &**x).collect(),
                     request_timeout,
                 )

crates/semantic_index/src/embedding_queue.rs 🔗

@@ -41,7 +41,7 @@ pub struct EmbeddingQueue {
     pending_batch_token_count: usize,
     finished_files_tx: channel::Sender<FileToEmbed>,
     finished_files_rx: channel::Receiver<FileToEmbed>,
-    provider_credential: ProviderCredential,
+    pub provider_credential: ProviderCredential,
 }
 
 #[derive(Clone)]

crates/semantic_index/src/semantic_index.rs 🔗

@@ -281,15 +281,13 @@ impl SemanticIndex {
     }
 
     pub fn authenticate(&mut self, cx: &AppContext) -> bool {
-        let credential = self.provider_credential.clone();
-        match credential {
-            ProviderCredential::NoCredentials => {
-                let credential = self.embedding_provider.retrieve_credentials(cx);
-                self.provider_credential = credential;
-            }
-            _ => {}
-        }
+        let existing_credential = self.provider_credential.clone();
+        let credential = match existing_credential {
+            ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx),
+            _ => existing_credential,
+        };
 
+        self.provider_credential = credential.clone();
         self.embedding_queue.lock().set_credential(credential);
         self.is_authenticated()
     }
@@ -1020,14 +1018,11 @@ impl SemanticIndex {
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<()>> {
         if !self.is_authenticated() {
-            println!("Authenticating");
             if !self.authenticate(cx) {
                 return Task::ready(Err(anyhow!("user is not authenticated")));
             }
         }
 
-        println!("SHOULD NOW BE AUTHENTICATED");
-
         if !self.projects.contains_key(&project.downgrade()) {
             let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
                 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {