move OpenAIEmbeddings to OpenAIEmbeddingProvider in providers folder

KCaverly created

Change summary

crates/ai/src/embedding.rs                        | 287 ----------------
crates/ai/src/providers/dummy.rs                  |  37 ++
crates/ai/src/providers/open_ai/embedding.rs      | 252 ++++++++++++++
crates/ai/src/providers/open_ai/mod.rs            |   3 
crates/semantic_index/src/semantic_index.rs       |   5 
crates/semantic_index/src/semantic_index_tests.rs |  19 
crates/zed/examples/semantic_index_eval.rs        |   4 
7 files changed, 308 insertions(+), 299 deletions(-)

Detailed changes

crates/ai/src/embedding.rs 🔗

@@ -1,30 +1,9 @@
-use anyhow::{anyhow, Result};
+use anyhow::Result;
 use async_trait::async_trait;
-use futures::AsyncReadExt;
-use gpui::executor::Background;
-use gpui::serde_json;
-use isahc::http::StatusCode;
-use isahc::prelude::Configurable;
-use isahc::{AsyncBody, Response};
-use lazy_static::lazy_static;
 use ordered_float::OrderedFloat;
-use parking_lot::Mutex;
-use parse_duration::parse;
-use postage::watch;
 use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
 use rusqlite::ToSql;
-use serde::{Deserialize, Serialize};
-use std::env;
-use std::ops::Add;
-use std::sync::Arc;
-use std::time::{Duration, Instant};
-use tiktoken_rs::{cl100k_base, CoreBPE};
-use util::http::{HttpClient, Request};
-
-lazy_static! {
-    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
-    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
-}
+use std::time::Instant;
 
 #[derive(Debug, PartialEq, Clone)]
 pub struct Embedding(pub Vec<f32>);
@@ -85,39 +64,6 @@ impl Embedding {
     }
 }
 
