Change summary
crates/ai/src/providers/open_ai/embedding.rs | 11 ++++++-----
crates/semantic_index/src/embedding_queue.rs | 2 +-
crates/semantic_index/src/semantic_index.rs | 17 ++++++-----------
3 files changed, 13 insertions(+), 17 deletions(-)
Detailed changes
@@ -162,14 +162,15 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
async fn embed_batch(
&self,
spans: Vec<String>,
- _credential: ProviderCredential,
+ credential: ProviderCredential,
) -> Result<Vec<Embedding>> {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
- let api_key = OPENAI_API_KEY
- .as_ref()
- .ok_or_else(|| anyhow!("no api key"))?;
+ let api_key = match credential {
+ ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key),
+ _ => Err(anyhow!("no api key provided")),
+ }?;
let mut request_number = 0;
let mut rate_limiting = false;
@@ -178,7 +179,7 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider {
while request_number < MAX_RETRIES {
response = self
.send_request(
- api_key,
+ &api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
@@ -41,7 +41,7 @@ pub struct EmbeddingQueue {
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
- provider_credential: ProviderCredential,
+ pub provider_credential: ProviderCredential,
}
#[derive(Clone)]
@@ -281,15 +281,13 @@ impl SemanticIndex {
}
pub fn authenticate(&mut self, cx: &AppContext) -> bool {
- let credential = self.provider_credential.clone();
- match credential {
- ProviderCredential::NoCredentials => {
- let credential = self.embedding_provider.retrieve_credentials(cx);
- self.provider_credential = credential;
- }
- _ => {}
- }
+ let existing_credential = self.provider_credential.clone();
+ let credential = match existing_credential {
+ ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx),
+ _ => existing_credential,
+ };
+ self.provider_credential = credential.clone();
self.embedding_queue.lock().set_credential(credential);
self.is_authenticated()
}
@@ -1020,14 +1018,11 @@ impl SemanticIndex {
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
if !self.is_authenticated() {
- println!("Authenticating");
if !self.authenticate(cx) {
return Task::ready(Err(anyhow!("user is not authenticated")));
}
}
- println!("SHOULD NOW BE AUTHENTICATED");
-
if !self.projects.contains_key(&project.downgrade()) {
let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {