crates/ai/Cargo.toml 🔗
@@ -8,6 +8,9 @@ publish = false
path = "src/ai.rs"
doctest = false
+[features]
+test-support = []
+
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }
KCaverly created
crates/ai/Cargo.toml | 3
crates/ai/src/ai.rs | 3
crates/ai/src/auth.rs | 20 ++
crates/ai/src/embedding.rs | 8
crates/ai/src/prompts/base.rs | 41 ----
crates/ai/src/providers/dummy.rs | 85 -----------
crates/ai/src/providers/mod.rs | 1
crates/ai/src/providers/open_ai/auth.rs | 33 ++++
crates/ai/src/providers/open_ai/embedding.rs | 46 +----
crates/ai/src/providers/open_ai/mod.rs | 1
crates/ai/src/test.rs | 123 +++++++++++++++++
crates/assistant/src/codegen.rs | 14 +
crates/semantic_index/Cargo.toml | 1
crates/semantic_index/src/embedding_queue.rs | 16 +-
crates/semantic_index/src/semantic_index.rs | 52 ++++--
crates/semantic_index/src/semantic_index_tests.rs | 101 ++-----------
16 files changed, 277 insertions(+), 271 deletions(-)
@@ -8,6 +8,9 @@ publish = false
path = "src/ai.rs"
doctest = false
+[features]
+test-support = []
+
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }
@@ -1,5 +1,8 @@
+pub mod auth;
pub mod completion;
pub mod embedding;
pub mod models;
pub mod prompts;
pub mod providers;
+#[cfg(any(test, feature = "test-support"))]
+pub mod test;
@@ -0,0 +1,20 @@
+use gpui::AppContext;
+
+#[derive(Clone)]
+pub enum ProviderCredential {
+ Credentials { api_key: String },
+ NoCredentials,
+ NotNeeded,
+}
+
+pub trait CredentialProvider: Send + Sync {
+ fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
+}
+
+#[derive(Clone)]
+pub struct NullCredentialProvider;
+impl CredentialProvider for NullCredentialProvider {
+ fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+ ProviderCredential::NotNeeded
+ }
+}
@@ -7,6 +7,7 @@ use ordered_float::OrderedFloat;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
+use crate::auth::{CredentialProvider, ProviderCredential};
use crate::models::LanguageModel;
#[derive(Debug, PartialEq, Clone)]
@@ -71,11 +72,14 @@ impl Embedding {
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
fn base_model(&self) -> Box<dyn LanguageModel>;
- fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
+ fn credential_provider(&self) -> Box<dyn CredentialProvider>;
+ fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+ self.credential_provider().retrieve_credentials(cx)
+ }
async fn embed_batch(
&self,
spans: Vec<String>,
- api_key: Option<String>,
+ credential: ProviderCredential,
) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn rate_limit_expiration(&self) -> Option<Instant>;
@@ -126,6 +126,7 @@ impl PromptChain {
#[cfg(test)]
pub(crate) mod tests {
use crate::models::TruncationDirection;
+ use crate::test::FakeLanguageModel;
use super::*;
@@ -181,39 +182,7 @@ pub(crate) mod tests {
}
}
- #[derive(Clone)]
- struct DummyLanguageModel {
- capacity: usize,
- }
-
- impl LanguageModel for DummyLanguageModel {
- fn name(&self) -> String {
- "dummy".to_string()
- }
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
- anyhow::Ok(content.chars().collect::<Vec<char>>().len())
- }
- fn truncate(
- &self,
- content: &str,
- length: usize,
- direction: TruncationDirection,
- ) -> anyhow::Result<String> {
- anyhow::Ok(match direction {
- TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
- .into_iter()
- .collect::<String>(),
- TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
- .into_iter()
- .collect::<String>(),
- })
- }
- fn capacity(&self) -> anyhow::Result<usize> {
- anyhow::Ok(self.capacity)
- }
- }
-
- let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@@ -249,7 +218,7 @@ pub(crate) mod tests {
// Testing with Truncation Off
// Should ignore capacity and return all prompts
- let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@@ -286,7 +255,7 @@ pub(crate) mod tests {
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let capacity = 20;
- let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@@ -322,7 +291,7 @@ pub(crate) mod tests {
// Change Ordering of Prompts Based on Priority
let capacity = 120;
let reserved_tokens = 10;
- let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@@ -1,85 +0,0 @@
-use std::time::Instant;
-
-use crate::{
- completion::CompletionRequest,
- embedding::{Embedding, EmbeddingProvider},
- models::{LanguageModel, TruncationDirection},
-};
-use async_trait::async_trait;
-use gpui::AppContext;
-use serde::Serialize;
-
-pub struct DummyLanguageModel {}
-
-impl LanguageModel for DummyLanguageModel {
- fn name(&self) -> String {
- "dummy".to_string()
- }
- fn capacity(&self) -> anyhow::Result<usize> {
- anyhow::Ok(1000)
- }
- fn truncate(
- &self,
- content: &str,
- length: usize,
- direction: crate::models::TruncationDirection,
- ) -> anyhow::Result<String> {
- if content.len() < length {
- return anyhow::Ok(content.to_string());
- }
-
- let truncated = match direction {
- TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
- .iter()
- .collect::<String>(),
- TruncationDirection::Start => content.chars().collect::<Vec<char>>()[..length]
- .iter()
- .collect::<String>(),
- };
-
- anyhow::Ok(truncated)
- }
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
- anyhow::Ok(content.chars().collect::<Vec<char>>().len())
- }
-}
-
-#[derive(Serialize)]
-pub struct DummyCompletionRequest {
- pub name: String,
-}
-
-impl CompletionRequest for DummyCompletionRequest {
- fn data(&self) -> serde_json::Result<String> {
- serde_json::to_string(self)
- }
-}
-
-pub struct DummyEmbeddingProvider {}
-
-#[async_trait]
-impl EmbeddingProvider for DummyEmbeddingProvider {
- fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
- Some("Dummy Credentials".to_string())
- }
- fn base_model(&self) -> Box<dyn LanguageModel> {
- Box::new(DummyLanguageModel {})
- }
- fn rate_limit_expiration(&self) -> Option<Instant> {
- None
- }
- async fn embed_batch(
- &self,
- spans: Vec<String>,
- api_key: Option<String>,
- ) -> anyhow::Result<Vec<Embedding>> {
- // 1024 is the OpenAI Embeddings size for ada models.
- // the model we will likely be starting with.
- let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
- return Ok(vec![dummy_vec; spans.len()]);
- }
-
- fn max_tokens_per_batch(&self) -> usize {
- 8190
- }
-}
@@ -1,2 +1 @@
-pub mod dummy;
pub mod open_ai;
@@ -0,0 +1,33 @@
+use std::env;
+
+use gpui::AppContext;
+use util::ResultExt;
+
+use crate::auth::{CredentialProvider, ProviderCredential};
+use crate::providers::open_ai::OPENAI_API_URL;
+
+#[derive(Clone)]
+pub struct OpenAICredentialProvider {}
+
+impl CredentialProvider for OpenAICredentialProvider {
+ fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+ let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ Some(api_key)
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ String::from_utf8(api_key).log_err()
+ } else {
+ None
+ };
+
+ if let Some(api_key) = api_key {
+ ProviderCredential::Credentials { api_key }
+ } else {
+ ProviderCredential::NoCredentials
+ }
+ }
+}
@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::executor::Background;
-use gpui::{serde_json, AppContext};
+use gpui::serde_json;
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
@@ -17,13 +17,13 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
-use util::ResultExt;
+use crate::auth::{CredentialProvider, ProviderCredential};
use crate::embedding::{Embedding, EmbeddingProvider};
use crate::models::LanguageModel;
use crate::providers::open_ai::OpenAILanguageModel;
-use super::OPENAI_API_URL;
+use crate::providers::open_ai::auth::OpenAICredentialProvider;
lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
@@ -33,6 +33,7 @@ lazy_static! {
#[derive(Clone)]
pub struct OpenAIEmbeddingProvider {
model: OpenAILanguageModel,
+ credential_provider: OpenAICredentialProvider,
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@@ -73,6 +74,7 @@ impl OpenAIEmbeddingProvider {
OpenAIEmbeddingProvider {
model,
+ credential_provider: OpenAICredentialProvider {},
client,
executor,
rate_limit_count_rx,
@@ -138,25 +140,17 @@ impl OpenAIEmbeddingProvider {
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddingProvider {
- fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
- let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
- Some(api_key)
- } else if let Some((_, api_key)) = cx
- .platform()
- .read_credentials(OPENAI_API_URL)
- .log_err()
- .flatten()
- {
- String::from_utf8(api_key).log_err()
- } else {
- None
- };
- api_key
- }
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
+
+ fn credential_provider(&self) -> Box<dyn CredentialProvider> {
+ let credential_provider: Box<dyn CredentialProvider> =
+ Box::new(self.credential_provider.clone());
+ credential_provider
+ }
+
fn max_tokens_per_batch(&self) -> usize {
50000
}
@@ -164,25 +158,11 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
fn rate_limit_expiration(&self) -> Option<Instant> {
*self.rate_limit_count_rx.borrow()
}
- // fn truncate(&self, span: &str) -> (String, usize) {
- // let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
- // let output = if tokens.len() > OPENAI_INPUT_LIMIT {
- // tokens.truncate(OPENAI_INPUT_LIMIT);
- // OPENAI_BPE_TOKENIZER
- // .decode(tokens.clone())
- // .ok()
- // .unwrap_or_else(|| span.to_string())
- // } else {
- // span.to_string()
- // };
-
- // (output, tokens.len())
- // }
async fn embed_batch(
&self,
spans: Vec<String>,
- api_key: Option<String>,
+ _credential: ProviderCredential,
) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
@@ -1,3 +1,4 @@
+pub mod auth;
pub mod completion;
pub mod embedding;
pub mod model;
@@ -0,0 +1,123 @@
+use std::{
+ sync::atomic::{self, AtomicUsize, Ordering},
+ time::Instant,
+};
+
+use async_trait::async_trait;
+
+use crate::{
+ auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
+ embedding::{Embedding, EmbeddingProvider},
+ models::{LanguageModel, TruncationDirection},
+};
+
+#[derive(Clone)]
+pub struct FakeLanguageModel {
+ pub capacity: usize,
+}
+
+impl LanguageModel for FakeLanguageModel {
+ fn name(&self) -> String {
+ "dummy".to_string()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ anyhow::Ok(content.chars().collect::<Vec<char>>().len())
+ }
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String> {
+ anyhow::Ok(match direction {
+ TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
+ .into_iter()
+ .collect::<String>(),
+ TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
+ .into_iter()
+ .collect::<String>(),
+ })
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(self.capacity)
+ }
+}
+
+pub struct FakeEmbeddingProvider {
+ pub embedding_count: AtomicUsize,
+ pub credential_provider: NullCredentialProvider,
+}
+
+impl Clone for FakeEmbeddingProvider {
+ fn clone(&self) -> Self {
+ FakeEmbeddingProvider {
+ embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
+ credential_provider: self.credential_provider.clone(),
+ }
+ }
+}
+
+impl Default for FakeEmbeddingProvider {
+ fn default() -> Self {
+ FakeEmbeddingProvider {
+ embedding_count: AtomicUsize::default(),
+ credential_provider: NullCredentialProvider {},
+ }
+ }
+}
+
+impl FakeEmbeddingProvider {
+ pub fn embedding_count(&self) -> usize {
+ self.embedding_count.load(atomic::Ordering::SeqCst)
+ }
+
+ pub fn embed_sync(&self, span: &str) -> Embedding {
+ let mut result = vec![1.0; 26];
+ for letter in span.chars() {
+ let letter = letter.to_ascii_lowercase();
+ if letter as u32 >= 'a' as u32 {
+ let ix = (letter as u32) - ('a' as u32);
+ if ix < 26 {
+ result[ix as usize] += 1.0;
+ }
+ }
+ }
+
+ let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+ for x in &mut result {
+ *x /= norm;
+ }
+
+ result.into()
+ }
+}
+
+#[async_trait]
+impl EmbeddingProvider for FakeEmbeddingProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ Box::new(FakeLanguageModel { capacity: 1000 })
+ }
+ fn credential_provider(&self) -> Box<dyn CredentialProvider> {
+ let credential_provider: Box<dyn CredentialProvider> =
+ Box::new(self.credential_provider.clone());
+ credential_provider
+ }
+ fn max_tokens_per_batch(&self) -> usize {
+ 1000
+ }
+
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ None
+ }
+
+ async fn embed_batch(
+ &self,
+ spans: Vec<String>,
+ _credential: ProviderCredential,
+ ) -> anyhow::Result<Vec<Embedding>> {
+ self.embedding_count
+ .fetch_add(spans.len(), atomic::Ordering::SeqCst);
+
+ anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
+ }
+}
@@ -335,7 +335,6 @@ fn strip_markdown_codeblock(
#[cfg(test)]
mod tests {
use super::*;
- use ai::providers::dummy::DummyCompletionRequest;
use futures::{
future::BoxFuture,
stream::{self, BoxStream},
@@ -345,9 +344,21 @@ mod tests {
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use parking_lot::Mutex;
use rand::prelude::*;
+ use serde::Serialize;
use settings::SettingsStore;
use smol::future::FutureExt;
+ #[derive(Serialize)]
+ pub struct DummyCompletionRequest {
+ pub name: String,
+ }
+
+ impl CompletionRequest for DummyCompletionRequest {
+ fn data(&self) -> serde_json::Result<String> {
+ serde_json::to_string(self)
+ }
+ }
+
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(
cx: &mut TestAppContext,
@@ -381,6 +392,7 @@ mod tests {
cx,
)
});
+
let request = Box::new(DummyCompletionRequest {
name: "test".to_string(),
});
@@ -42,6 +42,7 @@ sha1 = "0.10.5"
ndarray = { version = "0.15.0" }
[dev-dependencies]
+ai = { path = "../ai", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }
@@ -1,5 +1,5 @@
use crate::{parsing::Span, JobHandle};
-use ai::embedding::EmbeddingProvider;
+use ai::{auth::ProviderCredential, embedding::EmbeddingProvider};
use gpui::executor::Background;
use parking_lot::Mutex;
use smol::channel;
@@ -41,7 +41,7 @@ pub struct EmbeddingQueue {
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
- api_key: Option<String>,
+ provider_credential: ProviderCredential,
}
#[derive(Clone)]
@@ -54,7 +54,7 @@ impl EmbeddingQueue {
pub fn new(
embedding_provider: Arc<dyn EmbeddingProvider>,
executor: Arc<Background>,
- api_key: Option<String>,
+ provider_credential: ProviderCredential,
) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
@@ -64,12 +64,12 @@ impl EmbeddingQueue {
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
- api_key,
+ provider_credential,
}
}
- pub fn set_api_key(&mut self, api_key: Option<String>) {
- self.api_key = api_key
+ pub fn set_credential(&mut self, credential: ProviderCredential) {
+ self.provider_credential = credential
}
pub fn push(&mut self, file: FileToEmbed) {
@@ -118,7 +118,7 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
- let api_key = self.api_key.clone();
+ let credential = self.provider_credential.clone();
self.executor
.spawn(async move {
@@ -143,7 +143,7 @@ impl EmbeddingQueue {
return;
};
- match embedding_provider.embed_batch(spans, api_key).await {
+ match embedding_provider.embed_batch(spans, credential).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
@@ -7,6 +7,7 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
+use ai::auth::ProviderCredential;
use ai::embedding::{Embedding, EmbeddingProvider};
use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result};
@@ -124,7 +125,7 @@ pub struct SemanticIndex {
_embedding_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
- api_key: Option<String>,
+ provider_credential: ProviderCredential,
embedding_queue: Arc<Mutex<EmbeddingQueue>>,
}
@@ -279,18 +280,27 @@ impl SemanticIndex {
}
}
- pub fn authenticate(&mut self, cx: &AppContext) {
- if self.api_key.is_none() {
- self.api_key = self.embedding_provider.retrieve_credentials(cx);
-
- self.embedding_queue
- .lock()
- .set_api_key(self.api_key.clone());
+ pub fn authenticate(&mut self, cx: &AppContext) -> bool {
+ let credential = self.provider_credential.clone();
+ match credential {
+ ProviderCredential::NoCredentials => {
+ let credential = self.embedding_provider.retrieve_credentials(cx);
+ self.provider_credential = credential;
+ }
+ _ => {}
}
+
+ self.embedding_queue.lock().set_credential(credential);
+
+ self.is_authenticated()
}
pub fn is_authenticated(&self) -> bool {
- self.api_key.is_some()
+ let credential = &self.provider_credential;
+ match credential {
+ &ProviderCredential::Credentials { .. } => true,
+ _ => false,
+ }
}
pub fn enabled(cx: &AppContext) -> bool {
@@ -340,7 +350,7 @@ impl SemanticIndex {
Ok(cx.add_model(|cx| {
let t0 = Instant::now();
let embedding_queue =
- EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
+ EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials);
let _embedding_task = cx.background().spawn({
let embedded_files = embedding_queue.finished_files();
let db = db.clone();
@@ -405,7 +415,7 @@ impl SemanticIndex {
_embedding_task,
_parsing_files_tasks,
projects: Default::default(),
- api_key: None,
+ provider_credential: ProviderCredential::NoCredentials,
embedding_queue
}
}))
@@ -721,13 +731,14 @@ impl SemanticIndex {
let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone();
- let api_key = self.api_key.clone();
+ let credential = self.provider_credential.clone();
cx.spawn(|this, mut cx| async move {
index.await?;
let t0 = Instant::now();
+
let query = embedding_provider
- .embed_batch(vec![query], api_key)
+ .embed_batch(vec![query], credential)
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
@@ -945,7 +956,7 @@ impl SemanticIndex {
let fs = self.fs.clone();
let db_path = self.db.path().clone();
let background = cx.background().clone();
- let api_key = self.api_key.clone();
+ let credential = self.provider_credential.clone();
cx.background().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new();
@@ -964,7 +975,7 @@ impl SemanticIndex {
&mut spans,
embedding_provider.as_ref(),
&db,
- api_key.clone(),
+ credential.clone(),
)
.await
.log_err()
@@ -1008,9 +1019,8 @@ impl SemanticIndex {
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
- if self.api_key.is_none() {
- self.authenticate(cx);
- if self.api_key.is_none() {
+ if !self.is_authenticated() {
+ if !self.authenticate(cx) {
return Task::ready(Err(anyhow!("user is not authenticated")));
}
}
@@ -1193,7 +1203,7 @@ impl SemanticIndex {
spans: &mut [Span],
embedding_provider: &dyn EmbeddingProvider,
db: &VectorDatabase,
- api_key: Option<String>,
+ credential: ProviderCredential,
) -> Result<()> {
let mut batch = Vec::new();
let mut batch_tokens = 0;
@@ -1216,7 +1226,7 @@ impl SemanticIndex {
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
let batch_embeddings = embedding_provider
- .embed_batch(mem::take(&mut batch), api_key.clone())
+ .embed_batch(mem::take(&mut batch), credential.clone())
.await?;
embeddings.extend(batch_embeddings);
batch_tokens = 0;
@@ -1228,7 +1238,7 @@ impl SemanticIndex {
if !batch.is_empty() {
let batch_embeddings = embedding_provider
- .embed_batch(mem::take(&mut batch), api_key)
+ .embed_batch(mem::take(&mut batch), credential)
.await?;
embeddings.extend(batch_embeddings);
@@ -4,14 +4,9 @@ use crate::{
semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
};
-use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel};
-use ai::{
- embedding::{Embedding, EmbeddingProvider},
- models::LanguageModel,
-};
-use anyhow::Result;
-use async_trait::async_trait;
-use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
+use ai::test::FakeEmbeddingProvider;
+
+use gpui::{executor::Deterministic, Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
@@ -19,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
use rand::{rngs::StdRng, Rng};
use serde_json::json;
use settings::SettingsStore;
-use std::{
- path::Path,
- sync::{
- atomic::{self, AtomicUsize},
- Arc,
- },
- time::{Instant, SystemTime},
-};
+use std::{path::Path, sync::Arc, time::SystemTime};
use unindent::Unindent;
use util::RandomCharIter;
@@ -232,7 +220,11 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
+ let mut queue = EmbeddingQueue::new(
+ embedding_provider.clone(),
+ cx.background(),
+ ai::auth::ProviderCredential::NoCredentials,
+ );
for file in &files {
queue.push(file.clone());
}
@@ -284,7 +276,7 @@ fn assert_search_results(
#[gpui::test]
async fn test_code_context_retrieval_rust() {
let language = rust_lang();
- let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@@ -386,7 +378,7 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test]
async fn test_code_context_retrieval_json() {
let language = json_lang();
- let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -470,7 +462,7 @@ fn assert_documents_eq(
#[gpui::test]
async fn test_code_context_retrieval_javascript() {
let language = js_lang();
- let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@@ -569,7 +561,7 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test]
async fn test_code_context_retrieval_lua() {
let language = lua_lang();
- let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -643,7 +635,7 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test]
async fn test_code_context_retrieval_elixir() {
let language = elixir_lang();
- let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -760,7 +752,7 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test]
async fn test_code_context_retrieval_cpp() {
let language = cpp_lang();
- let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@@ -913,7 +905,7 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test]
async fn test_code_context_retrieval_ruby() {
let language = ruby_lang();
- let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -1104,7 +1096,7 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test]
async fn test_code_context_retrieval_php() {
let language = php_lang();
- let embedding_provider = Arc::new(DummyEmbeddingProvider {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -1252,65 +1244,6 @@ async fn test_code_context_retrieval_php() {
);
}
-#[derive(Default)]
-struct FakeEmbeddingProvider {
- embedding_count: AtomicUsize,
-}
-
-impl FakeEmbeddingProvider {
- fn embedding_count(&self) -> usize {
- self.embedding_count.load(atomic::Ordering::SeqCst)
- }
-
- fn embed_sync(&self, span: &str) -> Embedding {
- let mut result = vec![1.0; 26];
- for letter in span.chars() {
- let letter = letter.to_ascii_lowercase();
- if letter as u32 >= 'a' as u32 {
- let ix = (letter as u32) - ('a' as u32);
- if ix < 26 {
- result[ix as usize] += 1.0;
- }
- }
- }
-
- let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
- for x in &mut result {
- *x /= norm;
- }
-
- result.into()
- }
-}
-
-#[async_trait]
-impl EmbeddingProvider for FakeEmbeddingProvider {
- fn base_model(&self) -> Box<dyn LanguageModel> {
- Box::new(DummyLanguageModel {})
- }
- fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
- Some("Fake Credentials".to_string())
- }
- fn max_tokens_per_batch(&self) -> usize {
- 1000
- }
-
- fn rate_limit_expiration(&self) -> Option<Instant> {
- None
- }
-
- async fn embed_batch(
- &self,
- spans: Vec<String>,
- _api_key: Option<String>,
- ) -> Result<Vec<Embedding>> {
- self.embedding_count
- .fetch_add(spans.len(), atomic::Ordering::SeqCst);
-
- anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
- }
-}
-
fn js_lang() -> Arc<Language> {
Arc::new(
Language::new(