-#[derive(Clone)]
-pub struct OpenAIEmbeddings {
-    pub client: Arc<dyn HttpClient>,
-    pub executor: Arc<Background>,
-    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
-    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
-}
-
-#[derive(Serialize)]
-struct OpenAIEmbeddingRequest<'a> {
-    model: &'static str,
-    input: Vec<&'a str>,
-}
-
-#[derive(Deserialize)]
-struct OpenAIEmbeddingResponse {
-    data: Vec<OpenAIEmbedding>,
-    usage: OpenAIEmbeddingUsage,
-}
-
-#[derive(Debug, Deserialize)]
-struct OpenAIEmbedding {
-    embedding: Vec<f32>,
-    index: usize,
-    object: String,
-}
-
-#[derive(Deserialize)]
-struct OpenAIEmbeddingUsage {
-    prompt_tokens: usize,
-    total_tokens: usize,
-}
-
 #[async_trait]
 pub trait EmbeddingProvider: Sync + Send {
     fn is_authenticated(&self) -> bool;
@@ -127,235 +73,6 @@ pub trait EmbeddingProvider: Sync + Send {
     fn rate_limit_expiration(&self) -> Option<Instant>;
 }
 
-pub struct DummyEmbeddings {}
-
-#[async_trait]
-impl EmbeddingProvider for DummyEmbeddings {
-    fn is_authenticated(&self) -> bool {
-        true
-    }
-    fn rate_limit_expiration(&self) -> Option<Instant> {
-        None
-    }
-    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
-        // 1024 is the OpenAI Embeddings size for ada models.
-        // the model we will likely be starting with.
-        let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
-        return Ok(vec![dummy_vec; spans.len()]);
-    }
-
-    fn max_tokens_per_batch(&self) -> usize {
-        OPENAI_INPUT_LIMIT
-    }
-
-    fn truncate(&self, span: &str) -> (String, usize) {
-        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
-        let token_count = tokens.len();
-        let output = if token_count > OPENAI_INPUT_LIMIT {
-            tokens.truncate(OPENAI_INPUT_LIMIT);
-            let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
-            new_input.ok().unwrap_or_else(|| span.to_string())
-        } else {
-            span.to_string()
-        };
-
-        (output, tokens.len())
-    }
-}
-
-const OPENAI_INPUT_LIMIT: usize = 8190;
-
-impl OpenAIEmbeddings {
-    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
-        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
-        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
-
-        OpenAIEmbeddings {
-            client,
-            executor,
-            rate_limit_count_rx,
-            rate_limit_count_tx,
-        }
-    }
-
-    fn resolve_rate_limit(&self) {
-        let reset_time = *self.rate_limit_count_tx.lock().borrow();
-
-        if let Some(reset_time) = reset_time {
-            if Instant::now() >= reset_time {
-                *self.rate_limit_count_tx.lock().borrow_mut() = None
-            }
-        }
-
-        log::trace!(
-            "resolving reset time: {:?}",
-            *self.rate_limit_count_tx.lock().borrow()
-        );
-    }
-
-    fn update_reset_time(&self, reset_time: Instant) {
-        let original_time = *self.rate_limit_count_tx.lock().borrow();
-
-        let updated_time = if let Some(original_time) = original_time {
-            if reset_time < original_time {
-                Some(reset_time)
-            } else {
-                Some(original_time)
-            }
-        } else {
-            Some(reset_time)
-        };
-
-        log::trace!("updating rate limit time: {:?}", updated_time);
-
-        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
-    }
-    async fn send_request(
-        &self,
-        api_key: &str,
-        spans: Vec<&str>,
-        request_timeout: u64,
-    ) -> Result<Response<AsyncBody>> {
-        let request = Request::post("https://api.openai.com/v1/embeddings")
-            .redirect_policy(isahc::config::RedirectPolicy::Follow)
-            .timeout(Duration::from_secs(request_timeout))
-            .header("Content-Type", "application/json")
-            .header("Authorization", format!("Bearer {}", api_key))
-            .body(
-                serde_json::to_string(&OpenAIEmbeddingRequest {
-                    input: spans.clone(),
-                    model: "text-embedding-ada-002",
-                })
-                .unwrap()
-                .into(),
-            )?;
-
-        Ok(self.client.send(request).await?)
-    }
-}
-
-#[async_trait]
-impl EmbeddingProvider for OpenAIEmbeddings {
-    fn is_authenticated(&self) -> bool {
-        OPENAI_API_KEY.as_ref().is_some()
-    }
-    fn max_tokens_per_batch(&self) -> usize {
-        50000
-    }
-
-    fn rate_limit_expiration(&self) -> Option<Instant> {
-        *self.rate_limit_count_rx.borrow()
-    }
-    fn truncate(&self, span: &str) -> (String, usize) {
-        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
-        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
-            tokens.truncate(OPENAI_INPUT_LIMIT);
-            OPENAI_BPE_TOKENIZER
-                .decode(tokens.clone())
-                .ok()
-                .unwrap_or_else(|| span.to_string())
-        } else {
-            span.to_string()
-        };
-
-        (output, tokens.len())
-    }
-
-    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
-        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
-        const MAX_RETRIES: usize = 4;
-
-        let api_key = OPENAI_API_KEY
-            .as_ref()
-            .ok_or_else(|| anyhow!("no api key"))?;
-
-        let mut request_number = 0;
-        let mut rate_limiting = false;
-        let mut request_timeout: u64 = 15;
-        let mut response: Response<AsyncBody>;
-        while request_number < MAX_RETRIES {
-            response = self
-                .send_request(
-                    api_key,
-                    spans.iter().map(|x| &**x).collect(),
-                    request_timeout,
-                )
-                .await?;
-
-            request_number += 1;
-
-            match response.status() {
-                StatusCode::REQUEST_TIMEOUT => {
-                    request_timeout += 5;
-                }
-                StatusCode::OK => {
-                    let mut body = String::new();
-                    response.body_mut().read_to_string(&mut body).await?;
-                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
-
-                    log::trace!(
-                        "openai embedding completed. tokens: {:?}",
-                        response.usage.total_tokens
-                    );
-
-                    // If we complete a request successfully that was previously rate_limited
-                    // resolve the rate limit
-                    if rate_limiting {
-                        self.resolve_rate_limit()
-                    }
-
-                    return Ok(response
-                        .data
-                        .into_iter()
-                        .map(|embedding| Embedding::from(embedding.embedding))
-                        .collect());
-                }
-                StatusCode::TOO_MANY_REQUESTS => {
-                    rate_limiting = true;
-                    let mut body = String::new();
-                    response.body_mut().read_to_string(&mut body).await?;
-
-                    let delay_duration = {
-                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
-                        if let Some(time_to_reset) =
-                            response.headers().get("x-ratelimit-reset-tokens")
-                        {
-                            if let Ok(time_str) = time_to_reset.to_str() {
-                                parse(time_str).unwrap_or(delay)
-                            } else {
-                                delay
-                            }
-                        } else {
-                            delay
-                        }
-                    };
-
-                    // If we've previously rate limited, increment the duration but not the count
-                    let reset_time = Instant::now().add(delay_duration);
-                    self.update_reset_time(reset_time);
-
-                    log::trace!(
-                        "openai rate limiting: waiting {:?} until lifted",
-                        &delay_duration
-                    );
-
-                    self.executor.timer(delay_duration).await;
-                }
-                _ => {
-                    let mut body = String::new();
-                    response.body_mut().read_to_string(&mut body).await?;
-                    return Err(anyhow!(
-                        "open ai bad request: {:?} {:?}",
-                        &response.status(),
-                        body
-                    ));
-                }
-            }
-        }
-        Err(anyhow!("openai max retries"))
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;

