@@ -2,7 +2,7 @@ use crate::{
SearchOption, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleRegex,
ToggleWholeWord,
};
-use anyhow::Result;
+use anyhow::{Context, Result};
use collections::HashMap;
use editor::{
items::active_match_index, scroll::autoscroll::Autoscroll, Anchor, Editor, MultiBuffer,
@@ -18,7 +18,9 @@ use gpui::{
Task, View, ViewContext, ViewHandle, WeakModelHandle, WeakViewHandle,
};
use menu::Confirm;
+use postage::stream::Stream;
use project::{search::SearchQuery, Project};
+use semantic_index::SemanticIndex;
use smallvec::SmallVec;
use std::{
any::{Any, TypeId},
@@ -36,7 +38,10 @@ use workspace::{
ItemNavHistory, Pane, ToolbarItemLocation, ToolbarItemView, Workspace, WorkspaceId,
};
-actions!(project_search, [SearchInNew, ToggleFocus, NextField]);
+actions!(
+ project_search,
+ [SearchInNew, ToggleFocus, NextField, ToggleSemanticSearch]
+);
#[derive(Default)]
struct ActiveSearches(HashMap<WeakModelHandle<Project>, WeakViewHandle<ProjectSearchView>>);
@@ -92,6 +97,7 @@ pub struct ProjectSearchView {
case_sensitive: bool,
whole_word: bool,
regex: bool,
+ semantic: Option<SemanticSearchState>,
panels_with_errors: HashSet<InputPanel>,
active_match_index: Option<usize>,
search_id: usize,
@@ -100,6 +106,13 @@ pub struct ProjectSearchView {
excluded_files_editor: ViewHandle<Editor>,
}
+struct SemanticSearchState {
+ file_count: usize,
+ outstanding_file_count: usize,
+ _progress_task: Task<()>,
+ search_task: Option<Task<Result<()>>>,
+}
+
pub struct ProjectSearchBar {
active_project_search: Option<ViewHandle<ProjectSearchView>>,
subscription: Option<Subscription>,
@@ -198,12 +211,25 @@ impl View for ProjectSearchView {
let theme = theme::current(cx).clone();
let text = if self.query_editor.read(cx).text(cx).is_empty() {
- ""
+ Cow::Borrowed("")
+ } else if let Some(semantic) = &self.semantic {
+ if semantic.search_task.is_some() {
+ Cow::Borrowed("Searching...")
+ } else if semantic.outstanding_file_count > 0 {
+ Cow::Owned(format!(
+ "Indexing. {} of {}...",
+ semantic.file_count - semantic.outstanding_file_count,
+ semantic.file_count
+ ))
+ } else {
+ Cow::Borrowed("Indexing complete")
+ }
} else if model.pending_search.is_some() {
- "Searching..."
+ Cow::Borrowed("Searching...")
} else {
- "No results"
+ Cow::Borrowed("No results")
};
+
MouseEventHandler::<Status, _>::new(0, cx, |_, _| {
Label::new(text, theme.search.results_status.clone())
.aligned()
@@ -499,6 +525,7 @@ impl ProjectSearchView {
case_sensitive,
whole_word,
regex,
+ semantic: None,
panels_with_errors: HashSet::new(),
active_match_index: None,
query_editor_was_focused: false,
@@ -563,6 +590,35 @@ impl ProjectSearchView {
}
fn search(&mut self, cx: &mut ViewContext<Self>) {
+ if let Some(semantic) = &mut self.semantic {
+ if semantic.outstanding_file_count > 0 {
+ return;
+ }
+
+ let search_phrase = self.query_editor.read(cx).text(cx);
+ let project = self.model.read(cx).project.clone();
+ if let Some(semantic_index) = SemanticIndex::global(cx) {
+ let search_task = semantic_index.update(cx, |semantic_index, cx| {
+ semantic_index.search_project(project, search_phrase, 10, cx)
+ });
+ semantic.search_task = Some(cx.spawn(|this, mut cx| async move {
+ let results = search_task.await.context("search task")?;
+
+ this.update(&mut cx, |this, cx| {
+ dbg!(&results);
+ // TODO: Update results
+
+ if let Some(semantic) = &mut this.semantic {
+ semantic.search_task = None;
+ }
+ })?;
+
+ anyhow::Ok(())
+ }));
+ }
+ return;
+ }
+
if let Some(query) = self.build_search_query(cx) {
self.model.update(cx, |model, cx| model.search(query, cx));
}
@@ -876,6 +932,59 @@ impl ProjectSearchBar {
}
}
+ fn toggle_semantic_search(&mut self, cx: &mut ViewContext<Self>) -> bool {
+ if let Some(search_view) = self.active_project_search.as_ref() {
+ search_view.update(cx, |search_view, cx| {
+ if search_view.semantic.is_some() {
+ search_view.semantic = None;
+ } else if let Some(semantic_index) = SemanticIndex::global(cx) {
+ // TODO: confirm that it's ok to send this project
+
+ let project = search_view.model.read(cx).project.clone();
+ let index_task = semantic_index.update(cx, |semantic_index, cx| {
+ semantic_index.index_project(project, cx)
+ });
+
+ cx.spawn(|search_view, mut cx| async move {
+ let (files_to_index, mut files_remaining_rx) = index_task.await?;
+
+ search_view.update(&mut cx, |search_view, cx| {
+ search_view.semantic = Some(SemanticSearchState {
+ file_count: files_to_index,
+ outstanding_file_count: files_to_index,
+ search_task: None,
+ _progress_task: cx.spawn(|search_view, mut cx| async move {
+ while let Some(count) = files_remaining_rx.recv().await {
+ search_view
+ .update(&mut cx, |search_view, cx| {
+ if let Some(semantic_search_state) =
+ &mut search_view.semantic
+ {
+ semantic_search_state.outstanding_file_count =
+ count;
+ cx.notify();
+ if count == 0 {
+ return;
+ }
+ }
+ })
+ .ok();
+ }
+ }),
+ });
+ })?;
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ }
+ });
+ cx.notify();
+ true
+ } else {
+ false
+ }
+ }
+
fn render_nav_button(
&self,
icon: &'static str,
@@ -953,6 +1062,42 @@ impl ProjectSearchBar {
.into_any()
}
+ fn render_semantic_search_button(&self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
+ let tooltip_style = theme::current(cx).tooltip.clone();
+ let is_active = if let Some(search) = self.active_project_search.as_ref() {
+ let search = search.read(cx);
+ search.semantic.is_some()
+ } else {
+ false
+ };
+
+ let region_id = 3;
+
+ MouseEventHandler::<Self, _>::new(region_id, cx, |state, cx| {
+ let theme = theme::current(cx);
+ let style = theme
+ .search
+ .option_button
+ .in_state(is_active)
+ .style_for(state);
+ Label::new("Semantic", style.text.clone())
+ .contained()
+ .with_style(style.container)
+ })
+ .on_click(MouseButton::Left, move |_, this, cx| {
+ this.toggle_semantic_search(cx);
+ })
+ .with_cursor_style(CursorStyle::PointingHand)
+ .with_tooltip::<Self>(
+ region_id,
+ format!("Toggle Semantic Search"),
+ Some(Box::new(ToggleSemanticSearch)),
+ tooltip_style,
+ cx,
+ )
+ .into_any()
+ }
+
fn is_option_enabled(&self, option: SearchOption, cx: &AppContext) -> bool {
if let Some(search) = self.active_project_search.as_ref() {
let search = search.read(cx);
@@ -1049,6 +1194,7 @@ impl View for ProjectSearchBar {
)
.with_child(
Flex::row()
+ .with_child(self.render_semantic_search_button(cx))
.with_child(self.render_option_button(
"Case",
SearchOption::CaseSensitive,
@@ -1,172 +0,0 @@
-use crate::{SearchResult, SemanticIndex};
-use editor::{scroll::autoscroll::Autoscroll, Editor};
-use gpui::{
- actions, elements::*, AnyElement, AppContext, ModelHandle, MouseState, Task, ViewContext,
- WeakViewHandle,
-};
-use picker::{Picker, PickerDelegate, PickerEvent};
-use project::{Project, ProjectPath};
-use std::{collections::HashMap, sync::Arc, time::Duration};
-use util::ResultExt;
-use workspace::Workspace;
-
-const MIN_QUERY_LEN: usize = 5;
-const EMBEDDING_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(500);
-
-actions!(semantic_search, [Toggle]);
-
-pub type SemanticSearch = Picker<SemanticSearchDelegate>;
-
-pub struct SemanticSearchDelegate {
- workspace: WeakViewHandle<Workspace>,
- project: ModelHandle<Project>,
- semantic_index: ModelHandle<SemanticIndex>,
- selected_match_index: usize,
- matches: Vec<SearchResult>,
- history: HashMap<String, Vec<SearchResult>>,
-}
-
-impl SemanticSearchDelegate {
- // This is currently searching on every keystroke,
- // This is wildly overkill, and has the potential to get expensive
- // We will need to update this to throttle searching
- pub fn new(
- workspace: WeakViewHandle<Workspace>,
- project: ModelHandle<Project>,
- semantic_index: ModelHandle<SemanticIndex>,
- ) -> Self {
- Self {
- workspace,
- project,
- semantic_index,
- selected_match_index: 0,
- matches: vec![],
- history: HashMap::new(),
- }
- }
-}
-
-impl PickerDelegate for SemanticSearchDelegate {
- fn placeholder_text(&self) -> Arc<str> {
- "Search repository in natural language...".into()
- }
-
- fn confirm(&mut self, cx: &mut ViewContext<SemanticSearch>) {
- if let Some(search_result) = self.matches.get(self.selected_match_index) {
- // Open Buffer
- let search_result = search_result.clone();
- let buffer = self.project.update(cx, |project, cx| {
- project.open_buffer(
- ProjectPath {
- worktree_id: search_result.worktree_id,
- path: search_result.file_path.clone().into(),
- },
- cx,
- )
- });
-
- let workspace = self.workspace.clone();
- let position = search_result.clone().byte_range.start;
- cx.spawn(|_, mut cx| async move {
- let buffer = buffer.await?;
- workspace.update(&mut cx, |workspace, cx| {
- let editor = workspace.open_project_item::<Editor>(buffer, cx);
- editor.update(cx, |editor, cx| {
- editor.change_selections(Some(Autoscroll::center()), cx, |s| {
- s.select_ranges([position..position])
- });
- });
- })?;
- Ok::<_, anyhow::Error>(())
- })
- .detach_and_log_err(cx);
- cx.emit(PickerEvent::Dismiss);
- }
- }
-
- fn dismissed(&mut self, _cx: &mut ViewContext<SemanticSearch>) {}
-
- fn match_count(&self) -> usize {
- self.matches.len()
- }
-
- fn selected_index(&self) -> usize {
- self.selected_match_index
- }
-
- fn set_selected_index(&mut self, ix: usize, _cx: &mut ViewContext<SemanticSearch>) {
- self.selected_match_index = ix;
- }
-
- fn update_matches(&mut self, query: String, cx: &mut ViewContext<SemanticSearch>) -> Task<()> {
- log::info!("Searching for {:?}...", query);
- if query.len() < MIN_QUERY_LEN {
- log::info!("Query below minimum length");
- return Task::ready(());
- }
-
- let semantic_index = self.semantic_index.clone();
- let project = self.project.clone();
- cx.spawn(|this, mut cx| async move {
- cx.background().timer(EMBEDDING_DEBOUNCE_INTERVAL).await;
-
- let retrieved_cached = this.update(&mut cx, |this, _| {
- let delegate = this.delegate_mut();
- if delegate.history.contains_key(&query) {
- let historic_results = delegate.history.get(&query).unwrap().to_owned();
- delegate.matches = historic_results.clone();
- true
- } else {
- false
- }
- });
-
- if let Some(retrieved) = retrieved_cached.log_err() {
- if !retrieved {
- let task = semantic_index.update(&mut cx, |store, cx| {
- store.search_project(project.clone(), query.to_string(), 10, cx)
- });
-
- if let Some(results) = task.await.log_err() {
- log::info!("Not queried previously, searching...");
- this.update(&mut cx, |this, _| {
- let delegate = this.delegate_mut();
- delegate.matches = results.clone();
- delegate.history.insert(query, results);
- })
- .ok();
- }
- } else {
- log::info!("Already queried, retrieved directly from cached history");
- }
- }
- })
- }
-
- fn render_match(
- &self,
- ix: usize,
- mouse_state: &mut MouseState,
- selected: bool,
- cx: &AppContext,
- ) -> AnyElement<Picker<Self>> {
- let theme = theme::current(cx);
- let style = &theme.picker.item;
- let current_style = style.in_state(selected).style_for(mouse_state);
-
- let search_result = &self.matches[ix];
-
- let path = search_result.file_path.to_string_lossy();
- let name = search_result.name.clone();
-
- Flex::column()
- .with_child(Text::new(name, current_style.label.text.clone()).with_soft_wrap(false))
- .with_child(Label::new(
- path.to_string(),
- style.inactive_state().default.label.clone(),
- ))
- .contained()
- .with_style(current_style.container)
- .into_any()
- }
-}
@@ -1,6 +1,5 @@
mod db;
mod embedding;
-mod modal;
mod parsing;
mod semantic_index_settings;
@@ -12,25 +11,20 @@ use anyhow::{anyhow, Result};
use db::VectorDatabase;
use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use futures::{channel::oneshot, Future};
-use gpui::{
- AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
- WeakModelHandle,
-};
+use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
use language::{Language, LanguageRegistry};
-use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
use parking_lot::Mutex;
use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
+use postage::watch;
use project::{Fs, Project, WorktreeId};
use smol::channel;
use std::{
- collections::{HashMap, HashSet},
+ collections::HashMap,
+ mem,
ops::Range,
path::{Path, PathBuf},
- sync::{
- atomic::{self, AtomicUsize},
- Arc, Weak,
- },
- time::{Instant, SystemTime},
+ sync::{Arc, Weak},
+ time::SystemTime,
};
use util::{
channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
@@ -38,9 +32,8 @@ use util::{
paths::EMBEDDINGS_DIR,
ResultExt,
};
-use workspace::{Workspace, WorkspaceCreated};
-const SEMANTIC_INDEX_VERSION: usize = 1;
+const SEMANTIC_INDEX_VERSION: usize = 3;
const EMBEDDINGS_BATCH_SIZE: usize = 150;
pub fn init(
@@ -55,25 +48,6 @@ pub fn init(
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
.join("embeddings_db");
- SemanticSearch::init(cx);
- cx.add_action(
- |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
- if cx.has_global::<ModelHandle<SemanticIndex>>() {
- let semantic_index = cx.global::<ModelHandle<SemanticIndex>>().clone();
- workspace.toggle_modal(cx, |workspace, cx| {
- let project = workspace.project().clone();
- let workspace = cx.weak_handle();
- cx.add_view(|cx| {
- SemanticSearch::new(
- SemanticSearchDelegate::new(workspace, project, semantic_index),
- cx,
- )
- })
- });
- }
- },
- );
-
if *RELEASE_CHANNEL == ReleaseChannel::Stable
|| !settings::get::<SemanticIndexSettings>(cx).enabled
{
@@ -95,21 +69,6 @@ pub fn init(
cx.update(|cx| {
cx.set_global(semantic_index.clone());
- cx.subscribe_global::<WorkspaceCreated, _>({
- let semantic_index = semantic_index.clone();
- move |event, cx| {
- let workspace = &event.0;
- if let Some(workspace) = workspace.upgrade(cx) {
- let project = workspace.read(cx).project().clone();
- if project.read(cx).is_local() {
- semantic_index.update(cx, |store, cx| {
- store.index_project(project, cx).detach();
- });
- }
- }
- }
- })
- .detach();
});
anyhow::Ok(())
@@ -128,20 +87,17 @@ pub struct SemanticIndex {
_embed_batch_task: Task<()>,
_batch_files_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
- next_job_id: Arc<AtomicUsize>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
}
struct ProjectState {
worktree_db_ids: Vec<(WorktreeId, i64)>,
- outstanding_jobs: Arc<Mutex<HashSet<JobId>>>,
+ outstanding_job_count_rx: watch::Receiver<usize>,
+ outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
}
-type JobId = usize;
-
struct JobHandle {
- id: JobId,
- set: Weak<Mutex<HashSet<JobId>>>,
+ tx: Weak<Mutex<watch::Sender<usize>>>,
}
impl ProjectState {
@@ -221,6 +177,14 @@ enum EmbeddingJob {
}
impl SemanticIndex {
+ pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
+ if cx.has_global::<ModelHandle<Self>>() {
+ Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
+ } else {
+ None
+ }
+ }
+
async fn new(
fs: Arc<dyn Fs>,
database_url: PathBuf,
@@ -236,184 +200,69 @@ impl SemanticIndex {
.await?;
Ok(cx.add_model(|cx| {
- // paths_tx -> embeddings_tx -> db_update_tx
-
- //db_update_tx/rx: Updating Database
+ // Perform database operations
let (db_update_tx, db_update_rx) = channel::unbounded();
- let _db_update_task = cx.background().spawn(async move {
- while let Ok(job) = db_update_rx.recv().await {
- match job {
- DbOperation::InsertFile {
- worktree_id,
- documents,
- path,
- mtime,
- job_handle,
- } => {
- db.insert_file(worktree_id, path, mtime, documents)
- .log_err();
- drop(job_handle)
- }
- DbOperation::Delete { worktree_id, path } => {
- db.delete_file(worktree_id, path).log_err();
- }
- DbOperation::FindOrCreateWorktree { path, sender } => {
- let id = db.find_or_create_worktree(&path);
- sender.send(id).ok();
- }
- DbOperation::FileMTimes {
- worktree_id: worktree_db_id,
- sender,
- } => {
- let file_mtimes = db.get_file_mtimes(worktree_db_id);
- sender.send(file_mtimes).ok();
- }
+ let _db_update_task = cx.background().spawn({
+ async move {
+ while let Ok(job) = db_update_rx.recv().await {
+ Self::run_db_operation(&db, job)
}
}
});
- // embed_tx/rx: Embed Batch and Send to Database
+ // Group documents into batches and send them to the embedding provider.
let (embed_batch_tx, embed_batch_rx) =
channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
let _embed_batch_task = cx.background().spawn({
let db_update_tx = db_update_tx.clone();
let embedding_provider = embedding_provider.clone();
async move {
- while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
- // Construct Batch
- let mut batch_documents = vec![];
- for (_, documents, _, _, _) in embeddings_queue.iter() {
- batch_documents
- .extend(documents.iter().map(|document| document.content.as_str()));
- }
-
- if let Ok(embeddings) =
- embedding_provider.embed_batch(batch_documents).await
- {
- log::trace!(
- "created {} embeddings for {} files",
- embeddings.len(),
- embeddings_queue.len(),
- );
-
- let mut i = 0;
- let mut j = 0;
-
- for embedding in embeddings.iter() {
- while embeddings_queue[i].1.len() == j {
- i += 1;
- j = 0;
- }
-
- embeddings_queue[i].1[j].embedding = embedding.to_owned();
- j += 1;
- }
-
- for (worktree_id, documents, path, mtime, job_handle) in
- embeddings_queue.into_iter()
- {
- for document in documents.iter() {
- // TODO: Update this so it doesn't panic
- assert!(
- document.embedding.len() > 0,
- "Document Embedding Not Complete"
- );
- }
-
- db_update_tx
- .send(DbOperation::InsertFile {
- worktree_id,
- documents,
- path,
- mtime,
- job_handle,
- })
- .await
- .unwrap();
- }
- }
+ while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
+ Self::compute_embeddings_for_batch(
+ embeddings_queue,
+ &embedding_provider,
+ &db_update_tx,
+ )
+ .await;
}
}
});
- // batch_tx/rx: Batch Files to Send for Embeddings
+ // Group documents into batches and send them to the embedding provider.
let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
let _batch_files_task = cx.background().spawn(async move {
let mut queue_len = 0;
let mut embeddings_queue = vec![];
-
while let Ok(job) = batch_files_rx.recv().await {
- let should_flush = match job {
- EmbeddingJob::Enqueue {
- documents,
- worktree_id,
- path,
- mtime,
- job_handle,
- } => {
- queue_len += &documents.len();
- embeddings_queue.push((
- worktree_id,
- documents,
- path,
- mtime,
- job_handle,
- ));
- queue_len >= EMBEDDINGS_BATCH_SIZE
- }
- EmbeddingJob::Flush => true,
- };
-
- if should_flush {
- embed_batch_tx.try_send(embeddings_queue).unwrap();
- embeddings_queue = vec![];
- queue_len = 0;
- }
+ Self::enqueue_documents_to_embed(
+ job,
+ &mut queue_len,
+ &mut embeddings_queue,
+ &embed_batch_tx,
+ );
}
});
- // parsing_files_tx/rx: Parsing Files to Embeddable Documents
+ // Parse files into embeddable documents.
let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
-
let mut _parsing_files_tasks = Vec::new();
for _ in 0..cx.background().num_cpus() {
let fs = fs.clone();
let parsing_files_rx = parsing_files_rx.clone();
let batch_files_tx = batch_files_tx.clone();
+ let db_update_tx = db_update_tx.clone();
_parsing_files_tasks.push(cx.background().spawn(async move {
let mut retriever = CodeContextRetriever::new();
while let Ok(pending_file) = parsing_files_rx.recv().await {
- if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err()
- {
- if let Some(documents) = retriever
- .parse_file(
- &pending_file.relative_path,
- &content,
- pending_file.language,
- )
- .log_err()
- {
- log::trace!(
- "parsed path {:?}: {} documents",
- pending_file.relative_path,
- documents.len()
- );
-
- batch_files_tx
- .try_send(EmbeddingJob::Enqueue {
- worktree_id: pending_file.worktree_db_id,
- path: pending_file.relative_path,
- mtime: pending_file.modified_time,
- job_handle: pending_file.job_handle,
- documents,
- })
- .unwrap();
- }
- }
-
- if parsing_files_rx.len() == 0 {
- batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
- }
+ Self::parse_file(
+ &fs,
+ pending_file,
+ &mut retriever,
+ &batch_files_tx,
+ &parsing_files_rx,
+ &db_update_tx,
+ )
+ .await;
}
}));
}
@@ -424,7 +273,6 @@ impl SemanticIndex {
embedding_provider,
language_registry,
db_update_tx,
- next_job_id: Default::default(),
parsing_files_tx,
_db_update_task,
_embed_batch_task,
@@ -435,6 +283,167 @@ impl SemanticIndex {
}))
}
+ fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
+ match job {
+ DbOperation::InsertFile {
+ worktree_id,
+ documents,
+ path,
+ mtime,
+ job_handle,
+ } => {
+ db.insert_file(worktree_id, path, mtime, documents)
+ .log_err();
+ drop(job_handle)
+ }
+ DbOperation::Delete { worktree_id, path } => {
+ db.delete_file(worktree_id, path).log_err();
+ }
+ DbOperation::FindOrCreateWorktree { path, sender } => {
+ let id = db.find_or_create_worktree(&path);
+ sender.send(id).ok();
+ }
+ DbOperation::FileMTimes {
+ worktree_id: worktree_db_id,
+ sender,
+ } => {
+ let file_mtimes = db.get_file_mtimes(worktree_db_id);
+ sender.send(file_mtimes).ok();
+ }
+ }
+ }
+
+ async fn compute_embeddings_for_batch(
+ mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
+ embedding_provider: &Arc<dyn EmbeddingProvider>,
+ db_update_tx: &channel::Sender<DbOperation>,
+ ) {
+ let mut batch_documents = vec![];
+ for (_, documents, _, _, _) in embeddings_queue.iter() {
+ batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
+ }
+
+ if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
+ log::trace!(
+ "created {} embeddings for {} files",
+ embeddings.len(),
+ embeddings_queue.len(),
+ );
+
+ let mut i = 0;
+ let mut j = 0;
+
+ for embedding in embeddings.iter() {
+ while embeddings_queue[i].1.len() == j {
+ i += 1;
+ j = 0;
+ }
+
+ embeddings_queue[i].1[j].embedding = embedding.to_owned();
+ j += 1;
+ }
+
+ for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
+ // for document in documents.iter() {
+ // // TODO: Update this so it doesn't panic
+ // assert!(
+ // document.embedding.len() > 0,
+ // "Document Embedding Not Complete"
+ // );
+ // }
+
+ db_update_tx
+ .send(DbOperation::InsertFile {
+ worktree_id,
+ documents,
+ path,
+ mtime,
+ job_handle,
+ })
+ .await
+ .unwrap();
+ }
+ }
+ }
+
+ fn enqueue_documents_to_embed(
+ job: EmbeddingJob,
+ queue_len: &mut usize,
+ embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
+ embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
+ ) {
+ let should_flush = match job {
+ EmbeddingJob::Enqueue {
+ documents,
+ worktree_id,
+ path,
+ mtime,
+ job_handle,
+ } => {
+ *queue_len += &documents.len();
+ embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
+ *queue_len >= EMBEDDINGS_BATCH_SIZE
+ }
+ EmbeddingJob::Flush => true,
+ };
+
+ if should_flush {
+ embed_batch_tx
+ .try_send(mem::take(embeddings_queue))
+ .unwrap();
+ *queue_len = 0;
+ }
+ }
+
+ async fn parse_file(
+ fs: &Arc<dyn Fs>,
+ pending_file: PendingFile,
+ retriever: &mut CodeContextRetriever,
+ batch_files_tx: &channel::Sender<EmbeddingJob>,
+ parsing_files_rx: &channel::Receiver<PendingFile>,
+ db_update_tx: &channel::Sender<DbOperation>,
+ ) {
+ if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
+ if let Some(documents) = retriever
+ .parse_file(&pending_file.relative_path, &content, pending_file.language)
+ .log_err()
+ {
+ log::trace!(
+ "parsed path {:?}: {} documents",
+ pending_file.relative_path,
+ documents.len()
+ );
+
+ if documents.len() == 0 {
+ db_update_tx
+ .send(DbOperation::InsertFile {
+ worktree_id: pending_file.worktree_db_id,
+ documents,
+ path: pending_file.relative_path,
+ mtime: pending_file.modified_time,
+ job_handle: pending_file.job_handle,
+ })
+ .await
+ .unwrap();
+ } else {
+ batch_files_tx
+ .try_send(EmbeddingJob::Enqueue {
+ worktree_id: pending_file.worktree_db_id,
+ path: pending_file.relative_path,
+ mtime: pending_file.modified_time,
+ job_handle: pending_file.job_handle,
+ documents,
+ })
+ .unwrap();
+ }
+ }
+ }
+
+ if parsing_files_rx.len() == 0 {
+ batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
+ }
+ }
+
fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
let (tx, rx) = oneshot::channel();
self.db_update_tx
@@ -457,11 +466,11 @@ impl SemanticIndex {
async move { rx.await? }
}
- fn index_project(
+ pub fn index_project(
&mut self,
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
- ) -> Task<Result<usize>> {
+ ) -> Task<Result<(usize, watch::Receiver<usize>)>> {
let worktree_scans_complete = project
.read(cx)
.worktrees(cx)
@@ -483,7 +492,6 @@ impl SemanticIndex {
let language_registry = self.language_registry.clone();
let db_update_tx = self.db_update_tx.clone();
let parsing_files_tx = self.parsing_files_tx.clone();
- let next_job_id = self.next_job_id.clone();
cx.spawn(|this, mut cx| async move {
futures::future::join_all(worktree_scans_complete).await;
@@ -509,8 +517,8 @@ impl SemanticIndex {
);
}
- // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
- let outstanding_jobs = Arc::new(Mutex::new(HashSet::new()));
+ let (job_count_tx, job_count_rx) = watch::channel_with(0);
+ let job_count_tx = Arc::new(Mutex::new(job_count_tx));
this.update(&mut cx, |this, _| {
this.projects.insert(
project.downgrade(),
@@ -519,7 +527,8 @@ impl SemanticIndex {
.iter()
.map(|(a, b)| (*a, *b))
.collect(),
- outstanding_jobs: outstanding_jobs.clone(),
+ outstanding_job_count_rx: job_count_rx.clone(),
+ outstanding_job_count_tx: job_count_tx.clone(),
},
);
});
@@ -527,7 +536,6 @@ impl SemanticIndex {
cx.background()
.spawn(async move {
let mut count = 0;
- let t0 = Instant::now();
for worktree in worktrees.into_iter() {
let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
for file in worktree.files(false, 0) {
@@ -552,14 +560,11 @@ impl SemanticIndex {
.map_or(false, |existing_mtime| existing_mtime == file.mtime);
if !already_stored {
- log::trace!("sending for parsing: {:?}", path_buf);
count += 1;
- let job_id = next_job_id.fetch_add(1, atomic::Ordering::SeqCst);
+ *job_count_tx.lock().borrow_mut() += 1;
let job_handle = JobHandle {
- id: job_id,
- set: Arc::downgrade(&outstanding_jobs),
+ tx: Arc::downgrade(&job_count_tx),
};
- outstanding_jobs.lock().insert(job_id);
parsing_files_tx
.try_send(PendingFile {
worktree_db_id: db_ids_by_worktree_id[&worktree.id()],
@@ -582,27 +587,22 @@ impl SemanticIndex {
.unwrap();
}
}
- log::trace!(
- "parsing worktree completed in {:?}",
- t0.elapsed().as_millis()
- );
- Ok(count)
+ anyhow::Ok((count, job_count_rx))
})
.await
})
}
- pub fn remaining_files_to_index_for_project(
+ pub fn outstanding_job_count_rx(
&self,
project: &ModelHandle<Project>,
- ) -> Option<usize> {
+ ) -> Option<watch::Receiver<usize>> {
Some(
self.projects
.get(&project.downgrade())?
- .outstanding_jobs
- .lock()
- .len(),
+ .outstanding_job_count_rx
+ .clone(),
)
}
@@ -678,8 +678,9 @@ impl Entity for SemanticIndex {
impl Drop for JobHandle {
fn drop(&mut self) {
- if let Some(set) = self.set.upgrade() {
- set.lock().remove(&self.id);
+ if let Some(tx) = self.tx.upgrade() {
+ let mut tx = tx.lock();
+ *tx.borrow_mut() -= 1;
}
}
}