update semantic search to use keychain as fallback

KCaverly created

Change summary

crates/ai/src/embedding.rs                  | 45 ++++++++++++++++++----
crates/search/src/project_search.rs         |  2 
crates/semantic_index/src/semantic_index.rs | 20 +++++++++-
crates/zed/examples/semantic_index_eval.rs  | 18 ++++++++
4 files changed, 73 insertions(+), 12 deletions(-)

Detailed changes

crates/ai/src/embedding.rs 🔗

@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
 use async_trait::async_trait;
 use futures::AsyncReadExt;
 use gpui::executor::Background;
-use gpui::serde_json;
+use gpui::{serde_json, ViewContext};
 use isahc::http::StatusCode;
 use isahc::prelude::Configurable;
 use isahc::{AsyncBody, Response};
@@ -20,9 +20,11 @@ use std::sync::Arc;
 use std::time::{Duration, Instant};
 use tiktoken_rs::{cl100k_base, CoreBPE};
 use util::http::{HttpClient, Request};
+use util::ResultExt;
+
+use crate::completion::OPENAI_API_URL;
 
 lazy_static! {
-    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
     static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 }
 
@@ -87,6 +89,7 @@ impl Embedding {
 
 #[derive(Clone)]
 pub struct OpenAIEmbeddings {
+    pub api_key: Option<String>,
     pub client: Arc<dyn HttpClient>,
     pub executor: Arc<Background>,
     rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@@ -166,11 +169,36 @@ impl EmbeddingProvider for DummyEmbeddings {
 const OPENAI_INPUT_LIMIT: usize = 8190;
 
 impl OpenAIEmbeddings {
-    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
+    pub fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
+        if self.api_key.is_none() {
+            let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+                Some(api_key)
+            } else if let Some((_, api_key)) = cx
+                .platform()
+                .read_credentials(OPENAI_API_URL)
+                .log_err()
+                .flatten()
+            {
+                String::from_utf8(api_key).log_err()
+            } else {
+                None
+            };
+
+            if let Some(api_key) = api_key {
+                self.api_key = Some(api_key);
+            }
+        }
+    }
+    pub fn new(
+        api_key: Option<String>,
+        client: Arc<dyn HttpClient>,
+        executor: Arc<Background>,
+    ) -> Self {
         let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
         let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 
         OpenAIEmbeddings {
+            api_key,
             client,
             executor,
             rate_limit_count_rx,
@@ -237,8 +265,9 @@ impl OpenAIEmbeddings {
 #[async_trait]
 impl EmbeddingProvider for OpenAIEmbeddings {
     fn is_authenticated(&self) -> bool {
-        OPENAI_API_KEY.as_ref().is_some()
+        self.api_key.is_some()
     }
+
     fn max_tokens_per_batch(&self) -> usize {
         50000
     }
@@ -265,9 +294,9 @@ impl EmbeddingProvider for OpenAIEmbeddings {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
         const MAX_RETRIES: usize = 4;
 
-        let api_key = OPENAI_API_KEY
-            .as_ref()
-            .ok_or_else(|| anyhow!("no api key"))?;
+        let Some(api_key) = self.api_key.clone() else {
+            return Err(anyhow!("no open ai key provided"));
+        };
 
         let mut request_number = 0;
         let mut rate_limiting = false;
@@ -276,7 +305,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
         while request_number < MAX_RETRIES {
             response = self
                 .send_request(
-                    api_key,
+                    &api_key,
                     spans.iter().map(|x| &**x).collect(),
                     request_timeout,
                 )

crates/search/src/project_search.rs 🔗

@@ -352,7 +352,7 @@ impl View for ProjectSearchView {
                         major_text = Cow::Borrowed("Not Authenticated");
                         show_minor_text = false;
                         Some(
-                            "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables"
+                            "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables.\nIf this variable was set using the Assistant Panel, please restart Zed to Authenticate."
                                 .to_string(),
                         )
                     }

crates/semantic_index/src/semantic_index.rs 🔗

@@ -7,7 +7,10 @@ pub mod semantic_index_settings;
 mod semantic_index_tests;
 
 use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
+use ai::{
+    completion::OPENAI_API_URL,
+    embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings},
+};
 use anyhow::{anyhow, Result};
 use collections::{BTreeMap, HashMap, HashSet};
 use db::VectorDatabase;
@@ -55,6 +58,19 @@ pub fn init(
         .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
         .join("embeddings_db");
 
+    let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+        Some(api_key)
+    } else if let Some((_, api_key)) = cx
+        .platform()
+        .read_credentials(OPENAI_API_URL)
+        .log_err()
+        .flatten()
+    {
+        String::from_utf8(api_key).log_err()
+    } else {
+        None
+    };
+
     cx.subscribe_global::<WorkspaceCreated, _>({
         move |event, cx| {
             let Some(semantic_index) = SemanticIndex::global(cx) else {
@@ -88,7 +104,7 @@ pub fn init(
         let semantic_index = SemanticIndex::new(
             fs,
             db_file_path,
-            Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+            Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
             language_registry,
             cx.clone(),
         )

crates/zed/examples/semantic_index_eval.rs 🔗

@@ -1,3 +1,4 @@
+use ai::completion::OPENAI_API_URL;
 use ai::embedding::OpenAIEmbeddings;
 use anyhow::{anyhow, Result};
 use client::{self, UserStore};
@@ -17,6 +18,7 @@ use std::{cmp, env, fs};
 use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 use util::http::{self};
 use util::paths::EMBEDDINGS_DIR;
+use util::ResultExt;
 use zed::languages;
 
 #[derive(Deserialize, Clone, Serialize)]
@@ -469,12 +471,26 @@ fn main() {
             .join("embeddings_db");
 
         let languages = languages.clone();
+
+        let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+            Some(api_key)
+        } else if let Some((_, api_key)) = cx
+            .platform()
+            .read_credentials(OPENAI_API_URL)
+            .log_err()
+            .flatten()
+        {
+            String::from_utf8(api_key).log_err()
+        } else {
+            None
+        };
+
         let fs = fs.clone();
         cx.spawn(|mut cx| async move {
             let semantic_index = SemanticIndex::new(
                 fs.clone(),
                 db_file_path,
-                Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+                Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
                 languages.clone(),
                 cx.clone(),
             )