crates/ai/src/providers/dummy.rs 🔗

@@ -1,4 +1,10 @@
-use crate::completion::CompletionRequest;
+use std::time::Instant;
+
+use crate::{
+    completion::CompletionRequest,
+    embedding::{Embedding, EmbeddingProvider},
+};
+use async_trait::async_trait;
 use serde::Serialize;
 
 #[derive(Serialize)]
@@ -11,3 +17,32 @@ impl CompletionRequest for DummyCompletionRequest {
         serde_json::to_string(self)
     }
 }
+
+pub struct DummyEmbeddingProvider {}
+
+#[async_trait]
+impl EmbeddingProvider for DummyEmbeddingProvider {
+    fn is_authenticated(&self) -> bool {
+        true
+    }
+    fn rate_limit_expiration(&self) -> Option<Instant> {
+        None
+    }
+    async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
+        // 1024 is the OpenAI Embeddings size for ada models.
+        // the model we will likely be starting with.
+        let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
+        return Ok(vec![dummy_vec; spans.len()]);
+    }
+
+    fn max_tokens_per_batch(&self) -> usize {
+        8190
+    }
+
+    fn truncate(&self, span: &str) -> (String, usize) {
+        let truncated = span.chars().collect::<Vec<char>>()[..8190]
+            .iter()
+            .collect::<String>();
+        (truncated, 8190)
+    }
+}

crates/ai/src/providers/open_ai/embedding.rs 🔗

