Change summary
crates/semantic_index/src/embedding.rs | 7 +++++++
crates/semantic_index/src/semantic_index.rs | 10 +++-------
crates/semantic_index/src/semantic_index_tests.rs | 3 +++
3 files changed, 13 insertions(+), 7 deletions(-)
Detailed changes
@@ -117,6 +117,7 @@ struct OpenAIEmbeddingUsage {
#[async_trait]
pub trait EmbeddingProvider: Sync + Send {
+ fn is_authenticated(&self) -> bool;
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn truncate(&self, span: &str) -> (String, usize);
@@ -127,6 +128,9 @@ pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
+ fn is_authenticated(&self) -> bool {
+ true
+ }
fn rate_limit_expiration(&self) -> Option<Instant> {
None
}
@@ -229,6 +233,9 @@ impl OpenAIEmbeddings {
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
+ fn is_authenticated(&self) -> bool {
+ OPENAI_API_KEY.as_ref().is_some()
+ }
fn max_tokens_per_batch(&self) -> usize {
50000
}
@@ -281,12 +281,8 @@ impl SemanticIndex {
settings::get::<SemanticIndexSettings>(cx).enabled
}
- pub fn has_api_key(&self) -> bool {
- OPENAI_API_KEY.as_ref().is_some()
- }
-
pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
- if !self.has_api_key() {
+ if !self.embedding_provider.is_authenticated() {
return SemanticIndexStatus::NotAuthenticated;
}
@@ -980,8 +976,8 @@ impl SemanticIndex {
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
- if !self.has_api_key() {
- return Task::ready(Err(anyhow!("no open ai key present")));
+ if !self.embedding_provider.is_authenticated() {
+ return Task::ready(Err(anyhow!("user is not authenticated")));
}
if !self.projects.contains_key(&project.downgrade()) {
@@ -1267,6 +1267,9 @@ impl FakeEmbeddingProvider {
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
+ fn is_authenticated(&self) -> bool {
+ true
+ }
fn truncate(&self, span: &str) -> (String, usize) {
(span.to_string(), 1)
}