move keychain access into semantic index as opposed to on init (#3158)

Kyle Caverly created

remove keychain request during init

Release Notes:

- Move keychain request to inside indexing.
- Move install_default_formatters to a no op during tests

Change summary

crates/ai/src/embedding.rs                        | 71 +++++++--------
crates/project/src/project.rs                     | 31 ++++--
crates/semantic_index/src/embedding_queue.rs      | 15 +++
crates/semantic_index/src/semantic_index.rs       | 73 ++++++++++------
crates/semantic_index/src/semantic_index_tests.rs | 14 ++-
crates/zed/examples/semantic_index_eval.rs        | 19 ---
script/evaluate_semantic_index                    |  2 
7 files changed, 122 insertions(+), 103 deletions(-)

Detailed changes

crates/ai/src/embedding.rs 🔗

@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
 use async_trait::async_trait;
 use futures::AsyncReadExt;
 use gpui::executor::Background;
-use gpui::{serde_json, ViewContext};
+use gpui::{serde_json, AppContext};
 use isahc::http::StatusCode;
 use isahc::prelude::Configurable;
 use isahc::{AsyncBody, Response};
@@ -89,7 +89,6 @@ impl Embedding {
 
 #[derive(Clone)]
 pub struct OpenAIEmbeddings {
-    pub api_key: Option<String>,
     pub client: Arc<dyn HttpClient>,
     pub executor: Arc<Background>,
     rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@@ -123,8 +122,12 @@ struct OpenAIEmbeddingUsage {
 
 #[async_trait]
 pub trait EmbeddingProvider: Sync + Send {
-    fn is_authenticated(&self) -> bool;
-    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
+    fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
+    async fn embed_batch(
+        &self,
+        spans: Vec<String>,
+        api_key: Option<String>,
+    ) -> Result<Vec<Embedding>>;
     fn max_tokens_per_batch(&self) -> usize;
     fn truncate(&self, span: &str) -> (String, usize);
     fn rate_limit_expiration(&self) -> Option<Instant>;
@@ -134,13 +137,17 @@ pub struct DummyEmbeddings {}
 
 #[async_trait]
 impl EmbeddingProvider for DummyEmbeddings {
-    fn is_authenticated(&self) -> bool {
-        true
+    fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
+        Some("Dummy API KEY".to_string())
     }
     fn rate_limit_expiration(&self) -> Option<Instant> {
         None
     }
-    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+    async fn embed_batch(
+        &self,
+        spans: Vec<String>,
+        _api_key: Option<String>,
+    ) -> Result<Vec<Embedding>> {
         // 1024 is the OpenAI Embeddings size for ada models.
         // the model we will likely be starting with.
         let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
@@ -169,36 +176,11 @@ impl EmbeddingProvider for DummyEmbeddings {
 const OPENAI_INPUT_LIMIT: usize = 8190;
 
 impl OpenAIEmbeddings {
-    pub fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
-        if self.api_key.is_none() {
-            let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
-                Some(api_key)
-            } else if let Some((_, api_key)) = cx
-                .platform()
-                .read_credentials(OPENAI_API_URL)
-                .log_err()
-                .flatten()
-            {
-                String::from_utf8(api_key).log_err()
-            } else {
-                None
-            };
-
-            if let Some(api_key) = api_key {
-                self.api_key = Some(api_key);
-            }
-        }
-    }
-    pub fn new(
-        api_key: Option<String>,
-        client: Arc<dyn HttpClient>,
-        executor: Arc<Background>,
-    ) -> Self {
+    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
         let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
         let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
 
         OpenAIEmbeddings {
-            api_key,
             client,
             executor,
             rate_limit_count_rx,
@@ -264,8 +246,19 @@ impl OpenAIEmbeddings {
 
 #[async_trait]
 impl EmbeddingProvider for OpenAIEmbeddings {
-    fn is_authenticated(&self) -> bool {
-        self.api_key.is_some()
+    fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
+        if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+            Some(api_key)
+        } else if let Some((_, api_key)) = cx
+            .platform()
+            .read_credentials(OPENAI_API_URL)
+            .log_err()
+            .flatten()
+        {
+            String::from_utf8(api_key).log_err()
+        } else {
+            None
+        }
     }
 
     fn max_tokens_per_batch(&self) -> usize {
@@ -290,11 +283,15 @@ impl EmbeddingProvider for OpenAIEmbeddings {
         (output, tokens.len())
     }
 
-    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+    async fn embed_batch(
+        &self,
+        spans: Vec<String>,
+        api_key: Option<String>,
+    ) -> Result<Vec<Embedding>> {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
         const MAX_RETRIES: usize = 4;
 
-        let Some(api_key) = self.api_key.clone() else {
+        let Some(api_key) = api_key else {
             return Err(anyhow!("no open ai key provided"));
         };
 

crates/project/src/project.rs 🔗

@@ -53,7 +53,7 @@ use lsp::{
 use lsp_command::*;
 use node_runtime::NodeRuntime;
 use postage::watch;
-use prettier::{LocateStart, Prettier, PRETTIER_SERVER_FILE, PRETTIER_SERVER_JS};
+use prettier::{LocateStart, Prettier};
 use project_settings::{LspSettings, ProjectSettings};
 use rand::prelude::*;
 use search::SearchQuery;
@@ -79,13 +79,10 @@ use std::{
     time::{Duration, Instant},
 };
 use terminals::Terminals;
-use text::{Anchor, LineEnding, Rope};
+use text::Anchor;
 use util::{
-    debug_panic, defer,
-    http::HttpClient,
-    merge_json_value_into,
-    paths::{DEFAULT_PRETTIER_DIR, LOCAL_SETTINGS_RELATIVE_PATH},
-    post_inc, ResultExt, TryFutureExt as _,
+    debug_panic, defer, http::HttpClient, merge_json_value_into,
+    paths::LOCAL_SETTINGS_RELATIVE_PATH, post_inc, ResultExt, TryFutureExt as _,
 };
 
 pub use fs::*;
@@ -8489,6 +8486,18 @@ impl Project {
         }
     }
 
+    #[cfg(any(test, feature = "test-support"))]
+    fn install_default_formatters(
+        &self,
+        _worktree: Option<WorktreeId>,
+        _new_language: &Language,
+        _language_settings: &LanguageSettings,
+        _cx: &mut ModelContext<Self>,
+    ) -> Task<anyhow::Result<()>> {
+        return Task::ready(Ok(()));
+    }
+
+    #[cfg(not(any(test, feature = "test-support")))]
     fn install_default_formatters(
         &self,
         worktree: Option<WorktreeId>,
@@ -8519,7 +8528,7 @@ impl Project {
             return Task::ready(Ok(()));
         };
 
-        let default_prettier_dir = DEFAULT_PRETTIER_DIR.as_path();
+        let default_prettier_dir = util::paths::DEFAULT_PRETTIER_DIR.as_path();
         let already_running_prettier = self
             .prettier_instances
             .get(&(worktree, default_prettier_dir.to_path_buf()))
@@ -8528,10 +8537,10 @@ impl Project {
         let fs = Arc::clone(&self.fs);
         cx.background()
             .spawn(async move {
-                let prettier_wrapper_path = default_prettier_dir.join(PRETTIER_SERVER_FILE);
+                let prettier_wrapper_path = default_prettier_dir.join(prettier::PRETTIER_SERVER_FILE);
                 // method creates parent directory if it doesn't exist
-                fs.save(&prettier_wrapper_path, &Rope::from(PRETTIER_SERVER_JS), LineEnding::Unix).await
-                .with_context(|| format!("writing {PRETTIER_SERVER_FILE} file at {prettier_wrapper_path:?}"))?;
+                fs.save(&prettier_wrapper_path, &text::Rope::from(prettier::PRETTIER_SERVER_JS), text::LineEnding::Unix).await
+                .with_context(|| format!("writing {} file at {prettier_wrapper_path:?}", prettier::PRETTIER_SERVER_FILE))?;
 
                 let packages_to_versions = future::try_join_all(
                     prettier_plugins

crates/semantic_index/src/embedding_queue.rs 🔗

@@ -41,6 +41,7 @@ pub struct EmbeddingQueue {
     pending_batch_token_count: usize,
     finished_files_tx: channel::Sender<FileToEmbed>,
     finished_files_rx: channel::Receiver<FileToEmbed>,
+    api_key: Option<String>,
 }
 
 #[derive(Clone)]
@@ -50,7 +51,11 @@ pub struct FileFragmentToEmbed {
 }
 
 impl EmbeddingQueue {
-    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
+    pub fn new(
+        embedding_provider: Arc<dyn EmbeddingProvider>,
+        executor: Arc<Background>,
+        api_key: Option<String>,
+    ) -> Self {
         let (finished_files_tx, finished_files_rx) = channel::unbounded();
         Self {
             embedding_provider,
@@ -59,9 +64,14 @@ impl EmbeddingQueue {
             pending_batch_token_count: 0,
             finished_files_tx,
             finished_files_rx,
+            api_key,
         }
     }
 
+    pub fn set_api_key(&mut self, api_key: Option<String>) {
+        self.api_key = api_key
+    }
+
     pub fn push(&mut self, file: FileToEmbed) {
         if file.spans.is_empty() {
             self.finished_files_tx.try_send(file).unwrap();
@@ -108,6 +118,7 @@ impl EmbeddingQueue {
 
         let finished_files_tx = self.finished_files_tx.clone();
         let embedding_provider = self.embedding_provider.clone();
+        let api_key = self.api_key.clone();
 
         self.executor
             .spawn(async move {
@@ -132,7 +143,7 @@ impl EmbeddingQueue {
                     return;
                 };
 
-                match embedding_provider.embed_batch(spans).await {
+                match embedding_provider.embed_batch(spans, api_key).await {
                     Ok(embeddings) => {
                         let mut embeddings = embeddings.into_iter();
                         for fragment in batch {

crates/semantic_index/src/semantic_index.rs 🔗

@@ -7,10 +7,7 @@ pub mod semantic_index_settings;
 mod semantic_index_tests;
 
 use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::{
-    completion::OPENAI_API_URL,
-    embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings},
-};
+use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
 use anyhow::{anyhow, Result};
 use collections::{BTreeMap, HashMap, HashSet};
 use db::VectorDatabase;
@@ -58,19 +55,6 @@ pub fn init(
         .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
         .join("embeddings_db");
 
-    let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
-        Some(api_key)
-    } else if let Some((_, api_key)) = cx
-        .platform()
-        .read_credentials(OPENAI_API_URL)
-        .log_err()
-        .flatten()
-    {
-        String::from_utf8(api_key).log_err()
-    } else {
-        None
-    };
-
     cx.subscribe_global::<WorkspaceCreated, _>({
         move |event, cx| {
             let Some(semantic_index) = SemanticIndex::global(cx) else {
@@ -104,7 +88,7 @@ pub fn init(
         let semantic_index = SemanticIndex::new(
             fs,
             db_file_path,
-            Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
+            Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
             language_registry,
             cx.clone(),
         )
@@ -139,6 +123,8 @@ pub struct SemanticIndex {
     _embedding_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
+    api_key: Option<String>,
+    embedding_queue: Arc<Mutex<EmbeddingQueue>>,
 }
 
 struct ProjectState {
@@ -284,7 +270,7 @@ pub struct SearchResult {
 }
 
 impl SemanticIndex {
-    pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
+    pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
         if cx.has_global::<ModelHandle<Self>>() {
             Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
         } else {
@@ -292,12 +278,26 @@ impl SemanticIndex {
         }
     }
 
+    pub fn authenticate(&mut self, cx: &AppContext) {
+        if self.api_key.is_none() {
+            self.api_key = self.embedding_provider.retrieve_credentials(cx);
+
+            self.embedding_queue
+                .lock()
+                .set_api_key(self.api_key.clone());
+        }
+    }
+
+    pub fn is_authenticated(&self) -> bool {
+        self.api_key.is_some()
+    }
+
     pub fn enabled(cx: &AppContext) -> bool {
         settings::get::<SemanticIndexSettings>(cx).enabled
     }
 
     pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
-        if !self.embedding_provider.is_authenticated() {
+        if !self.is_authenticated() {
             return SemanticIndexStatus::NotAuthenticated;
         }
 
@@ -339,7 +339,7 @@ impl SemanticIndex {
         Ok(cx.add_model(|cx| {
             let t0 = Instant::now();
             let embedding_queue =
-                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
+                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
             let _embedding_task = cx.background().spawn({
                 let embedded_files = embedding_queue.finished_files();
                 let db = db.clone();
@@ -404,6 +404,8 @@ impl SemanticIndex {
                 _embedding_task,
                 _parsing_files_tasks,
                 projects: Default::default(),
+                api_key: None,
+                embedding_queue
             }
         }))
     }
@@ -718,12 +720,13 @@ impl SemanticIndex {
 
         let index = self.index_project(project.clone(), cx);
         let embedding_provider = self.embedding_provider.clone();
+        let api_key = self.api_key.clone();
 
         cx.spawn(|this, mut cx| async move {
             index.await?;
             let t0 = Instant::now();
             let query = embedding_provider
-                .embed_batch(vec![query])
+                .embed_batch(vec![query], api_key)
                 .await?
                 .pop()
                 .ok_or_else(|| anyhow!("could not embed query"))?;
@@ -941,6 +944,7 @@ impl SemanticIndex {
         let fs = self.fs.clone();
         let db_path = self.db.path().clone();
         let background = cx.background().clone();
+        let api_key = self.api_key.clone();
         cx.background().spawn(async move {
             let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
             let mut results = Vec::<SearchResult>::new();
@@ -955,10 +959,15 @@ impl SemanticIndex {
                     .parse_file_with_template(None, &snapshot.text(), language)
                     .log_err()
                     .unwrap_or_default();
-                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
-                    .await
-                    .log_err()
-                    .is_some()
+                if Self::embed_spans(
+                    &mut spans,
+                    embedding_provider.as_ref(),
+                    &db,
+                    api_key.clone(),
+                )
+                .await
+                .log_err()
+                .is_some()
                 {
                     for span in spans {
                         let similarity = span.embedding.unwrap().similarity(&query);
@@ -998,8 +1007,11 @@ impl SemanticIndex {
         project: ModelHandle<Project>,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<()>> {
-        if !self.embedding_provider.is_authenticated() {
-            return Task::ready(Err(anyhow!("user is not authenticated")));
+        if self.api_key.is_none() {
+            self.authenticate(cx);
+            if self.api_key.is_none() {
+                return Task::ready(Err(anyhow!("user is not authenticated")));
+            }
         }
 
         if !self.projects.contains_key(&project.downgrade()) {
@@ -1180,6 +1192,7 @@ impl SemanticIndex {
         spans: &mut [Span],
         embedding_provider: &dyn EmbeddingProvider,
         db: &VectorDatabase,
+        api_key: Option<String>,
     ) -> Result<()> {
         let mut batch = Vec::new();
         let mut batch_tokens = 0;
@@ -1202,7 +1215,7 @@ impl SemanticIndex {
 
             if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
                 let batch_embeddings = embedding_provider
-                    .embed_batch(mem::take(&mut batch))
+                    .embed_batch(mem::take(&mut batch), api_key.clone())
                     .await?;
                 embeddings.extend(batch_embeddings);
                 batch_tokens = 0;
@@ -1214,7 +1227,7 @@ impl SemanticIndex {
 
         if !batch.is_empty() {
             let batch_embeddings = embedding_provider
-                .embed_batch(mem::take(&mut batch))
+                .embed_batch(mem::take(&mut batch), api_key)
                 .await?;
 
             embeddings.extend(batch_embeddings);

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -7,7 +7,7 @@ use crate::{
 use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
 use anyhow::Result;
 use async_trait::async_trait;
-use gpui::{executor::Deterministic, Task, TestAppContext};
+use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
 use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
 use parking_lot::Mutex;
 use pretty_assertions::assert_eq;
@@ -228,7 +228,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
 
     let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 
-    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
+    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
     for file in &files {
         queue.push(file.clone());
     }
@@ -1281,8 +1281,8 @@ impl FakeEmbeddingProvider {
 
 #[async_trait]
 impl EmbeddingProvider for FakeEmbeddingProvider {
-    fn is_authenticated(&self) -> bool {
-        true
+    fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
+        Some("Fake Credentials".to_string())
     }
     fn truncate(&self, span: &str) -> (String, usize) {
         (span.to_string(), 1)
@@ -1296,7 +1296,11 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
         None
     }
 
-    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+    async fn embed_batch(
+        &self,
+        spans: Vec<String>,
+        _api_key: Option<String>,
+    ) -> Result<Vec<Embedding>> {
         self.embedding_count
             .fetch_add(spans.len(), atomic::Ordering::SeqCst);
         Ok(spans.iter().map(|span| self.embed_sync(span)).collect())

crates/zed/examples/semantic_index_eval.rs 🔗

@@ -1,4 +1,3 @@
-use ai::completion::OPENAI_API_URL;
 use ai::embedding::OpenAIEmbeddings;
 use anyhow::{anyhow, Result};
 use client::{self, UserStore};
@@ -18,7 +17,6 @@ use std::{cmp, env, fs};
 use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
 use util::http::{self};
 use util::paths::EMBEDDINGS_DIR;
-use util::ResultExt;
 use zed::languages;
 
 #[derive(Deserialize, Clone, Serialize)]
@@ -57,7 +55,7 @@ fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
         .as_path()
         .parent()
         .unwrap()
-        .join("crates/semantic_index/eval");
+        .join("zed/crates/semantic_index/eval");
 
     let mut repo_evals: Vec<RepoEval> = Vec::new();
     for entry in fs::read_dir(eval_folder)? {
@@ -472,25 +470,12 @@ fn main() {
 
         let languages = languages.clone();
 
-        let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
-            Some(api_key)
-        } else if let Some((_, api_key)) = cx
-            .platform()
-            .read_credentials(OPENAI_API_URL)
-            .log_err()
-            .flatten()
-        {
-            String::from_utf8(api_key).log_err()
-        } else {
-            None
-        };
-
         let fs = fs.clone();
         cx.spawn(|mut cx| async move {
             let semantic_index = SemanticIndex::new(
                 fs.clone(),
                 db_file_path,
-                Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
+                Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
                 languages.clone(),
                 cx.clone(),
             )

script/evaluate_semantic_index 🔗

@@ -1,3 +1,3 @@
 #!/bin/bash
 
-RUST_LOG=semantic_index=trace cargo run -p semantic_index --example eval --release
+RUST_LOG=semantic_index=trace cargo run --example semantic_index_eval --release