@@ -0,0 +1,252 @@
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use futures::AsyncReadExt;
+use gpui::executor::Background;
+use gpui::serde_json;
+use isahc::http::StatusCode;
+use isahc::prelude::Configurable;
+use isahc::{AsyncBody, Response};
+use lazy_static::lazy_static;
+use parking_lot::Mutex;
+use parse_duration::parse;
+use postage::watch;
+use serde::{Deserialize, Serialize};
+use std::env;
+use std::ops::Add;
+use std::sync::Arc;
+use std::time::{Duration, Instant};
+use tiktoken_rs::{cl100k_base, CoreBPE};
+use util::http::{HttpClient, Request};
+
+use crate::embedding::{Embedding, EmbeddingProvider};
+
+lazy_static! {
+    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
+    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
+}
+
+#[derive(Clone)]
+pub struct OpenAIEmbeddingProvider {
+    pub client: Arc<dyn HttpClient>,
+    pub executor: Arc<Background>,
+    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
+    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
+}
+
+#[derive(Serialize)]
+struct OpenAIEmbeddingRequest<'a> {
+    model: &'static str,
+    input: Vec<&'a str>,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingResponse {
+    data: Vec<OpenAIEmbedding>,
+    usage: OpenAIEmbeddingUsage,
+}
+
+#[derive(Debug, Deserialize)]
+struct OpenAIEmbedding {
+    embedding: Vec<f32>,
+    index: usize,
+    object: String,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingUsage {
+    prompt_tokens: usize,
+    total_tokens: usize,
+}
+
+const OPENAI_INPUT_LIMIT: usize = 8190;
+
+impl OpenAIEmbeddingProvider {
+    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
+        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
+        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
+
+        OpenAIEmbeddingProvider {
+            client,
+            executor,
+            rate_limit_count_rx,
+            rate_limit_count_tx,
+        }
+    }
+
+    fn resolve_rate_limit(&self) {
+        let reset_time = *self.rate_limit_count_tx.lock().borrow();
+
+        if let Some(reset_time) = reset_time {
+            if Instant::now() >= reset_time {
+                *self.rate_limit_count_tx.lock().borrow_mut() = None
+            }
+        }
+
+        log::trace!(
+            "resolving reset time: {:?}",
+            *self.rate_limit_count_tx.lock().borrow()
+        );
+    }
+
+    fn update_reset_time(&self, reset_time: Instant) {
+        let original_time = *self.rate_limit_count_tx.lock().borrow();
+
+        let updated_time = if let Some(original_time) = original_time {
+            if reset_time < original_time {
+                Some(reset_time)
+            } else {
+                Some(original_time)
+            }
+        } else {
+            Some(reset_time)
+        };
+
+        log::trace!("updating rate limit time: {:?}", updated_time);
+
+        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
+    }
+    async fn send_request(
+        &self,
+        api_key: &str,
+        spans: Vec<&str>,
+        request_timeout: u64,
+    ) -> Result<Response<AsyncBody>> {
+        let request = Request::post("https://api.openai.com/v1/embeddings")
+            .redirect_policy(isahc::config::RedirectPolicy::Follow)
+            .timeout(Duration::from_secs(request_timeout))
+            .header("Content-Type", "application/json")
+            .header("Authorization", format!("Bearer {}", api_key))
+            .body(
+                serde_json::to_string(&OpenAIEmbeddingRequest {
+                    input: spans.clone(),
+                    model: "text-embedding-ada-002",
+                })
+                .unwrap()
+                .into(),
+            )?;
+
+        Ok(self.client.send(request).await?)
+    }
+}
+
+#[async_trait]
+impl EmbeddingProvider for OpenAIEmbeddingProvider {
+    fn is_authenticated(&self) -> bool {
+        OPENAI_API_KEY.as_ref().is_some()
+    }
+    fn max_tokens_per_batch(&self) -> usize {
+        50000
+    }
+
+    fn rate_limit_expiration(&self) -> Option<Instant> {
+        *self.rate_limit_count_rx.borrow()
+    }
+    fn truncate(&self, span: &str) -> (String, usize) {
+        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
+            tokens.truncate(OPENAI_INPUT_LIMIT);
+            OPENAI_BPE_TOKENIZER
+                .decode(tokens.clone())
+                .ok()
+                .unwrap_or_else(|| span.to_string())
+        } else {
+            span.to_string()
+        };
+
+        (output, tokens.len())
+    }
+
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
+        const MAX_RETRIES: usize = 4;
+
+        let api_key = OPENAI_API_KEY
+            .as_ref()
+            .ok_or_else(|| anyhow!("no api key"))?;
+
+        let mut request_number = 0;
+        let mut rate_limiting = false;
+        let mut request_timeout: u64 = 15;
+        let mut response: Response<AsyncBody>;
+        while request_number < MAX_RETRIES {
+            response = self
+                .send_request(
+                    api_key,
+                    spans.iter().map(|x| &**x).collect(),
+                    request_timeout,
+                )
+                .await?;
+
+            request_number += 1;
+
+            match response.status() {
+                StatusCode::REQUEST_TIMEOUT => {
+                    request_timeout += 5;
+                }
+                StatusCode::OK => {
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+
+                    log::trace!(
+                        "openai embedding completed. tokens: {:?}",
+                        response.usage.total_tokens
+                    );
+
+                    // If we complete a request successfully that was previously rate_limited
+                    // resolve the rate limit
+                    if rate_limiting {
+                        self.resolve_rate_limit()
+                    }
+
+                    return Ok(response
+                        .data
+                        .into_iter()
+                        .map(|embedding| Embedding::from(embedding.embedding))
+                        .collect());
+                }
+                StatusCode::TOO_MANY_REQUESTS => {
+                    rate_limiting = true;
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+
+                    let delay_duration = {
+                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
+                        if let Some(time_to_reset) =
+                            response.headers().get("x-ratelimit-reset-tokens")
+                        {
+                            if let Ok(time_str) = time_to_reset.to_str() {
+                                parse(time_str).unwrap_or(delay)
+                            } else {
+                                delay
+                            }
+                        } else {
+                            delay
+                        }
+                    };
+
+                    // If we've previously rate limited, increment the duration but not the count
+                    let reset_time = Instant::now().add(delay_duration);
+                    self.update_reset_time(reset_time);
+
+                    log::trace!(
+                        "openai rate limiting: waiting {:?} until lifted",
+                        &delay_duration
+                    );
+
+                    self.executor.timer(delay_duration).await;
+                }
+                _ => {
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+                    return Err(anyhow!(
+                        "open ai bad request: {:?} {:?}",
+                        &response.status(),
+                        body
+                    ));
+                }
+            }
+        }
+        Err(anyhow!("openai max retries"))
+    }
+}

