diff --git a/crates/semantic_index/src/embedding.rs b/crates/semantic_index/src/embedding.rs index b0124bf7df2664f1b3f237edd601a8e59b196fbd..2b6e94854e2c0e8e21f9df11874f45d35dca8a59 100644 --- a/crates/semantic_index/src/embedding.rs +++ b/crates/semantic_index/src/embedding.rs @@ -117,6 +117,7 @@ struct OpenAIEmbeddingUsage { #[async_trait] pub trait EmbeddingProvider: Sync + Send { + fn is_authenticated(&self) -> bool; async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; fn truncate(&self, span: &str) -> (String, usize); @@ -127,6 +128,9 @@ pub struct DummyEmbeddings {} #[async_trait] impl EmbeddingProvider for DummyEmbeddings { + fn is_authenticated(&self) -> bool { + true + } fn rate_limit_expiration(&self) -> Option { None } @@ -229,6 +233,9 @@ impl OpenAIEmbeddings { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { + fn is_authenticated(&self) -> bool { + OPENAI_API_KEY.as_ref().is_some() + } fn max_tokens_per_batch(&self) -> usize { 50000 } diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index b6ad75a34ea9db1c2fea0a2a55bcfd86c6902a29..1ba0001cfda9c6be7128e1378bc75750cd1842da 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -281,12 +281,8 @@ impl SemanticIndex { settings::get::(cx).enabled } - pub fn has_api_key(&self) -> bool { - OPENAI_API_KEY.as_ref().is_some() - } - pub fn status(&self, project: &ModelHandle) -> SemanticIndexStatus { - if !self.has_api_key() { + if !self.embedding_provider.is_authenticated() { return SemanticIndexStatus::NotAuthenticated; } @@ -980,8 +976,8 @@ impl SemanticIndex { project: ModelHandle, cx: &mut ModelContext, ) -> Task> { - if !self.has_api_key() { - return Task::ready(Err(anyhow!("no open ai key present"))); + if !self.embedding_provider.is_authenticated() { + return Task::ready(Err(anyhow!("user is not authenticated"))); } if !self.projects.contains_key(&project.downgrade()) { diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 9035327b2e8d910312ebb2f7c117a51d300f5d6a..f386665915b1d9b9314c7246da93daec6624900e 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1267,6 +1267,9 @@ impl FakeEmbeddingProvider { #[async_trait] impl EmbeddingProvider for FakeEmbeddingProvider { + fn is_authenticated(&self) -> bool { + true + } fn truncate(&self, span: &str) -> (String, usize) { (span.to_string(), 1) }