@@ -2,15 +2,20 @@ use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::serde_json;
+use isahc::http::StatusCode;
use isahc::prelude::Configurable;
+use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
+use std::env;
use std::sync::Arc;
-use std::{env, time::Instant};
+use std::time::Duration;
+use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
+ static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
#[derive(Clone)]
@@ -60,69 +65,100 @@ impl EmbeddingProvider for DummyEmbeddings {
}
}
-// impl OpenAIEmbeddings {
-// async fn truncate(span: &str) -> String {
-// let bpe = cl100k_base().unwrap();
-// let mut tokens = bpe.encode_with_special_tokens(span);
-// if tokens.len() > 8192 {
-// tokens.truncate(8192);
-// let result = bpe.decode(tokens);
-// if result.is_ok() {
-// return result.unwrap();
-// }
-// }
-
-// return span.to_string();
-// }
-// }
-
-#[async_trait]
-impl EmbeddingProvider for OpenAIEmbeddings {
- async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
- // Truncate spans to 8192 if needed
- // let t0 = Instant::now();
- // let mut truncated_spans = vec![];
- // for span in spans {
- // truncated_spans.push(Self::truncate(span));
- // }
- // let spans = futures::future::join_all(truncated_spans).await;
- // log::info!("Truncated Spans in {:?}", t0.elapsed().as_secs());
+impl OpenAIEmbeddings {
+ async fn truncate(span: String) -> String {
+ let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
+ if tokens.len() > 8190 {
+ tokens.truncate(8190);
+ let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
+ if result.is_ok() {
+ let transformed = result.unwrap();
+ // assert_ne!(transformed, span);
+ return transformed;
+ }
+ }
- let api_key = OPENAI_API_KEY
- .as_ref()
- .ok_or_else(|| anyhow!("no api key"))?;
+ return span.to_string();
+ }
+ async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
let request = Request::post("https://api.openai.com/v1/embeddings")
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(
serde_json::to_string(&OpenAIEmbeddingRequest {
- input: spans,
+ input: spans.clone(),
model: "text-embedding-ada-002",
})
.unwrap()
.into(),
)?;
- let mut response = self.client.send(request).await?;
- if !response.status().is_success() {
- return Err(anyhow!("openai embedding failed {}", response.status()));
- }
+ Ok(self.client.send(request).await?)
+ }
+}
+
+#[async_trait]
+impl EmbeddingProvider for OpenAIEmbeddings {
+ async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+ const BACKOFF_SECONDS: [usize; 3] = [65, 180, 360];
+ const MAX_RETRIES: usize = 3;
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+ let api_key = OPENAI_API_KEY
+ .as_ref()
+ .ok_or_else(|| anyhow!("no api key"))?;
- log::info!(
- "openai embedding completed. tokens: {:?}",
- response.usage.total_tokens
- );
+ let mut request_number = 0;
+ let mut response: Response<AsyncBody>;
+ let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
+ while request_number < MAX_RETRIES {
+ response = self
+ .send_request(api_key, spans.iter().map(|x| &**x).collect())
+ .await?;
+ request_number += 1;
+
+ if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
+ return Err(anyhow!(
+ "openai max retries, error: {:?}",
+ &response.status()
+ ));
+ }
+
+ match response.status() {
+ StatusCode::TOO_MANY_REQUESTS => {
+ let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
+ std::thread::sleep(delay);
+ }
+ StatusCode::BAD_REQUEST => {
+ log::info!("BAD REQUEST: {:?}", &response.status());
+ // Don't worry about delaying bad request, as we can assume
+ // we haven't been rate limited yet.
+ for span in spans.iter_mut() {
+ *span = Self::truncate(span.to_string()).await;
+ }
+ }
+ StatusCode::OK => {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+
+ log::info!(
+ "openai embedding completed. tokens: {:?}",
+ response.usage.total_tokens
+ );
+ return Ok(response
+ .data
+ .into_iter()
+ .map(|embedding| embedding.embedding)
+ .collect());
+ }
+ _ => {
+ return Err(anyhow!("openai embedding failed {}", response.status()));
+ }
+ }
+ }
- Ok(response
- .data
- .into_iter()
- .map(|embedding| embedding.embedding)
- .collect())
+ Err(anyhow!("openai embedding failed"))
}
}
@@ -74,7 +74,6 @@ pub fn init(
cx.subscribe_global::<WorkspaceCreated, _>({
let vector_store = vector_store.clone();
move |event, cx| {
- let t0 = Instant::now();
let workspace = &event.0;
if let Some(workspace) = workspace.upgrade(cx) {
let project = workspace.read(cx).project().clone();
@@ -126,9 +125,7 @@ pub struct VectorStore {
language_registry: Arc<LanguageRegistry>,
db_update_tx: channel::Sender<DbWrite>,
// embed_batch_tx: channel::Sender<Vec<(i64, IndexedFile, Vec<String>)>>,
- batch_files_tx: channel::Sender<(i64, IndexedFile, Vec<String>)>,
parsing_files_tx: channel::Sender<(i64, PathBuf, Arc<Language>, SystemTime)>,
- parsing_files_rx: channel::Receiver<(i64, PathBuf, Arc<Language>, SystemTime)>,
_db_update_task: Task<()>,
_embed_batch_task: Vec<Task<()>>,
_batch_files_task: Task<()>,
@@ -220,14 +217,13 @@ impl VectorStore {
let (embed_batch_tx, embed_batch_rx) =
channel::unbounded::<Vec<(i64, IndexedFile, Vec<String>)>>();
let mut _embed_batch_task = Vec::new();
- for _ in 0..cx.background().num_cpus() {
+ for _ in 0..1 {
+ //cx.background().num_cpus() {
let db_update_tx = db_update_tx.clone();
let embed_batch_rx = embed_batch_rx.clone();
let embedding_provider = embedding_provider.clone();
_embed_batch_task.push(cx.background().spawn(async move {
while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
- log::info!("Embedding Batch! ");
-
// Construct Batch
let mut embeddings_queue = embeddings_queue.clone();
let mut document_spans = vec![];
@@ -235,20 +231,20 @@ impl VectorStore {
document_spans.extend(document_span);
}
- if let Some(mut embeddings) = embedding_provider
+ if let Ok(embeddings) = embedding_provider
.embed_batch(document_spans.iter().map(|x| &**x).collect())
.await
- .log_err()
{
let mut i = 0;
let mut j = 0;
- while let Some(embedding) = embeddings.pop() {
+
+ for embedding in embeddings.iter() {
while embeddings_queue[i].1.documents.len() == j {
i += 1;
j = 0;
}
- embeddings_queue[i].1.documents[j].embedding = embedding;
+ embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
j += 1;
}
@@ -283,7 +279,6 @@ impl VectorStore {
while let Ok((worktree_id, indexed_file, document_spans)) =
batch_files_rx.recv().await
{
- log::info!("Batching File: {:?}", &indexed_file.path);
queue_len += &document_spans.len();
embeddings_queue.push((worktree_id, indexed_file, document_spans));
if queue_len >= EMBEDDINGS_BATCH_SIZE {
@@ -338,10 +333,7 @@ impl VectorStore {
embedding_provider,
language_registry,
db_update_tx,
- // embed_batch_tx,
- batch_files_tx,
parsing_files_tx,
- parsing_files_rx,
_db_update_task,
_embed_batch_task,
_batch_files_task,
@@ -449,8 +441,6 @@ impl VectorStore {
let database_url = self.database_url.clone();
let db_update_tx = self.db_update_tx.clone();
let parsing_files_tx = self.parsing_files_tx.clone();
- let parsing_files_rx = self.parsing_files_rx.clone();
- let batch_files_tx = self.batch_files_tx.clone();
cx.spawn(|this, mut cx| async move {
let t0 = Instant::now();
@@ -553,37 +543,6 @@ impl VectorStore {
})
.detach();
- // cx.background()
- // .scoped(|scope| {
- // for _ in 0..cx.background().num_cpus() {
- // scope.spawn(async {
- // let mut parser = Parser::new();
- // let mut cursor = QueryCursor::new();
- // while let Ok((worktree_id, file_path, language, mtime)) =
- // parsing_files_rx.recv().await
- // {
- // log::info!("Parsing File: {:?}", &file_path);
- // if let Some((indexed_file, document_spans)) = Self::index_file(
- // &mut cursor,
- // &mut parser,
- // &fs,
- // language,
- // file_path.clone(),
- // mtime,
- // )
- // .await
- // .log_err()
- // {
- // batch_files_tx
- // .try_send((worktree_id, indexed_file, document_spans))
- // .unwrap();
- // }
- // }
- // });
- // }
- // })
- // .await;
-
this.update(&mut cx, |this, cx| {
// The below is managing for updated on save
// Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
@@ -592,90 +551,90 @@ impl VectorStore {
if let Some(project_state) = this.projects.get(&project.downgrade()) {
let worktree_db_ids = project_state.worktree_db_ids.clone();
- // if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
- // {
- // // Iterate through changes
- // let language_registry = this.language_registry.clone();
-
- // let db =
- // VectorDatabase::new(this.database_url.to_string_lossy().into());
- // if db.is_err() {
- // return;
- // }
- // let db = db.unwrap();
-
- // let worktree_db_id: Option<i64> = {
- // let mut found_db_id = None;
- // for (w_id, db_id) in worktree_db_ids.into_iter() {
- // if &w_id == worktree_id {
- // found_db_id = Some(db_id);
- // }
- // }
-
- // found_db_id
- // };
-
- // if worktree_db_id.is_none() {
- // return;
- // }
- // let worktree_db_id = worktree_db_id.unwrap();
-
- // let file_mtimes = db.get_file_mtimes(worktree_db_id);
- // if file_mtimes.is_err() {
- // return;
- // }
-
- // let file_mtimes = file_mtimes.unwrap();
- // let paths_tx = this.paths_tx.clone();
-
- // smol::block_on(async move {
- // for change in changes.into_iter() {
- // let change_path = change.0.clone();
- // log::info!("Change: {:?}", &change_path);
- // if let Ok(language) = language_registry
- // .language_for_file(&change_path.to_path_buf(), None)
- // .await
- // {
- // if language
- // .grammar()
- // .and_then(|grammar| grammar.embedding_config.as_ref())
- // .is_none()
- // {
- // continue;
- // }
-
- // // TODO: Make this a bit more defensive
- // let modified_time =
- // change_path.metadata().unwrap().modified().unwrap();
- // let existing_time =
- // file_mtimes.get(&change_path.to_path_buf());
- // let already_stored =
- // existing_time.map_or(false, |existing_time| {
- // if &modified_time != existing_time
- // && existing_time.elapsed().unwrap().as_secs()
- // > REINDEXING_DELAY
- // {
- // false
- // } else {
- // true
- // }
- // });
-
- // if !already_stored {
- // log::info!("Need to reindex: {:?}", &change_path);
- // paths_tx
- // .try_send((
- // worktree_db_id,
- // change_path.to_path_buf(),
- // language,
- // modified_time,
- // ))
- // .unwrap();
- // }
- // }
- // }
- // })
- // }
+ if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event
+ {
+ // Iterate through changes
+ let language_registry = this.language_registry.clone();
+
+ let db =
+ VectorDatabase::new(this.database_url.to_string_lossy().into());
+ if db.is_err() {
+ return;
+ }
+ let db = db.unwrap();
+
+ let worktree_db_id: Option<i64> = {
+ let mut found_db_id = None;
+ for (w_id, db_id) in worktree_db_ids.into_iter() {
+ if &w_id == worktree_id {
+ found_db_id = Some(db_id);
+ }
+ }
+
+ found_db_id
+ };
+
+ if worktree_db_id.is_none() {
+ return;
+ }
+ let worktree_db_id = worktree_db_id.unwrap();
+
+ let file_mtimes = db.get_file_mtimes(worktree_db_id);
+ if file_mtimes.is_err() {
+ return;
+ }
+
+ let file_mtimes = file_mtimes.unwrap();
+ let parsing_files_tx = this.parsing_files_tx.clone();
+
+ smol::block_on(async move {
+ for change in changes.into_iter() {
+ let change_path = change.0.clone();
+ log::info!("Change: {:?}", &change_path);
+ if let Ok(language) = language_registry
+ .language_for_file(&change_path.to_path_buf(), None)
+ .await
+ {
+ if language
+ .grammar()
+ .and_then(|grammar| grammar.embedding_config.as_ref())
+ .is_none()
+ {
+ continue;
+ }
+
+ // TODO: Make this a bit more defensive
+ let modified_time =
+ change_path.metadata().unwrap().modified().unwrap();
+ let existing_time =
+ file_mtimes.get(&change_path.to_path_buf());
+ let already_stored =
+ existing_time.map_or(false, |existing_time| {
+ if &modified_time != existing_time
+ && existing_time.elapsed().unwrap().as_secs()
+ > REINDEXING_DELAY
+ {
+ false
+ } else {
+ true
+ }
+ });
+
+ if !already_stored {
+ log::info!("Need to reindex: {:?}", &change_path);
+ parsing_files_tx
+ .try_send((
+ worktree_db_id,
+ change_path.to_path_buf(),
+ language,
+ modified_time,
+ ))
+ .unwrap();
+ }
+ }
+ }
+ })
+ }
}
});