crates/semantic_index/src/semantic_index.rs 🔗

@@ -7,7 +7,8 @@ pub mod semantic_index_settings;
 mod semantic_index_tests;
 
 use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
+use ai::embedding::{Embedding, EmbeddingProvider};
+use ai::providers::open_ai::OpenAIEmbeddingProvider;
 use anyhow::{anyhow, Result};
 use collections::{BTreeMap, HashMap, HashSet};
 use db::VectorDatabase;
@@ -88,7 +89,7 @@ pub fn init(
         let semantic_index = SemanticIndex::new(
             fs,
             db_file_path,
-            Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+            Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
             language_registry,
             cx.clone(),
         )

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -4,7 +4,8 @@ use crate::{
     semantic_index_settings::SemanticIndexSettings,
     FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
 };
-use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
+use ai::embedding::{Embedding, EmbeddingProvider};
+use ai::providers::dummy::DummyEmbeddingProvider;
 use anyhow::Result;
 use async_trait::async_trait;
 use gpui::{executor::Deterministic, Task, TestAppContext};
@@ -280,7 +281,7 @@ fn assert_search_results(
 #[gpui::test]
 async fn test_code_context_retrieval_rust() {
     let language = rust_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
@@ -382,7 +383,7 @@ async fn test_code_context_retrieval_rust() {
 #[gpui::test]
 async fn test_code_context_retrieval_json() {
     let language = json_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -466,7 +467,7 @@ fn assert_documents_eq(
 #[gpui::test]
 async fn test_code_context_retrieval_javascript() {
     let language = js_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
@@ -565,7 +566,7 @@ async fn test_code_context_retrieval_javascript() {
 #[gpui::test]
 async fn test_code_context_retrieval_lua() {
     let language = lua_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -639,7 +640,7 @@ async fn test_code_context_retrieval_lua() {
 #[gpui::test]
 async fn test_code_context_retrieval_elixir() {
     let language = elixir_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -756,7 +757,7 @@ async fn test_code_context_retrieval_elixir() {
 #[gpui::test]
 async fn test_code_context_retrieval_cpp() {
     let language = cpp_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
@@ -909,7 +910,7 @@ async fn test_code_context_retrieval_cpp() {
 #[gpui::test]
 async fn test_code_context_retrieval_ruby() {
     let language = ruby_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -1100,7 +1101,7 @@ async fn test_code_context_retrieval_ruby() {
 #[gpui::test]
 async fn test_code_context_retrieval_php() {
     let language = php_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(DummyEmbeddingProvider {});
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"

crates/zed/examples/semantic_index_eval.rs 🔗

@@ -1,4 +1,4 @@
-use ai::embedding::OpenAIEmbeddings;
+use ai::providers::open_ai::OpenAIEmbeddingProvider;
 use anyhow::{anyhow, Result};
 use client::{self, UserStore};
 use gpui::{AsyncAppContext, ModelHandle, Task};
@@ -474,7 +474,7 @@ fn main() {
             let semantic_index = SemanticIndex::new(
                 fs.clone(),
                 db_file_path,
-                Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+                Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
                 languages.clone(),
                 cx.clone(),
             )