Detailed changes
@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::executor::Background;
-use gpui::{serde_json, ViewContext};
+use gpui::{serde_json, AppContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
@@ -89,7 +89,6 @@ impl Embedding {
#[derive(Clone)]
pub struct OpenAIEmbeddings {
- pub api_key: Option<String>,
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@@ -123,8 +122,12 @@ struct OpenAIEmbeddingUsage {
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
- fn is_authenticated(&self) -> bool;
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
+ fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
+ async fn embed_batch(
+ &self,
+ spans: Vec<String>,
+ api_key: Option<String>,
+ ) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn truncate(&self, span: &str) -> (String, usize);
fn rate_limit_expiration(&self) -> Option<Instant>;
@@ -134,13 +137,17 @@ pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
- fn is_authenticated(&self) -> bool {
- true
+ fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
+ Some("Dummy API KEY".to_string())
}
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+ async fn embed_batch(
+ &self,
+ spans: Vec<String>,
+ _api_key: Option<String>,
+ ) -> Result<Vec<Embedding>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
@@ -169,36 +176,11 @@ impl EmbeddingProvider for DummyEmbeddings {
const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddings {
- pub fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
- if self.api_key.is_none() {
- let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
- Some(api_key)
- } else if let Some((_, api_key)) = cx
- .platform()
- .read_credentials(OPENAI_API_URL)
- .log_err()
- .flatten()
- {
- String::from_utf8(api_key).log_err()
- } else {
- None
- };
-
- if let Some(api_key) = api_key {
- self.api_key = Some(api_key);
- }
- }
- }
- pub fn new(
- api_key: Option<String>,
- client: Arc<dyn HttpClient>,
- executor: Arc<Background>,
- ) -> Self {
+ pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
OpenAIEmbeddings {
- api_key,
client,
executor,
rate_limit_count_rx,
@@ -264,8 +246,19 @@ impl OpenAIEmbeddings {
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
- fn is_authenticated(&self) -> bool {
- self.api_key.is_some()
+ fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
+ if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ Some(api_key)
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ String::from_utf8(api_key).log_err()
+ } else {
+ None
+ }
}
fn max_tokens_per_batch(&self) -> usize {
@@ -290,11 +283,15 @@ impl EmbeddingProvider for OpenAIEmbeddings {
(output, tokens.len())
}
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+ async fn embed_batch(
+ &self,
+ spans: Vec<String>,
+ api_key: Option<String>,
+ ) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
- let Some(api_key) = self.api_key.clone() else {
+ let Some(api_key) = api_key else {
return Err(anyhow!("no open ai key provided"));
};
@@ -53,7 +53,7 @@ use lsp::{
use lsp_command::*;
use node_runtime::NodeRuntime;
use postage::watch;
-use prettier::{LocateStart, Prettier, PRETTIER_SERVER_FILE, PRETTIER_SERVER_JS};
+use prettier::{LocateStart, Prettier};
use project_settings::{LspSettings, ProjectSettings};
use rand::prelude::*;
use search::SearchQuery;
@@ -79,13 +79,10 @@ use std::{
time::{Duration, Instant},
};
use terminals::Terminals;
-use text::{Anchor, LineEnding, Rope};
+use text::Anchor;
use util::{
- debug_panic, defer,
- http::HttpClient,
- merge_json_value_into,
- paths::{DEFAULT_PRETTIER_DIR, LOCAL_SETTINGS_RELATIVE_PATH},
- post_inc, ResultExt, TryFutureExt as _,
+ debug_panic, defer, http::HttpClient, merge_json_value_into,
+ paths::LOCAL_SETTINGS_RELATIVE_PATH, post_inc, ResultExt, TryFutureExt as _,
};
pub use fs::*;
@@ -8489,6 +8486,18 @@ impl Project {
}
}
+ #[cfg(any(test, feature = "test-support"))]
+ fn install_default_formatters(
+ &self,
+ _worktree: Option<WorktreeId>,
+ _new_language: &Language,
+ _language_settings: &LanguageSettings,
+ _cx: &mut ModelContext<Self>,
+ ) -> Task<anyhow::Result<()>> {
+ return Task::ready(Ok(()));
+ }
+
+ #[cfg(not(any(test, feature = "test-support")))]
fn install_default_formatters(
&self,
worktree: Option<WorktreeId>,
@@ -8519,7 +8528,7 @@ impl Project {
return Task::ready(Ok(()));
};
- let default_prettier_dir = DEFAULT_PRETTIER_DIR.as_path();
+ let default_prettier_dir = util::paths::DEFAULT_PRETTIER_DIR.as_path();
let already_running_prettier = self
.prettier_instances
.get(&(worktree, default_prettier_dir.to_path_buf()))
@@ -8528,10 +8537,10 @@ impl Project {
let fs = Arc::clone(&self.fs);
cx.background()
.spawn(async move {
- let prettier_wrapper_path = default_prettier_dir.join(PRETTIER_SERVER_FILE);
+ let prettier_wrapper_path = default_prettier_dir.join(prettier::PRETTIER_SERVER_FILE);
// method creates parent directory if it doesn't exist
- fs.save(&prettier_wrapper_path, &Rope::from(PRETTIER_SERVER_JS), LineEnding::Unix).await
- .with_context(|| format!("writing {PRETTIER_SERVER_FILE} file at {prettier_wrapper_path:?}"))?;
+ fs.save(&prettier_wrapper_path, &text::Rope::from(prettier::PRETTIER_SERVER_JS), text::LineEnding::Unix).await
+ .with_context(|| format!("writing {} file at {prettier_wrapper_path:?}", prettier::PRETTIER_SERVER_FILE))?;
let packages_to_versions = future::try_join_all(
prettier_plugins
@@ -41,6 +41,7 @@ pub struct EmbeddingQueue {
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
+ api_key: Option<String>,
}
#[derive(Clone)]
@@ -50,7 +51,11 @@ pub struct FileFragmentToEmbed {
}
impl EmbeddingQueue {
- pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
+ pub fn new(
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ executor: Arc<Background>,
+ api_key: Option<String>,
+ ) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,
@@ -59,9 +64,14 @@ impl EmbeddingQueue {
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
+ api_key,
}
}
+ pub fn set_api_key(&mut self, api_key: Option<String>) {
+ self.api_key = api_key
+ }
+
pub fn push(&mut self, file: FileToEmbed) {
if file.spans.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
@@ -108,6 +118,7 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
+ let api_key = self.api_key.clone();
self.executor
.spawn(async move {
@@ -132,7 +143,7 @@ impl EmbeddingQueue {
return;
};
- match embedding_provider.embed_batch(spans).await {
+ match embedding_provider.embed_batch(spans, api_key).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
@@ -7,10 +7,7 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::{
- completion::OPENAI_API_URL,
- embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings},
-};
+use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase;
@@ -58,19 +55,6 @@ pub fn init(
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
.join("embeddings_db");
- let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
- Some(api_key)
- } else if let Some((_, api_key)) = cx
- .platform()
- .read_credentials(OPENAI_API_URL)
- .log_err()
- .flatten()
- {
- String::from_utf8(api_key).log_err()
- } else {
- None
- };
-
cx.subscribe_global::<WorkspaceCreated, _>({
move |event, cx| {
let Some(semantic_index) = SemanticIndex::global(cx) else {
@@ -104,7 +88,7 @@ pub fn init(
let semantic_index = SemanticIndex::new(
fs,
db_file_path,
- Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
+ Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
language_registry,
cx.clone(),
)
@@ -139,6 +123,8 @@ pub struct SemanticIndex {
_embedding_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
+ api_key: Option<String>,
+ embedding_queue: Arc<Mutex<EmbeddingQueue>>,
}
struct ProjectState {
@@ -284,7 +270,7 @@ pub struct SearchResult {
}
impl SemanticIndex {
- pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
+ pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
if cx.has_global::<ModelHandle<Self>>() {
Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
} else {
@@ -292,12 +278,26 @@ impl SemanticIndex {
}
}
+ pub fn authenticate(&mut self, cx: &AppContext) {
+ if self.api_key.is_none() {
+ self.api_key = self.embedding_provider.retrieve_credentials(cx);
+
+ self.embedding_queue
+ .lock()
+ .set_api_key(self.api_key.clone());
+ }
+ }
+
+ pub fn is_authenticated(&self) -> bool {
+ self.api_key.is_some()
+ }
+
pub fn enabled(cx: &AppContext) -> bool {
settings::get::<SemanticIndexSettings>(cx).enabled
}
pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
- if !self.embedding_provider.is_authenticated() {
+ if !self.is_authenticated() {
return SemanticIndexStatus::NotAuthenticated;
}
@@ -339,7 +339,7 @@ impl SemanticIndex {
Ok(cx.add_model(|cx| {
let t0 = Instant::now();
let embedding_queue =
- EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
+ EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
let _embedding_task = cx.background().spawn({
let embedded_files = embedding_queue.finished_files();
let db = db.clone();
@@ -404,6 +404,8 @@ impl SemanticIndex {
_embedding_task,
_parsing_files_tasks,
projects: Default::default(),
+ api_key: None,
+ embedding_queue
}
}))
}
@@ -718,12 +720,13 @@ impl SemanticIndex {
let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone();
+ let api_key = self.api_key.clone();
cx.spawn(|this, mut cx| async move {
index.await?;
let t0 = Instant::now();
let query = embedding_provider
- .embed_batch(vec![query])
+ .embed_batch(vec![query], api_key)
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
@@ -941,6 +944,7 @@ impl SemanticIndex {
let fs = self.fs.clone();
let db_path = self.db.path().clone();
let background = cx.background().clone();
+ let api_key = self.api_key.clone();
cx.background().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new();
@@ -955,10 +959,15 @@ impl SemanticIndex {
.parse_file_with_template(None, &snapshot.text(), language)
.log_err()
.unwrap_or_default();
- if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
- .await
- .log_err()
- .is_some()
+ if Self::embed_spans(
+ &mut spans,
+ embedding_provider.as_ref(),
+ &db,
+ api_key.clone(),
+ )
+ .await
+ .log_err()
+ .is_some()
{
for span in spans {
let similarity = span.embedding.unwrap().similarity(&query);
@@ -998,8 +1007,11 @@ impl SemanticIndex {
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
- if !self.embedding_provider.is_authenticated() {
- return Task::ready(Err(anyhow!("user is not authenticated")));
+ if self.api_key.is_none() {
+ self.authenticate(cx);
+ if self.api_key.is_none() {
+ return Task::ready(Err(anyhow!("user is not authenticated")));
+ }
}
if !self.projects.contains_key(&project.downgrade()) {
@@ -1180,6 +1192,7 @@ impl SemanticIndex {
spans: &mut [Span],
embedding_provider: &dyn EmbeddingProvider,
db: &VectorDatabase,
+ api_key: Option<String>,
) -> Result<()> {
let mut batch = Vec::new();
let mut batch_tokens = 0;
@@ -1202,7 +1215,7 @@ impl SemanticIndex {
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
let batch_embeddings = embedding_provider
- .embed_batch(mem::take(&mut batch))
+ .embed_batch(mem::take(&mut batch), api_key.clone())
.await?;
embeddings.extend(batch_embeddings);
batch_tokens = 0;
@@ -1214,7 +1227,7 @@ impl SemanticIndex {
if !batch.is_empty() {
let batch_embeddings = embedding_provider
- .embed_batch(mem::take(&mut batch))
+ .embed_batch(mem::take(&mut batch), api_key)
.await?;
embeddings.extend(batch_embeddings);
@@ -7,7 +7,7 @@ use crate::{
use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
use anyhow::Result;
use async_trait::async_trait;
-use gpui::{executor::Deterministic, Task, TestAppContext};
+use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
@@ -228,7 +228,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
+ let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
for file in &files {
queue.push(file.clone());
}
@@ -1281,8 +1281,8 @@ impl FakeEmbeddingProvider {
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
- fn is_authenticated(&self) -> bool {
- true
+ fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
+ Some("Fake Credentials".to_string())
}
fn truncate(&self, span: &str) -> (String, usize) {
(span.to_string(), 1)
@@ -1296,7 +1296,11 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
None
}
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+ async fn embed_batch(
+ &self,
+ spans: Vec<String>,
+ _api_key: Option<String>,
+ ) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
@@ -1,4 +1,3 @@
-use ai::completion::OPENAI_API_URL;
use ai::embedding::OpenAIEmbeddings;
use anyhow::{anyhow, Result};
use client::{self, UserStore};
@@ -18,7 +17,6 @@ use std::{cmp, env, fs};
use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
use util::http::{self};
use util::paths::EMBEDDINGS_DIR;
-use util::ResultExt;
use zed::languages;
#[derive(Deserialize, Clone, Serialize)]
@@ -57,7 +55,7 @@ fn parse_eval() -> anyhow::Result<Vec<RepoEval>> {
.as_path()
.parent()
.unwrap()
- .join("crates/semantic_index/eval");
+ .join("zed/crates/semantic_index/eval");
let mut repo_evals: Vec<RepoEval> = Vec::new();
for entry in fs::read_dir(eval_folder)? {
@@ -472,25 +470,12 @@ fn main() {
let languages = languages.clone();
- let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
- Some(api_key)
- } else if let Some((_, api_key)) = cx
- .platform()
- .read_credentials(OPENAI_API_URL)
- .log_err()
- .flatten()
- {
- String::from_utf8(api_key).log_err()
- } else {
- None
- };
-
let fs = fs.clone();
cx.spawn(|mut cx| async move {
let semantic_index = SemanticIndex::new(
fs.clone(),
db_file_path,
- Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
+ Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
languages.clone(),
cx.clone(),
)
@@ -1,3 +1,3 @@
#!/bin/bash
-RUST_LOG=semantic_index=trace cargo run -p semantic_index --example eval --release
+RUST_LOG=semantic_index=trace cargo run --example semantic_index_eval --release