Detailed changes
@@ -8,17 +8,8 @@ pub enum ProviderCredential {
}
pub trait CredentialProvider: Send + Sync {
+ fn has_credentials(&self) -> bool;
fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
fn delete_credentials(&self, cx: &AppContext);
}
-
-#[derive(Clone)]
-pub struct NullCredentialProvider;
-impl CredentialProvider for NullCredentialProvider {
- fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
- ProviderCredential::NotNeeded
- }
- fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {}
- fn delete_credentials(&self, cx: &AppContext) {}
-}
@@ -1,28 +1,14 @@
use anyhow::Result;
use futures::{future::BoxFuture, stream::BoxStream};
-use gpui::AppContext;
-use crate::{
- auth::{CredentialProvider, ProviderCredential},
- models::LanguageModel,
-};
+use crate::{auth::CredentialProvider, models::LanguageModel};
pub trait CompletionRequest: Send + Sync {
fn data(&self) -> serde_json::Result<String>;
}
-pub trait CompletionProvider {
+pub trait CompletionProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
- fn credential_provider(&self) -> Box<dyn CredentialProvider>;
- fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
- self.credential_provider().retrieve_credentials(cx)
- }
- fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
- self.credential_provider().save_credentials(cx, credential);
- }
- fn delete_credentials(&self, cx: &AppContext) {
- self.credential_provider().delete_credentials(cx);
- }
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
@@ -2,12 +2,11 @@ use std::time::Instant;
use anyhow::Result;
use async_trait::async_trait;
-use gpui::AppContext;
use ordered_float::OrderedFloat;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
-use crate::auth::{CredentialProvider, ProviderCredential};
+use crate::auth::CredentialProvider;
use crate::models::LanguageModel;
#[derive(Debug, PartialEq, Clone)]
@@ -70,17 +69,9 @@ impl Embedding {
}
#[async_trait]
-pub trait EmbeddingProvider: Sync + Send {
+pub trait EmbeddingProvider: CredentialProvider {
fn base_model(&self) -> Box<dyn LanguageModel>;
- fn credential_provider(&self) -> Box<dyn CredentialProvider>;
- fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
- self.credential_provider().retrieve_credentials(cx)
- }
- async fn embed_batch(
- &self,
- spans: Vec<String>,
- credential: ProviderCredential,
- ) -> Result<Vec<Embedding>>;
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn rate_limit_expiration(&self) -> Option<Instant>;
}
@@ -1,46 +0,0 @@
-use std::env;
-
-use gpui::AppContext;
-use util::ResultExt;
-
-use crate::auth::{CredentialProvider, ProviderCredential};
-use crate::providers::open_ai::OPENAI_API_URL;
-
-#[derive(Clone)]
-pub struct OpenAICredentialProvider {}
-
-impl CredentialProvider for OpenAICredentialProvider {
- fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
- let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
- Some(api_key)
- } else if let Some((_, api_key)) = cx
- .platform()
- .read_credentials(OPENAI_API_URL)
- .log_err()
- .flatten()
- {
- String::from_utf8(api_key).log_err()
- } else {
- None
- };
-
- if let Some(api_key) = api_key {
- ProviderCredential::Credentials { api_key }
- } else {
- ProviderCredential::NoCredentials
- }
- }
- fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
- match credential {
- ProviderCredential::Credentials { api_key } => {
- cx.platform()
- .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
- .log_err();
- }
- _ => {}
- }
- }
- fn delete_credentials(&self, cx: &AppContext) {
- cx.platform().delete_credentials(OPENAI_API_URL).log_err();
- }
-}
@@ -3,14 +3,17 @@ use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
-use gpui::executor::Background;
+use gpui::{executor::Background, AppContext};
use isahc::{http::StatusCode, Request, RequestExt};
+use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::{
+ env,
fmt::{self, Display},
io,
sync::Arc,
};
+use util::ResultExt;
use crate::{
auth::{CredentialProvider, ProviderCredential},
@@ -18,9 +21,7 @@ use crate::{
models::LanguageModel,
};
-use super::{auth::OpenAICredentialProvider, OpenAILanguageModel};
-
-pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
+use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
@@ -194,42 +195,83 @@ pub async fn stream_completion(
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
- credential_provider: OpenAICredentialProvider,
- credential: ProviderCredential,
+ credential: Arc<RwLock<ProviderCredential>>,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
- pub fn new(
- model_name: &str,
- credential: ProviderCredential,
- executor: Arc<Background>,
- ) -> Self {
+ pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
let model = OpenAILanguageModel::load(model_name);
- let credential_provider = OpenAICredentialProvider {};
+ let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self {
model,
- credential_provider,
credential,
executor,
}
}
}
+impl CredentialProvider for OpenAICompletionProvider {
+ fn has_credentials(&self) -> bool {
+ match *self.credential.read() {
+ ProviderCredential::Credentials { .. } => true,
+ _ => false,
+ }
+ }
+ fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+ let mut credential = self.credential.write();
+ match *credential {
+ ProviderCredential::Credentials { .. } => {
+ return credential.clone();
+ }
+ _ => {
+ if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ *credential = ProviderCredential::Credentials { api_key };
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ if let Some(api_key) = String::from_utf8(api_key).log_err() {
+ *credential = ProviderCredential::Credentials { api_key };
+ }
+ } else {
+ };
+ }
+ }
+
+ credential.clone()
+ }
+
+ fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
+ match credential.clone() {
+ ProviderCredential::Credentials { api_key } => {
+ cx.platform()
+ .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+ .log_err();
+ }
+ _ => {}
+ }
+
+ *self.credential.write() = credential;
+ }
+ fn delete_credentials(&self, cx: &AppContext) {
+ cx.platform().delete_credentials(OPENAI_API_URL).log_err();
+ *self.credential.write() = ProviderCredential::NoCredentials;
+ }
+}
+
impl CompletionProvider for OpenAICompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
model
}
- fn credential_provider(&self) -> Box<dyn CredentialProvider> {
- let provider: Box<dyn CredentialProvider> = Box::new(self.credential_provider.clone());
- provider
- }
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- let credential = self.credential.clone();
+ let credential = self.credential.read().clone();
let request = stream_completion(credential, self.executor.clone(), prompt);
async move {
let response = request.await?;
@@ -2,27 +2,29 @@ use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::executor::Background;
-use gpui::serde_json;
+use gpui::{serde_json, AppContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
-use parking_lot::Mutex;
+use parking_lot::{Mutex, RwLock};
use parse_duration::parse;
use postage::watch;
use serde::{Deserialize, Serialize};
+use std::env;
use std::ops::Add;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
+use util::ResultExt;
use crate::auth::{CredentialProvider, ProviderCredential};
use crate::embedding::{Embedding, EmbeddingProvider};
use crate::models::LanguageModel;
use crate::providers::open_ai::OpenAILanguageModel;
-use crate::providers::open_ai::auth::OpenAICredentialProvider;
+use crate::providers::open_ai::OPENAI_API_URL;
lazy_static! {
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
@@ -31,7 +33,7 @@ lazy_static! {
#[derive(Clone)]
pub struct OpenAIEmbeddingProvider {
model: OpenAILanguageModel,
- credential_provider: OpenAICredentialProvider,
+ credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@@ -69,10 +71,11 @@ impl OpenAIEmbeddingProvider {
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
let model = OpenAILanguageModel::load("text-embedding-ada-002");
+ let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
OpenAIEmbeddingProvider {
model,
- credential_provider: OpenAICredentialProvider {},
+ credential,
client,
executor,
rate_limit_count_rx,
@@ -80,6 +83,13 @@ impl OpenAIEmbeddingProvider {
}
}
+ fn get_api_key(&self) -> Result<String> {
+ match self.credential.read().clone() {
+ ProviderCredential::Credentials { api_key } => Ok(api_key),
+ _ => Err(anyhow!("api credentials not provided")),
+ }
+ }
+
fn resolve_rate_limit(&self) {
let reset_time = *self.rate_limit_count_tx.lock().borrow();
@@ -136,6 +146,57 @@ impl OpenAIEmbeddingProvider {
}
}
+impl CredentialProvider for OpenAIEmbeddingProvider {
+ fn has_credentials(&self) -> bool {
+ match *self.credential.read() {
+ ProviderCredential::Credentials { .. } => true,
+ _ => false,
+ }
+ }
+ fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+ let mut credential = self.credential.write();
+ match *credential {
+ ProviderCredential::Credentials { .. } => {
+ return credential.clone();
+ }
+ _ => {
+ if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ *credential = ProviderCredential::Credentials { api_key };
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ if let Some(api_key) = String::from_utf8(api_key).log_err() {
+ *credential = ProviderCredential::Credentials { api_key };
+ }
+ } else {
+ };
+ }
+ }
+
+ credential.clone()
+ }
+
+ fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
+ match credential.clone() {
+ ProviderCredential::Credentials { api_key } => {
+ cx.platform()
+ .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+ .log_err();
+ }
+ _ => {}
+ }
+
+ *self.credential.write() = credential;
+ }
+ fn delete_credentials(&self, cx: &AppContext) {
+ cx.platform().delete_credentials(OPENAI_API_URL).log_err();
+ *self.credential.write() = ProviderCredential::NoCredentials;
+ }
+}
+
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
@@ -143,12 +204,6 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
model
}
- fn credential_provider(&self) -> Box<dyn CredentialProvider> {
- let credential_provider: Box<dyn CredentialProvider> =
- Box::new(self.credential_provider.clone());
- credential_provider
- }
-
fn max_tokens_per_batch(&self) -> usize {
50000
}
@@ -157,18 +212,11 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
*self.rate_limit_count_rx.borrow()
}
- async fn embed_batch(
- &self,
- spans: Vec<String>,
- credential: ProviderCredential,
- ) -> Result<Vec<Embedding>> {
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
- let api_key = match credential {
- ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key),
- _ => Err(anyhow!("no api key provided")),
- }?;
+ let api_key = self.get_api_key()?;
let mut request_number = 0;
let mut rate_limiting = false;
@@ -1,4 +1,3 @@
-pub mod auth;
pub mod completion;
pub mod embedding;
pub mod model;
@@ -6,3 +5,5 @@ pub mod model;
pub use completion::*;
pub use embedding::*;
pub use model::OpenAILanguageModel;
+
+pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
@@ -0,0 +1,11 @@
+pub trait LanguageModel {
+ fn name(&self) -> String;
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String>;
+ fn capacity(&self) -> anyhow::Result<usize>;
+}
@@ -5,10 +5,11 @@ use std::{
use async_trait::async_trait;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui::AppContext;
use parking_lot::Mutex;
use crate::{
- auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
+ auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
embedding::{Embedding, EmbeddingProvider},
models::{LanguageModel, TruncationDirection},
@@ -52,14 +53,12 @@ impl LanguageModel for FakeLanguageModel {
pub struct FakeEmbeddingProvider {
pub embedding_count: AtomicUsize,
- pub credential_provider: NullCredentialProvider,
}
impl Clone for FakeEmbeddingProvider {
fn clone(&self) -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
- credential_provider: self.credential_provider.clone(),
}
}
}
@@ -68,7 +67,6 @@ impl Default for FakeEmbeddingProvider {
fn default() -> Self {
FakeEmbeddingProvider {
embedding_count: AtomicUsize::default(),
- credential_provider: NullCredentialProvider {},
}
}
}
@@ -99,16 +97,22 @@ impl FakeEmbeddingProvider {
}
}
+impl CredentialProvider for FakeEmbeddingProvider {
+ fn has_credentials(&self) -> bool {
+ true
+ }
+ fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+ ProviderCredential::NotNeeded
+ }
+ fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
+ fn delete_credentials(&self, _cx: &AppContext) {}
+}
+
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
Box::new(FakeLanguageModel { capacity: 1000 })
}
- fn credential_provider(&self) -> Box<dyn CredentialProvider> {
- let credential_provider: Box<dyn CredentialProvider> =
- Box::new(self.credential_provider.clone());
- credential_provider
- }
fn max_tokens_per_batch(&self) -> usize {
1000
}
@@ -117,11 +121,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
None
}
- async fn embed_batch(
- &self,
- spans: Vec<String>,
- _credential: ProviderCredential,
- ) -> anyhow::Result<Vec<Embedding>> {
+ async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);
@@ -129,11 +129,11 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
}
}
-pub struct TestCompletionProvider {
+pub struct FakeCompletionProvider {
last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
}
-impl TestCompletionProvider {
+impl FakeCompletionProvider {
pub fn new() -> Self {
Self {
last_completion_tx: Mutex::new(None),
@@ -150,14 +150,22 @@ impl TestCompletionProvider {
}
}
-impl CompletionProvider for TestCompletionProvider {
+impl CredentialProvider for FakeCompletionProvider {
+ fn has_credentials(&self) -> bool {
+ true
+ }
+ fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+ ProviderCredential::NotNeeded
+ }
+ fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
+ fn delete_credentials(&self, _cx: &AppContext) {}
+}
+
+impl CompletionProvider for FakeCompletionProvider {
fn base_model(&self) -> Box<dyn LanguageModel> {
let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
model
}
- fn credential_provider(&self) -> Box<dyn CredentialProvider> {
- Box::new(NullCredentialProvider {})
- }
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,
@@ -10,7 +10,7 @@ use ai::{
auth::ProviderCredential,
completion::{CompletionProvider, CompletionRequest},
providers::open_ai::{
- stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+ stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage,
},
};
@@ -48,7 +48,7 @@ use semantic_index::{SemanticIndex, SemanticIndexStatus};
use settings::SettingsStore;
use std::{
cell::{Cell, RefCell},
- cmp, env,
+ cmp,
fmt::Write,
iter,
ops::Range,
@@ -210,7 +210,6 @@ impl AssistantPanel {
// Defaulting currently to GPT4, allow for this to be set via config.
let completion_provider = Box::new(OpenAICompletionProvider::new(
"gpt-4",
- ProviderCredential::NoCredentials,
cx.background().clone(),
));
@@ -298,7 +297,6 @@ impl AssistantPanel {
cx: &mut ViewContext<Self>,
project: &ModelHandle<Project>,
) {
- let credential = self.credential.borrow().clone();
let selection = editor.read(cx).selections.newest_anchor().clone();
if selection.start.excerpt_id() != selection.end.excerpt_id() {
return;
@@ -330,7 +328,6 @@ impl AssistantPanel {
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let provider = Arc::new(OpenAICompletionProvider::new(
"gpt-4",
- credential,
cx.background().clone(),
));
@@ -335,7 +335,7 @@ fn strip_markdown_codeblock(
#[cfg(test)]
mod tests {
use super::*;
- use ai::test::TestCompletionProvider;
+ use ai::test::FakeCompletionProvider;
use futures::stream::{self};
use gpui::{executor::Deterministic, TestAppContext};
use indoc::indoc;
@@ -379,7 +379,7 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
- let provider = Arc::new(TestCompletionProvider::new());
+ let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
@@ -445,7 +445,7 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
- let provider = Arc::new(TestCompletionProvider::new());
+ let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
@@ -511,7 +511,7 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
- let provider = Arc::new(TestCompletionProvider::new());
+ let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
@@ -1,5 +1,5 @@
use crate::{parsing::Span, JobHandle};
-use ai::{auth::ProviderCredential, embedding::EmbeddingProvider};
+use ai::embedding::EmbeddingProvider;
use gpui::executor::Background;
use parking_lot::Mutex;
use smol::channel;
@@ -41,7 +41,6 @@ pub struct EmbeddingQueue {
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
- pub provider_credential: ProviderCredential,
}
#[derive(Clone)]
@@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed {
}
impl EmbeddingQueue {
- pub fn new(
- embedding_provider: Arc<dyn EmbeddingProvider>,
- executor: Arc<Background>,
- provider_credential: ProviderCredential,
- ) -> Self {
+ pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,
@@ -64,14 +59,9 @@ impl EmbeddingQueue {
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
- provider_credential,
}
}
- pub fn set_credential(&mut self, credential: ProviderCredential) {
- self.provider_credential = credential;
- }
-
pub fn push(&mut self, file: FileToEmbed) {
if file.spans.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
@@ -118,7 +108,6 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
- let credential = self.provider_credential.clone();
self.executor
.spawn(async move {
@@ -143,7 +132,7 @@ impl EmbeddingQueue {
return;
};
- match embedding_provider.embed_batch(spans, credential).await {
+ match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
@@ -7,7 +7,6 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::auth::ProviderCredential;
use ai::embedding::{Embedding, EmbeddingProvider};
use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result};
@@ -125,8 +124,6 @@ pub struct SemanticIndex {
_embedding_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
- provider_credential: ProviderCredential,
- embedding_queue: Arc<Mutex<EmbeddingQueue>>,
}
struct ProjectState {
@@ -281,24 +278,17 @@ impl SemanticIndex {
}
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
- let existing_credential = self.provider_credential.clone();
- let credential = match existing_credential {
- ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx),
- _ => existing_credential,
- };
+ if !self.embedding_provider.has_credentials() {
+ self.embedding_provider.retrieve_credentials(cx);
+ } else {
+ return true;
+ }
- self.provider_credential = credential.clone();
- self.embedding_queue.lock().set_credential(credential);
- self.is_authenticated()
+ self.embedding_provider.has_credentials()
}
pub fn is_authenticated(&self) -> bool {
- let credential = &self.provider_credential;
- match credential {
- &ProviderCredential::Credentials { .. } => true,
- &ProviderCredential::NotNeeded => true,
- _ => false,
- }
+ self.embedding_provider.has_credentials()
}
pub fn enabled(cx: &AppContext) -> bool {
@@ -348,7 +338,7 @@ impl SemanticIndex {
Ok(cx.add_model(|cx| {
let t0 = Instant::now();
let embedding_queue =
- EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials);
+ EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
let _embedding_task = cx.background().spawn({
let embedded_files = embedding_queue.finished_files();
let db = db.clone();
@@ -413,8 +403,6 @@ impl SemanticIndex {
_embedding_task,
_parsing_files_tasks,
projects: Default::default(),
- provider_credential: ProviderCredential::NoCredentials,
- embedding_queue
}
}))
}
@@ -729,14 +717,13 @@ impl SemanticIndex {
let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone();
- let credential = self.provider_credential.clone();
cx.spawn(|this, mut cx| async move {
index.await?;
let t0 = Instant::now();
let query = embedding_provider
- .embed_batch(vec![query], credential)
+ .embed_batch(vec![query])
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
@@ -954,7 +941,6 @@ impl SemanticIndex {
let fs = self.fs.clone();
let db_path = self.db.path().clone();
let background = cx.background().clone();
- let credential = self.provider_credential.clone();
cx.background().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new();
@@ -969,15 +955,10 @@ impl SemanticIndex {
.parse_file_with_template(None, &snapshot.text(), language)
.log_err()
.unwrap_or_default();
- if Self::embed_spans(
- &mut spans,
- embedding_provider.as_ref(),
- &db,
- credential.clone(),
- )
- .await
- .log_err()
- .is_some()
+ if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
+ .await
+ .log_err()
+ .is_some()
{
for span in spans {
let similarity = span.embedding.unwrap().similarity(&query);
@@ -1201,7 +1182,6 @@ impl SemanticIndex {
spans: &mut [Span],
embedding_provider: &dyn EmbeddingProvider,
db: &VectorDatabase,
- credential: ProviderCredential,
) -> Result<()> {
let mut batch = Vec::new();
let mut batch_tokens = 0;
@@ -1224,7 +1204,7 @@ impl SemanticIndex {
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
let batch_embeddings = embedding_provider
- .embed_batch(mem::take(&mut batch), credential.clone())
+ .embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
batch_tokens = 0;
@@ -1236,7 +1216,7 @@ impl SemanticIndex {
if !batch.is_empty() {
let batch_embeddings = embedding_provider
- .embed_batch(mem::take(&mut batch), credential)
+ .embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
@@ -220,11 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let mut queue = EmbeddingQueue::new(
- embedding_provider.clone(),
- cx.background(),
- ai::auth::ProviderCredential::NoCredentials,
- );
+ let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
for file in &files {
queue.push(file.clone());
}