parellelize embedding api calls

KCaverly created

Change summary

crates/semantic_index/src/embedding.rs      |  6 ++
crates/semantic_index/src/semantic_index.rs | 54 +++++++++++++++-------
2 files changed, 42 insertions(+), 18 deletions(-)

Detailed changes

crates/semantic_index/src/embedding.rs 🔗

@@ -106,7 +106,7 @@ impl OpenAIEmbeddings {
 #[async_trait]
 impl EmbeddingProvider for OpenAIEmbeddings {
     async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
-        const BACKOFF_SECONDS: [usize; 3] = [65, 180, 360];
+        const BACKOFF_SECONDS: [usize; 3] = [45, 75, 125];
         const MAX_RETRIES: usize = 3;
 
         let api_key = OPENAI_API_KEY
@@ -133,6 +133,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
             match response.status() {
                 StatusCode::TOO_MANY_REQUESTS => {
                     let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
+                    log::trace!(
+                        "open ai rate limiting, delaying request by {:?} seconds",
+                        delay.as_secs()
+                    );
                     self.executor.timer(delay).await;
                 }
                 StatusCode::BAD_REQUEST => {

crates/semantic_index/src/semantic_index.rs 🔗

@@ -24,7 +24,7 @@ use std::{
     ops::Range,
     path::{Path, PathBuf},
     sync::{Arc, Weak},
-    time::SystemTime,
+    time::{Instant, SystemTime},
 };
 use util::{
     channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
@@ -34,7 +34,7 @@ use util::{
 };
 
 const SEMANTIC_INDEX_VERSION: usize = 4;
-const EMBEDDINGS_BATCH_SIZE: usize = 150;
+const EMBEDDINGS_BATCH_SIZE: usize = 80;
 
 pub fn init(
     fs: Arc<dyn Fs>,
@@ -84,7 +84,7 @@ pub struct SemanticIndex {
     db_update_tx: channel::Sender<DbOperation>,
     parsing_files_tx: channel::Sender<PendingFile>,
     _db_update_task: Task<()>,
-    _embed_batch_task: Task<()>,
+    _embed_batch_tasks: Vec<Task<()>>,
     _batch_files_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
@@ -189,6 +189,7 @@ impl SemanticIndex {
         language_registry: Arc<LanguageRegistry>,
         mut cx: AsyncAppContext,
     ) -> Result<ModelHandle<Self>> {
+        let t0 = Instant::now();
         let database_url = Arc::new(database_url);
 
         let db = cx
@@ -196,7 +197,13 @@ impl SemanticIndex {
             .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
             .await?;
 
+        log::trace!(
+            "db initialization took {:?} milliseconds",
+            t0.elapsed().as_millis()
+        );
+
         Ok(cx.add_model(|cx| {
+            let t0 = Instant::now();
             // Perform database operations
             let (db_update_tx, db_update_rx) = channel::unbounded();
             let _db_update_task = cx.background().spawn({
@@ -210,20 +217,24 @@ impl SemanticIndex {
             // Group documents into batches and send them to the embedding provider.
             let (embed_batch_tx, embed_batch_rx) =
                 channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
-            let _embed_batch_task = cx.background().spawn({
-                let db_update_tx = db_update_tx.clone();
-                let embedding_provider = embedding_provider.clone();
-                async move {
-                    while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
-                        Self::compute_embeddings_for_batch(
-                            embeddings_queue,
-                            &embedding_provider,
-                            &db_update_tx,
-                        )
-                        .await;
+            let mut _embed_batch_tasks = Vec::new();
+            for _ in 0..cx.background().num_cpus() {
+                let embed_batch_rx = embed_batch_rx.clone();
+                _embed_batch_tasks.push(cx.background().spawn({
+                    let db_update_tx = db_update_tx.clone();
+                    let embedding_provider = embedding_provider.clone();
+                    async move {
+                        while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
+                            Self::compute_embeddings_for_batch(
+                                embeddings_queue,
+                                &embedding_provider,
+                                &db_update_tx,
+                            )
+                            .await;
+                        }
                     }
-                }
-            });
+                }));
+            }
 
             // Group documents into batches and send them to the embedding provider.
             let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
@@ -264,6 +275,10 @@ impl SemanticIndex {
                 }));
             }
 
+            log::trace!(
+                "semantic index task initialization took {:?} milliseconds",
+                t0.elapsed().as_millis()
+            );
             Self {
                 fs,
                 database_url,
@@ -272,7 +287,7 @@ impl SemanticIndex {
                 db_update_tx,
                 parsing_files_tx,
                 _db_update_task,
-                _embed_batch_task,
+                _embed_batch_tasks,
                 _batch_files_task,
                 _parsing_files_tasks,
                 projects: HashMap::new(),
@@ -460,6 +475,7 @@ impl SemanticIndex {
         project: ModelHandle<Project>,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<(usize, watch::Receiver<usize>)>> {
+        let t0 = Instant::now();
         let worktree_scans_complete = project
             .read(cx)
             .worktrees(cx)
@@ -577,6 +593,10 @@ impl SemanticIndex {
                         }
                     }
 
+                    log::trace!(
+                        "walking worktree took {:?} milliseconds",
+                        t0.elapsed().as_millis()
+                    );
                     anyhow::Ok((count, job_count_rx))
                 })
                 .await