corrected batching order and managed for open ai embedding errors

KCaverly created

Change summary

crates/vector_store/Cargo.toml          |   1 
crates/vector_store/src/embedding.rs    | 132 ++++++++++-----
crates/vector_store/src/vector_store.rs | 221 ++++++++++----------------
3 files changed, 175 insertions(+), 179 deletions(-)

Detailed changes

crates/vector_store/Cargo.toml 🔗

@@ -32,6 +32,7 @@ async-trait.workspace = true
 bincode = "1.3.3"
 matrixmultiply = "0.3.7"
 tiktoken-rs = "0.5.0"
+rand.workspace = true
 
 [dev-dependencies]
 gpui = { path = "../gpui", features = ["test-support"] }

crates/vector_store/src/embedding.rs 🔗

@@ -2,15 +2,20 @@ use anyhow::{anyhow, Result};
 use async_trait::async_trait;
 use futures::AsyncReadExt;
 use gpui::serde_json;
+use isahc::http::StatusCode;
 use isahc::prelude::Configurable;
+use isahc::{AsyncBody, Response};
 use lazy_static::lazy_static;
 use serde::{Deserialize, Serialize};
+use std::env;
 use std::sync::Arc;
-use std::{env, time::Instant};
+use std::time::Duration;
+use tiktoken_rs::{cl100k_base, CoreBPE};
 use util::http::{HttpClient, Request};
 
 lazy_static! {
     static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
+    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 }
 
 #[derive(Clone)]
@@ -60,69 +65,100 @@ impl EmbeddingProvider for DummyEmbeddings {
     }
 }
 
-// impl OpenAIEmbeddings {
-//     async fn truncate(span: &str) -> String {
-//         let bpe = cl100k_base().unwrap();
-//         let mut tokens = bpe.encode_with_special_tokens(span);
-//         if tokens.len() > 8192 {
-//             tokens.truncate(8192);
-//             let result = bpe.decode(tokens);
-//             if result.is_ok() {
-//                 return result.unwrap();
-//             }
-//         }
-
-//         return span.to_string();
-//     }
-// }
-
-#[async_trait]
-impl EmbeddingProvider for OpenAIEmbeddings {
-    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
-        // Truncate spans to 8192 if needed
-        // let t0 = Instant::now();
-        // let mut truncated_spans = vec![];
-        // for span in spans {
-        //     truncated_spans.push(Self::truncate(span));
-        // }
-        // let spans = futures::future::join_all(truncated_spans).await;
-        // log::info!("Truncated Spans in {:?}", t0.elapsed().as_secs());
+impl OpenAIEmbeddings {
+    async fn truncate(span: String) -> String {
+        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
+        if tokens.len() > 8190 {
+            tokens.truncate(8190);
+            let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
+            if result.is_ok() {
+                let transformed = result.unwrap();
+                // assert_ne!(transformed, span);
+                return transformed;
+            }
+        }
 
-        let api_key = OPENAI_API_KEY
-            .as_ref()
-            .ok_or_else(|| anyhow!("no api key"))?;
+        return span.to_string();
+    }
 
+    async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
         let request = Request::post("https://api.openai.com/v1/embeddings")
             .redirect_policy(isahc::config::RedirectPolicy::Follow)
             .header("Content-Type", "application/json")
             .header("Authorization", format!("Bearer {}", api_key))
             .body(
                 serde_json::to_string(&OpenAIEmbeddingRequest {
-                    input: spans,
+                    input: spans.clone(),
                     model: "text-embedding-ada-002",
                 })
                 .unwrap()
                 .into(),
             )?;
 
-        let mut response = self.client.send(request).await?;
-        if !response.status().is_success() {
-            return Err(anyhow!("openai embedding failed {}", response.status()));
-        }
+        Ok(self.client.send(request).await?)
+    }
+}
+
+#[async_trait]
+impl EmbeddingProvider for OpenAIEmbeddings {
+    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+        const BACKOFF_SECONDS: [usize; 3] = [65, 180, 360];
+        const MAX_RETRIES: usize = 3;
 
-        let mut body = String::new();
-        response.body_mut().read_to_string(&mut body).await?;
-        let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+        let api_key = OPENAI_API_KEY
+            .as_ref()
+            .ok_or_else(|| anyhow!("no api key"))?;
 
-        log::info!(
-            "openai embedding completed. tokens: {:?}",
-            response.usage.total_tokens
-        );
+        let mut request_number = 0;
+        let mut response: Response<AsyncBody>;
+        let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
+        while request_number < MAX_RETRIES {
+            response = self
+                .send_request(api_key, spans.iter().map(|x| &**x).collect())
+                .await?;
+            request_number += 1;
+
+            if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
+                return Err(anyhow!(
+                    "openai max retries, error: {:?}",
+                    &response.status()
+                ));
+            }
+
+            match response.status() {
+                StatusCode::TOO_MANY_REQUESTS => {
+                    let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
+                    std::thread::sleep(delay);
+                }
+                StatusCode::BAD_REQUEST => {
+                    log::info!("BAD REQUEST: {:?}", &response.status());
+                    // Don't worry about delaying bad request, as we can assume
+                    // we haven't been rate limited yet.
+                    for span in spans.iter_mut() {
+                        *span = Self::truncate(span.to_string()).await;
+                    }
+                }
+                StatusCode::OK => {
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+
+                    log::info!(
+                        "openai embedding completed. tokens: {:?}",
+                        response.usage.total_tokens
+                    );
+                    return Ok(response
+                        .data
+                        .into_iter()
+                        .map(|embedding| embedding.embedding)
+                        .collect());
+                }
+                _ => {
+                    return Err(anyhow!("openai embedding failed {}", response.status()));
+                }
+            }
+        }
 
