update semantic search to use keychain as fallback (#3151)

Kyle Caverly created

Use the keychain for authenticating as fallback when api_key is not
present in environment variables.

Release Notes:

- Add consistency between OPENAI_API_KEY management in Semantic Search
and Assistant

Change summary

crates/ai/src/embedding.rs                  | 45 ++++++++++++++++++----
crates/search/src/project_search.rs         | 25 ++++++------
crates/semantic_index/src/semantic_index.rs | 20 +++++++++-
crates/zed/examples/semantic_index_eval.rs  | 18 ++++++++
4 files changed, 85 insertions(+), 23 deletions(-)

Detailed changes

crates/ai/src/embedding.rs 🔗

@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
 use async_trait::async_trait;
 use futures::AsyncReadExt;
 use gpui::executor::Background;
-use gpui::serde_json;
+use gpui::{serde_json, ViewContext};
 use isahc::http::StatusCode;
 use isahc::prelude::Configurable;
 use isahc::{AsyncBody, Response};
@@ -20,9 +20,11 @@ use std::sync::Arc;
 use std::time::{Duration, Instant};
 use tiktoken_rs::{cl100k_base, CoreBPE};
 use util::http::{HttpClient, Request};
+use util::ResultExt;
+
+use crate::completion::OPENAI_API_URL;
 
 lazy_static! {
-    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
     static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 }
 
@@ -87,6 +89,7 @@ impl Embedding {
 
 #[derive(Clone)]
 pub struct OpenAIEmbeddings {
+    pub api_key: Option<String>,
     pub client: Arc<dyn HttpClient>,
     pub executor: Arc<Background>,
     rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@@ -166,11 +169,36 @@ impl EmbeddingProvider for DummyEmbeddings {
 const OPENAI_INPUT_LIMIT: usize = 8190;
 
 impl OpenAIEmbeddings {
-    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
+    pub fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
+        if self.api_key.is_none() {
+            let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+                Some(api_key)
+            } else if let Some((_, api_key)) = cx
+                .platform()
+                .read_credentials(OPENAI_API_URL)
+                .log_err()
+                .flatten()
+            {
+                String::from_utf8(api_key).log_err()
+            } else {
+                None
+            };
+
+            if let Some(api_key) = api_key {
+                self.api_key = Some(api_key);
+            }
+        }
+    }
+    pub fn new(
+        api_key: Option<String>,
+        client: Arc<dyn HttpClient>,
+        executor: Arc<Background>,
+    ) -> Self {
         let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
         let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 
         OpenAIEmbeddings {
+            api_key,
             client,
             executor,
             rate_limit_count_rx,
@@ -237,8 +265,9 @@ impl OpenAIEmbeddings {
 #[async_trait]
 impl EmbeddingProvider for OpenAIEmbeddings {
     fn is_authenticated(&self) -> bool {
-        OPENAI_API_KEY.as_ref().is_some()
+        self.api_key.is_some()
     }
+
     fn max_tokens_per_batch(&self) -> usize {
         50000
     }
@@ -265,9 +294,9 @@ impl EmbeddingProvider for OpenAIEmbeddings {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
         const MAX_RETRIES: usize = 4;
 
-        let api_key = OPENAI_API_KEY
-            .as_ref()
-            .ok_or_else(|| anyhow!("no api key"))?;
+        let Some(api_key) = self.api_key.clone() else {
+            return Err(anyhow!("no open ai key provided"));
+        };
 
         let mut request_number = 0;
         let mut rate_limiting = false;
@@ -276,7 +305,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
         while request_number < MAX_RETRIES {
             response = self
                 .send_request(
-                    api_key,
+                    &api_key,
                     spans.iter().map(|x| &**x).collect(),
                     request_timeout,
                 )

crates/search/src/project_search.rs 🔗

@@ -351,33 +351,32 @@ impl View for ProjectSearchView {
                     SemanticIndexStatus::NotAuthenticated => {
                         major_text = Cow::Borrowed("Not Authenticated");
                         show_minor_text = false;
-                        Some(
-                            "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables"
-                                .to_string(),
-                        )
+                        Some(vec![
+                            "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables."
+                                .to_string(), "If you authenticated using the Assistant Panel, please restart Zed to Authenticate.".to_string()])
                     }
-                    SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
+                    SemanticIndexStatus::Indexed => Some(vec!["Indexing complete".to_string()]),
                     SemanticIndexStatus::Indexing {
                         remaining_files,
                         rate_limit_expiry,
                     } => {
                         if remaining_files == 0 {
-                            Some(format!("Indexing..."))
+                            Some(vec![format!("Indexing...")])
                         } else {
                             if let Some(rate_limit_expiry) = rate_limit_expiry {
                                 let remaining_seconds =
                                     rate_limit_expiry.duration_since(Instant::now());
                                 if remaining_seconds > Duration::from_secs(0) {
-                                    Some(format!(
+                                    Some(vec![format!(
                                         "Remaining files to index (rate limit resets in {}s): {}",
                                         remaining_seconds.as_secs(),
                                         remaining_files
-                                    ))
+                                    )])
                                 } else {
-                                    Some(format!("Remaining files to index: {}", remaining_files))
+                                    Some(vec![format!("Remaining files to index: {}", remaining_files)])
                                 }
                             } else {
-                                Some(format!("Remaining files to index: {}", remaining_files))
+                                Some(vec![format!("Remaining files to index: {}", remaining_files)])
                             }
                         }
                     }
@@ -394,9 +393,11 @@ impl View for ProjectSearchView {
             } else {
                 match current_mode {
                     SearchMode::Semantic => {
-                        let mut minor_text = Vec::new();
+                        let mut minor_text: Vec<String> = Vec::new();
                         minor_text.push("".into());
-                        minor_text.extend(semantic_status);
+                        if let Some(semantic_status) = semantic_status {
+                            minor_text.extend(semantic_status);
+                        }
                         if show_minor_text {
                             minor_text
                                 .push("Simply explain the code you are looking to find.".into());

crates/semantic_index/src/semantic_index.rs 🔗

@@ -7,7 +7,10 @@ pub mod semantic_index_settings;
 mod semantic_index_tests;
 
 use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
+use ai::{
+    completion::OPENAI_API_URL,
+    embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings},
+};
 use anyhow::{anyhow, Result};
 use collections::{BTreeMap, HashMap, HashSet};
 use db::VectorDatabase;
@@ -55,6 +58,19 @@ pub fn init(
         .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
         .join("embeddings_db");
 
+    let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+        Some(api_key)
+    } else if let Some((_, api_key)) = cx
+        .platform()
+        .read_credentials(OPENAI_API_URL)
+        .log_err()
+        .flatten()
+    {
+        String::from_utf8(api_key).log_err()
+    } else {
+        None
+    };
+
     cx.subscribe_global::<WorkspaceCreated, _>({
         move |event, cx| {
             let Some(semantic_index) = SemanticIndex::global(cx) else {
@@ -88,7 +104,7 @@ pub fn init(
         let semantic_index = SemanticIndex::new(
             fs,
             db_file_path,
-            Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+            Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
             language_registry,
             cx.clone(),
         )

crates/zed/examples/semantic_index_eval.rs 🔗

@@ -1,3 +1,4 @@
+use ai::completion::OPENAI_API_URL;
 use ai::embedding::OpenAIEmbeddings;
 use anyhow::{anyhow, Result};
 use client::{self, UserStore};
@@ -17,6 +18,7 @@ use std::{cmp, env, fs};
 use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 use util::http::{self};
 use util::paths::EMBEDDINGS_DIR;
+use util::ResultExt;
 use zed::languages;
 
 #[derive(Deserialize, Clone, Serialize)]
@@ -469,12 +471,26 @@ fn main() {
             .join("embeddings_db");
 
         let languages = languages.clone();
+
+        let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+            Some(api_key)
+        } else if let Some((_, api_key)) = cx
+            .platform()
+            .read_credentials(OPENAI_API_URL)
+            .log_err()
+            .flatten()
+        {
+            String::from_utf8(api_key).log_err()
+        } else {
+            None
+        };
+
         let fs = fs.clone();
         cx.spawn(|mut cx| async move {
             let semantic_index = SemanticIndex::new(
                 fs.clone(),
                 db_file_path,
-                Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+                Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
                 languages.clone(),
                 cx.clone(),
             )