-        Ok(response
-            .data
-            .into_iter()
-            .map(|embedding| embedding.embedding)
-            .collect())
+        Err(anyhow!("openai embedding failed"))
     }
 }

crates/vector_store/src/vector_store.rs 🔗

@@ -74,7 +74,6 @@ pub fn init(
             cx.subscribe_global::<WorkspaceCreated, _>({
                 let vector_store = vector_store.clone();
                 move |event, cx| {
-                    let t0 = Instant::now();
                     let workspace = &event.0;
                     if let Some(workspace) = workspace.upgrade(cx) {
                         let project = workspace.read(cx).project().clone();
@@ -126,9 +125,7 @@ pub struct VectorStore {
     language_registry: Arc<LanguageRegistry>,
     db_update_tx: channel::Sender<DbWrite>,
     // embed_batch_tx: channel::Sender<Vec<(i64, IndexedFile, Vec<String>)>>,
-    batch_files_tx: channel::Sender<(i64, IndexedFile, Vec<String>)>,
     parsing_files_tx: channel::Sender<(i64, PathBuf, Arc<Language>, SystemTime)>,
-    parsing_files_rx: channel::Receiver<(i64, PathBuf, Arc<Language>, SystemTime)>,
     _db_update_task: Task<()>,
     _embed_batch_task: Vec<Task<()>>,
     _batch_files_task: Task<()>,
@@ -220,14 +217,13 @@ impl VectorStore {
             let (embed_batch_tx, embed_batch_rx) =
                 channel::unbounded::<Vec<(i64, IndexedFile, Vec<String>)>>();
             let mut _embed_batch_task = Vec::new();
-            for _ in 0..cx.background().num_cpus() {
+            for _ in 0..1 {
+                //cx.background().num_cpus() {
                 let db_update_tx = db_update_tx.clone();
                 let embed_batch_rx = embed_batch_rx.clone();
                 let embedding_provider = embedding_provider.clone();
                 _embed_batch_task.push(cx.background().spawn(async move {
                     while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
-                        log::info!("Embedding Batch! ");
-
                         // Construct Batch
                         let mut embeddings_queue = embeddings_queue.clone();
                         let mut document_spans = vec![];
@@ -235,20 +231,20 @@ impl VectorStore {
                             document_spans.extend(document_span);
                         }
 
-                        if let Some(mut embeddings) = embedding_provider
+                        if let Ok(embeddings) = embedding_provider
                             .embed_batch(document_spans.iter().map(|x| &**x).collect())
                             .await
-                            .log_err()
                         {
                             let mut i = 0;
                             let mut j = 0;
-                            while let Some(embedding) = embeddings.pop() {
+
+                            for embedding in embeddings.iter() {
                                 while embeddings_queue[i].1.documents.len() == j {
                                     i += 1;
                                     j = 0;
                                 }
 
-                                embeddings_queue[i].1.documents[j].embedding = embedding;
+                                embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
                                 j += 1;
                             }
 
@@ -283,7 +279,6 @@ impl VectorStore {
                 while let Ok((worktree_id, indexed_file, document_spans)) =
                     batch_files_rx.recv().await
                 {
-                    log::info!("Batching File: {:?}", &indexed_file.path);
                     queue_len += &document_spans.len();
                     embeddings_queue.push((worktree_id, indexed_file, document_spans));
                     if queue_len >= EMBEDDINGS_BATCH_SIZE {
@@ -338,10 +333,7 @@ impl VectorStore {
                 embedding_provider,
                 language_registry,
                 db_update_tx,
-                // embed_batch_tx,
-                batch_files_tx,
                 parsing_files_tx,
-                parsing_files_rx,
                 _db_update_task,
                 _embed_batch_task,
                 _batch_files_task,
@@ -449,8 +441,6 @@ impl VectorStore {
         let database_url = self.database_url.clone();
         let db_update_tx = self.db_update_tx.clone();
         let parsing_files_tx = self.parsing_files_tx.clone();
-        let parsing_files_rx = self.parsing_files_rx.clone();
-        let batch_files_tx = self.batch_files_tx.clone();
 
         cx.spawn(|this, mut cx| async move {
             let t0 = Instant::now();
@@ -553,37 +543,6 @@ impl VectorStore {
                 })
                 .detach();
 
-            // cx.background()
-            //     .scoped(|scope| {
-            //         for _ in 0..cx.background().num_cpus() {
-            //             scope.spawn(async {
-            //                 let mut parser = Parser::new();
-            //                 let mut cursor = QueryCursor::new();
-            //                 while let Ok((worktree_id, file_path, language, mtime)) =
-            //                     parsing_files_rx.recv().await
-            //                 {
-            //                     log::info!("Parsing File: {:?}", &file_path);
-            //                     if let Some((indexed_file, document_spans)) = Self::index_file(
-            //                         &mut cursor,
-            //                         &mut parser,
-            //                         &fs,
-            //                         language,
-            //                         file_path.clone(),
-            //                         mtime,
-            //                     )
-            //                     .await
-            //                     .log_err()
-            //                     {
-            //                         batch_files_tx
-            //                             .try_send((worktree_id, indexed_file, document_spans))
-            //                             .unwrap();
-            //                     }
-            //                 }
-            //             });
-            //         }
-            //     })
-            //     .await;
-
             this.update(&mut cx, |this, cx| {
                 // The below is managing for updated on save
                 // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
@@ -592,90 +551,90 @@ impl VectorStore {
                     if let Some(project_state) = this.projects.get(&project.downgrade()) {
                         let worktree_db_ids = project_state.worktree_db_ids.clone();
 
-                        // if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
-                        // {
-                        //     // Iterate through changes
-                        //     let language_registry = this.language_registry.clone();
-
-                        //     let db =
-                        //         VectorDatabase::new(this.database_url.to_string_lossy().into());
-                        //     if db.is_err() {
-                        //         return;
-                        //     }
-                        //     let db = db.unwrap();
-
-                        //     let worktree_db_id: Option<i64> = {
-                        //         let mut found_db_id = None;
-                        //         for (w_id, db_id) in worktree_db_ids.into_iter() {
-                        //             if &w_id == worktree_id {
-                        //                 found_db_id = Some(db_id);
-                        //             }
-                        //         }
-
-                        //         found_db_id
-                        //     };
-
-                        //     if worktree_db_id.is_none() {
-                        //         return;
-                        //     }
-                        //     let worktree_db_id = worktree_db_id.unwrap();
-
-                        //     let file_mtimes = db.get_file_mtimes(worktree_db_id);
-                        //     if file_mtimes.is_err() {
-                        //         return;
-                        //     }
-
-                        //     let file_mtimes = file_mtimes.unwrap();
-                        //     let paths_tx = this.paths_tx.clone();
-
-                        //     smol::block_on(async move {
-                        //         for change in changes.into_iter() {
-                        //             let change_path = change.0.clone();
-                        //             log::info!("Change: {:?}", &change_path);
-                        //             if let Ok(language) = language_registry
-                        //                 .language_for_file(&change_path.to_path_buf(), None)
-                        //                 .await
-                        //             {
-                        //                 if language
-                        //                     .grammar()
-                        //                     .and_then(|grammar| grammar.embedding_config.as_ref())
-                        //                     .is_none()
-                        //                 {
-                        //                     continue;
-                        //                 }
-
-                        //                 // TODO: Make this a bit more defensive
-                        //                 let modified_time =
-                        //                     change_path.metadata().unwrap().modified().unwrap();
-                        //                 let existing_time =
-                        //                     file_mtimes.get(&change_path.to_path_buf());
-                        //                 let already_stored =
-                        //                     existing_time.map_or(false, |existing_time| {
-                        //                         if &modified_time != existing_time
-                        //                             && existing_time.elapsed().unwrap().as_secs()
-                        //                                 > REINDEXING_DELAY
-                        //                         {
-                        //                             false
-                        //                         } else {
-                        //                             true
-                        //                         }
-                        //                     });
-
-                        //                 if !already_stored {
-                        //                     log::info!("Need to reindex: {:?}", &change_path);
-                        //                     paths_tx
-                        //                         .try_send((
-                        //                             worktree_db_id,
-                        //                             change_path.to_path_buf(),
-                        //                             language,
-                        //                             modified_time,
-                        //                         ))
-                        //                         .unwrap();
-                        //                 }
-                        //             }
-                        //         }
-                        //     })
-                        // }
+                        if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
+                        {
+                            // Iterate through changes
+                            let language_registry = this.language_registry.clone();
+
+                            let db =
+                                VectorDatabase::new(this.database_url.to_string_lossy().into());
+                            if db.is_err() {
+                                return;
+                            }
+                            let db = db.unwrap();
+
+                            let worktree_db_id: Option<i64> = {
+                                let mut found_db_id = None;
+                                for (w_id, db_id) in worktree_db_ids.into_iter() {
+                                    if &w_id == worktree_id {
+                                        found_db_id = Some(db_id);
+                                    }
+                                }
+
+                                found_db_id
+                            };
+
+                            if worktree_db_id.is_none() {
+                                return;
+                            }
+                            let worktree_db_id = worktree_db_id.unwrap();
+
+                            let file_mtimes = db.get_file_mtimes(worktree_db_id);
+                            if file_mtimes.is_err() {
+                                return;
+                            }
+
+                            let file_mtimes = file_mtimes.unwrap();
+                            let parsing_files_tx = this.parsing_files_tx.clone();
+
+                            smol::block_on(async move {
+                                for change in changes.into_iter() {
+                                    let change_path = change.0.clone();
+                                    log::info!("Change: {:?}", &change_path);
+                                    if let Ok(language) = language_registry
+                                        .language_for_file(&change_path.to_path_buf(), None)
+                                        .await
+                                    {
+                                        if language
+                                            .grammar()
+                                            .and_then(|grammar| grammar.embedding_config.as_ref())
+                                            .is_none()
+                                        {
+                                            continue;
+                                        }
+
+                                        // TODO: Make this a bit more defensive
+                                        let modified_time =
+                                            change_path.metadata().unwrap().modified().unwrap();
+                                        let existing_time =
+                                            file_mtimes.get(&change_path.to_path_buf());
+                                        let already_stored =
+                                            existing_time.map_or(false, |existing_time| {
+                                                if &modified_time != existing_time
+                                                    && existing_time.elapsed().unwrap().as_secs()
+                                                        > REINDEXING_DELAY
+                                                {
+                                                    false
+                                                } else {
+                                                    true
+                                                }
+                                            });
+
+                                        if !already_stored {
+                                            log::info!("Need to reindex: {:?}", &change_path);
+                                            parsing_files_tx
+                                                .try_send((
+                                                    worktree_db_id,
+                                                    change_path.to_path_buf(),
+                                                    language,
+                                                    modified_time,
+                                                ))
+                                                .unwrap();
+                                        }
+                                    }
+                                }
+                            })
+                        }
                     }
                 });