Semantic index progress (#11071)

Max Brunsfeld , Antonio Scandurra , Kyle , Marshall , and Marshall Bowers created

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Kyle <kylek@zed.dev>
Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>

Change summary

Cargo.lock                                        |   1 
crates/assistant2/examples/assistant_example.rs   |  16 
crates/assistant2/examples/chat_with_functions.rs |  24 
crates/assistant2/examples/file_interactions.rs   |  24 
crates/assistant2/src/assistant2.rs               | 258 +---------------
crates/assistant2/src/tools.rs                    |  85 ++++-
crates/assistant_tooling/src/registry.rs          |  63 +--
crates/assistant_tooling/src/tool.rs              |   6 
crates/semantic_index/Cargo.toml                  |   1 
crates/semantic_index/src/semantic_index.rs       | 217 +++++++++----
10 files changed, 293 insertions(+), 402 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -8689,6 +8689,7 @@ dependencies = [
  "languages",
  "log",
  "open_ai",
+ "parking_lot",
  "project",
  "serde",
  "serde_json",

crates/assistant2/examples/assistant_example.rs 🔗

@@ -87,16 +87,14 @@ fn main() {
 
                 let project_index = semantic_index.project_index(project.clone(), cx);
 
-                let mut tool_registry = ToolRegistry::new();
-                tool_registry
-                    .register(ProjectIndexTool::new(project_index.clone(), fs.clone()))
-                    .context("failed to register ProjectIndexTool")
-                    .log_err();
-
-                let tool_registry = Arc::new(tool_registry);
-
                 cx.open_window(WindowOptions::default(), |cx| {
-                    cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
+                    let mut tool_registry = ToolRegistry::new();
+                    tool_registry
+                        .register(ProjectIndexTool::new(project_index.clone(), fs.clone()), cx)
+                        .context("failed to register ProjectIndexTool")
+                        .log_err();
+
+                    cx.new_view(|cx| Example::new(language_registry, Arc::new(tool_registry), cx))
                 });
                 cx.activate(true);
             })

crates/assistant2/examples/chat_with_functions.rs 🔗

@@ -135,7 +135,7 @@ impl LanguageModelTool for RollDiceTool {
         return Task::ready(Ok(DiceRoll { rolls }));
     }
 
-    fn new_view(
+    fn output_view(
         _tool_call_id: String,
         _input: Self::Input,
         result: Result<Self::Output>,
@@ -194,20 +194,20 @@ fn main() {
 
         cx.spawn(|cx| async move {
             cx.update(|cx| {
-                let mut tool_registry = ToolRegistry::new();
-                tool_registry
-                    .register(RollDiceTool::new())
-                    .context("failed to register DummyTool")
-                    .log_err();
+                cx.open_window(WindowOptions::default(), |cx| {
+                    let mut tool_registry = ToolRegistry::new();
+                    tool_registry
+                        .register(RollDiceTool::new(), cx)
+                        .context("failed to register DummyTool")
+                        .log_err();
 
-                let tool_registry = Arc::new(tool_registry);
+                    let tool_registry = Arc::new(tool_registry);
 
-                println!("Tools registered");
-                for definition in tool_registry.definitions() {
-                    println!("{}", definition);
-                }
+                    println!("Tools registered");
+                    for definition in tool_registry.definitions() {
+                        println!("{}", definition);
+                    }
 
-                cx.open_window(WindowOptions::default(), |cx| {
                     cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
                 });
                 cx.activate(true);

crates/assistant2/examples/file_interactions.rs 🔗

@@ -115,7 +115,7 @@ impl LanguageModelTool for FileBrowserTool {
         })
     }
 
-    fn new_view(
+    fn output_view(
         _tool_call_id: String,
         _input: Self::Input,
         result: Result<Self::Output>,
@@ -174,20 +174,20 @@ fn main() {
                 let fs = Arc::new(fs::RealFs::new(None));
                 let cwd = std::env::current_dir().expect("Failed to get current working directory");
 
-                let mut tool_registry = ToolRegistry::new();
-                tool_registry
-                    .register(FileBrowserTool::new(fs, cwd))
-                    .context("failed to register FileBrowserTool")
-                    .log_err();
+                cx.open_window(WindowOptions::default(), |cx| {
+                    let mut tool_registry = ToolRegistry::new();
+                    tool_registry
+                        .register(FileBrowserTool::new(fs, cwd), cx)
+                        .context("failed to register FileBrowserTool")
+                        .log_err();
 
-                let tool_registry = Arc::new(tool_registry);
+                    let tool_registry = Arc::new(tool_registry);
 
-                println!("Tools registered");
-                for definition in tool_registry.definitions() {
-                    println!("{}", definition);
-                }
+                    println!("Tools registered");
+                    for definition in tool_registry.definitions() {
+                        println!("{}", definition);
+                    }
 
-                cx.open_window(WindowOptions::default(), |cx| {
                     cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
                 });
                 cx.activate(true);

crates/assistant2/src/assistant2.rs 🔗

@@ -8,22 +8,21 @@ use client::{proto, Client};
 use completion_provider::*;
 use editor::Editor;
 use feature_flags::FeatureFlagAppExt as _;
-use futures::{channel::oneshot, future::join_all, Future, FutureExt, StreamExt};
+use futures::{future::join_all, StreamExt};
 use gpui::{
     list, prelude::*, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
-    FocusableView, Global, ListAlignment, ListState, Model, Render, Task, View, WeakView,
+    FocusableView, Global, ListAlignment, ListState, Render, Task, View, WeakView,
 };
 use language::{language_settings::SoftWrap, LanguageRegistry};
 use open_ai::{FunctionContent, ToolCall, ToolCallContent};
-use project::Fs;
 use rich_text::RichText;
-use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex};
+use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
 use serde::Deserialize;
 use settings::Settings;
-use std::{cmp, sync::Arc};
+use std::sync::Arc;
 use theme::ThemeSettings;
 use tools::ProjectIndexTool;
-use ui::{popover_menu, prelude::*, ButtonLike, CollapsibleContainer, Color, ContextMenu, Tooltip};
+use ui::{popover_menu, prelude::*, ButtonLike, Color, ContextMenu, Tooltip};
 use util::{paths::EMBEDDINGS_DIR, ResultExt};
 use workspace::{
     dock::{DockPosition, Panel, PanelEvent},
@@ -110,10 +109,10 @@ impl AssistantPanel {
 
                 let mut tool_registry = ToolRegistry::new();
                 tool_registry
-                    .register(ProjectIndexTool::new(
-                        project_index.clone(),
-                        app_state.fs.clone(),
-                    ))
+                    .register(
+                        ProjectIndexTool::new(project_index.clone(), app_state.fs.clone()),
+                        cx,
+                    )
                     .context("failed to register ProjectIndexTool")
                     .log_err();
 
@@ -447,11 +446,7 @@ impl AssistantChat {
             }
             editor
         });
-        let message = ChatMessage::User(UserMessage {
-            id,
-            body,
-            contexts: Vec::new(),
-        });
+        let message = ChatMessage::User(UserMessage { id, body });
         self.push_message(message, cx);
     }
 
@@ -525,11 +520,7 @@ impl AssistantChat {
         let is_last = ix == self.messages.len() - 1;
 
         match &self.messages[ix] {
-            ChatMessage::User(UserMessage {
-                body,
-                contexts: _contexts,
-                ..
-            }) => div()
+            ChatMessage::User(UserMessage { body, .. }) => div()
                 .when(!is_last, |element| element.mb_2())
                 .child(div().p_2().child(Label::new("You").color(Color::Default)))
                 .child(
@@ -539,7 +530,7 @@ impl AssistantChat {
                         .text_color(cx.theme().colors().editor_foreground)
                         .font(ThemeSettings::get_global(cx).buffer_font.clone())
                         .bg(cx.theme().colors().editor_background)
-                        .child(body.clone()), // .children(contexts.iter().map(|context| context.render(cx))),
+                        .child(body.clone()),
                 )
                 .into_any(),
             ChatMessage::Assistant(AssistantMessage {
@@ -588,11 +579,11 @@ impl AssistantChat {
 
         for message in &self.messages {
             match message {
-                ChatMessage::User(UserMessage { body, contexts, .. }) => {
-                    // setup context for model
-                    contexts.iter().for_each(|context| {
-                        completion_messages.extend(context.completion_messages(cx))
-                    });
+                ChatMessage::User(UserMessage { body, .. }) => {
+                    // When we re-introduce contexts like active file, we'll inject them here instead of relying on the model to request them
+                    // contexts.iter().for_each(|context| {
+                    //     completion_messages.extend(context.completion_messages(cx))
+                    // });
 
                     // Show user's message last so that the assistant is grounded in the user's request
                     completion_messages.push(CompletionMessage::User {
@@ -712,6 +703,12 @@ impl Render for AssistantChat {
             .text_color(Color::Default.color(cx))
             .child(self.render_model_dropdown(cx))
             .child(list(self.list_state.clone()).flex_1())
+            .child(
+                h_flex()
+                    .mt_2()
+                    .gap_2()
+                    .children(self.tool_registry.status_views().iter().cloned()),
+            )
     }
 }
 
@@ -743,7 +740,6 @@ impl ChatMessage {
 struct UserMessage {
     id: MessageId,
     body: View<Editor>,
-    contexts: Vec<AssistantContext>,
 }
 
 struct AssistantMessage {
@@ -752,211 +748,3 @@ struct AssistantMessage {
     tool_calls: Vec<ToolFunctionCall>,
     error: Option<SharedString>,
 }
-
-// Since we're swapping out for direct query usage, we might not need to use this injected context
-// It will be useful though for when the user _definitely_ wants the model to see a specific file,
-// query, error, etc.
-#[allow(dead_code)]
-enum AssistantContext {
-    Codebase(View<CodebaseContext>),
-}
-
-#[allow(dead_code)]
-struct CodebaseExcerpt {
-    element_id: ElementId,
-    path: SharedString,
-    text: SharedString,
-    score: f32,
-    expanded: bool,
-}
-
-impl AssistantContext {
-    #[allow(dead_code)]
-    fn render(&self, _cx: &mut ViewContext<AssistantChat>) -> AnyElement {
-        match self {
-            AssistantContext::Codebase(context) => context.clone().into_any_element(),
-        }
-    }
-
-    fn completion_messages(&self, cx: &WindowContext) -> Vec<CompletionMessage> {
-        match self {
-            AssistantContext::Codebase(context) => context.read(cx).completion_messages(),
-        }
-    }
-}
-
-enum CodebaseContext {
-    Pending { _task: Task<()> },
-    Done(Result<Vec<CodebaseExcerpt>>),
-}
-
-impl CodebaseContext {
-    fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext<Self>) {
-        if let CodebaseContext::Done(Ok(excerpts)) = self {
-            if let Some(excerpt) = excerpts
-                .iter_mut()
-                .find(|excerpt| excerpt.element_id == element_id)
-            {
-                excerpt.expanded = !excerpt.expanded;
-                cx.notify();
-            }
-        }
-    }
-}
-
-impl Render for CodebaseContext {
-    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
-        match self {
-            CodebaseContext::Pending { .. } => div()
-                .h_flex()
-                .items_center()
-                .gap_1()
-                .child(Icon::new(IconName::Ai).color(Color::Muted).into_element())
-                .child("Searching codebase..."),
-            CodebaseContext::Done(Ok(excerpts)) => {
-                div()
-                    .v_flex()
-                    .gap_2()
-                    .children(excerpts.iter().map(|excerpt| {
-                        let expanded = excerpt.expanded;
-                        let element_id = excerpt.element_id.clone();
-
-                        CollapsibleContainer::new(element_id.clone(), expanded)
-                            .start_slot(
-                                h_flex()
-                                    .gap_1()
-                                    .child(Icon::new(IconName::File).color(Color::Muted))
-                                    .child(Label::new(excerpt.path.clone()).color(Color::Muted)),
-                            )
-                            .on_click(cx.listener(move |this, _, cx| {
-                                this.toggle_expanded(element_id.clone(), cx);
-                            }))
-                            .child(
-                                div()
-                                    .p_2()
-                                    .rounded_md()
-                                    .bg(cx.theme().colors().editor_background)
-                                    .child(
-                                        excerpt.text.clone(), // todo!(): Show as an editor block
-                                    ),
-                            )
-                    }))
-            }
-            CodebaseContext::Done(Err(error)) => div().child(error.to_string()),
-        }
-    }
-}
-
-impl CodebaseContext {
-    #[allow(dead_code)]
-    fn new(
-        query: impl 'static + Future<Output = Result<String>>,
-        populated: oneshot::Sender<bool>,
-        project_index: Model<ProjectIndex>,
-        fs: Arc<dyn Fs>,
-        cx: &mut ViewContext<Self>,
-    ) -> Self {
-        let query = query.boxed_local();
-        let _task = cx.spawn(|this, mut cx| async move {
-            let result = async {
-                let query = query.await?;
-                let results = this
-                    .update(&mut cx, |_this, cx| {
-                        project_index.read(cx).search(&query, 16, cx)
-                    })?
-                    .await;
-
-                let excerpts = results.into_iter().map(|result| {
-                    let abs_path = result
-                        .worktree
-                        .read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path));
-                    let fs = fs.clone();
-
-                    async move {
-                        let path = result.path.clone();
-                        let text = fs.load(&abs_path?).await?;
-                        // todo!("what should we do with stale ranges?");
-                        let range = cmp::min(result.range.start, text.len())
-                            ..cmp::min(result.range.end, text.len());
-
-                        let text = SharedString::from(text[range].to_string());
-
-                        anyhow::Ok(CodebaseExcerpt {
-                            element_id: ElementId::Name(nanoid::nanoid!().into()),
-                            path: path.to_string_lossy().to_string().into(),
-                            text,
-                            score: result.score,
-                            expanded: false,
-                        })
-                    }
-                });
-
-                anyhow::Ok(
-                    futures::future::join_all(excerpts)
-                        .await
-                        .into_iter()
-                        .filter_map(|result| result.log_err())
-                        .collect(),
-                )
-            }
-            .await;
-
-            this.update(&mut cx, |this, cx| {
-                this.populate(result, populated, cx);
-            })
-            .ok();
-        });
-
-        Self::Pending { _task }
-    }
-
-    #[allow(dead_code)]
-    fn populate(
-        &mut self,
-        result: Result<Vec<CodebaseExcerpt>>,
-        populated: oneshot::Sender<bool>,
-        cx: &mut ViewContext<Self>,
-    ) {
-        let success = result.is_ok();
-        *self = Self::Done(result);
-        populated.send(success).ok();
-        cx.notify();
-    }
-
-    fn completion_messages(&self) -> Vec<CompletionMessage> {
-        // One system message for the whole batch of excerpts:
-
-        // Semantic search results for user query:
-        //
-        // Excerpt from $path:
-        // ~~~
-        // `text`
-        // ~~~
-        //
-        // Excerpt from $path:
-
-        match self {
-            CodebaseContext::Done(Ok(excerpts)) => {
-                if excerpts.is_empty() {
-                    return Vec::new();
-                }
-
-                let mut body = "Semantic search results for user query:\n".to_string();
-
-                for excerpt in excerpts {
-                    body.push_str("Excerpt from ");
-                    body.push_str(excerpt.path.as_ref());
-                    body.push_str(", score ");
-                    body.push_str(&excerpt.score.to_string());
-                    body.push_str(":\n");
-                    body.push_str("~~~\n");
-                    body.push_str(excerpt.text.as_ref());
-                    body.push_str("~~~\n");
-                }
-
-                vec![CompletionMessage::System { content: body }]
-            }
-            _ => vec![],
-        }
-    }
-}

crates/assistant2/src/tools.rs 🔗

@@ -1,9 +1,9 @@
 use anyhow::Result;
 use assistant_tooling::LanguageModelTool;
-use gpui::{prelude::*, AppContext, Model, Task};
+use gpui::{prelude::*, AnyView, AppContext, Model, Task};
 use project::Fs;
 use schemars::JsonSchema;
-use semantic_index::ProjectIndex;
+use semantic_index::{ProjectIndex, Status};
 use serde::Deserialize;
 use std::sync::Arc;
 use ui::{
@@ -36,13 +36,14 @@ pub struct CodebaseQuery {
 
 pub struct ProjectIndexView {
     input: CodebaseQuery,
-    output: Result<Vec<CodebaseExcerpt>>,
+    output: Result<ProjectIndexOutput>,
 }
 
 impl ProjectIndexView {
     fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext<Self>) {
-        if let Ok(excerpts) = &mut self.output {
-            if let Some(excerpt) = excerpts
+        if let Ok(output) = &mut self.output {
+            if let Some(excerpt) = output
+                .excerpts
                 .iter_mut()
                 .find(|excerpt| excerpt.element_id == element_id)
             {
@@ -59,11 +60,11 @@ impl Render for ProjectIndexView {
 
         let result = &self.output;
 
-        let excerpts = match result {
+        let output = match result {
             Err(err) => {
                 return div().child(Label::new(format!("Error: {}", err)).color(Color::Error));
             }
-            Ok(excerpts) => excerpts,
+            Ok(output) => output,
         };
 
         div()
@@ -80,7 +81,7 @@ impl Render for ProjectIndexView {
                             .child(Label::new(query).color(Color::Muted)),
                     ),
             )
-            .children(excerpts.iter().map(|excerpt| {
+            .children(output.excerpts.iter().map(|excerpt| {
                 let element_id = excerpt.element_id.clone();
                 let expanded = excerpt.expanded;
 
@@ -99,9 +100,7 @@ impl Render for ProjectIndexView {
                             .p_2()
                             .rounded_md()
                             .bg(cx.theme().colors().editor_background)
-                            .child(
-                                excerpt.text.clone(), // todo!(): Show as an editor block
-                            ),
+                            .child(excerpt.text.clone()),
                     )
             }))
     }
@@ -112,8 +111,15 @@ pub struct ProjectIndexTool {
     fs: Arc<dyn Fs>,
 }
 
+pub struct ProjectIndexOutput {
+    excerpts: Vec<CodebaseExcerpt>,
+    status: Status,
+}
+
 impl ProjectIndexTool {
     pub fn new(project_index: Model<ProjectIndex>, fs: Arc<dyn Fs>) -> Self {
+        // Listen for project index status and update the ProjectIndexTool directly
+
         // TODO: setup a better description based on the user's current codebase.
         Self { project_index, fs }
     }
@@ -121,7 +127,7 @@ impl ProjectIndexTool {
 
 impl LanguageModelTool for ProjectIndexTool {
     type Input = CodebaseQuery;
-    type Output = Vec<CodebaseExcerpt>;
+    type Output = ProjectIndexOutput;
     type View = ProjectIndexView;
 
     fn name(&self) -> String {
@@ -135,6 +141,7 @@ impl LanguageModelTool for ProjectIndexTool {
     fn execute(&self, query: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>> {
         let project_index = self.project_index.read(cx);
 
+        let status = project_index.status();
         let results = project_index.search(
             query.query.as_str(),
             query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
@@ -180,11 +187,11 @@ impl LanguageModelTool for ProjectIndexTool {
                 .into_iter()
                 .filter_map(|result| result.log_err())
                 .collect();
-            anyhow::Ok(excerpts)
+            anyhow::Ok(ProjectIndexOutput { excerpts, status })
         })
     }
 
-    fn new_view(
+    fn output_view(
         _tool_call_id: String,
         input: Self::Input,
         output: Result<Self::Output>,
@@ -193,16 +200,28 @@ impl LanguageModelTool for ProjectIndexTool {
         cx.new_view(|_cx| ProjectIndexView { input, output })
     }
 
+    fn status_view(&self, cx: &mut WindowContext) -> Option<AnyView> {
+        Some(
+            cx.new_view(|cx| ProjectIndexStatusView::new(self.project_index.clone(), cx))
+                .into(),
+        )
+    }
+
     fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
         match &output {
-            Ok(excerpts) => {
-                if excerpts.len() == 0 {
-                    return "No results found".to_string();
+            Ok(output) => {
+                let mut body = "Semantic search results:\n".to_string();
+
+                if output.status != Status::Idle {
+                    body.push_str("Still indexing. Results may be incomplete.\n");
                 }
 
-                let mut body = "Semantic search results:\n".to_string();
+                if output.excerpts.is_empty() {
+                    body.push_str("No results found");
+                    return body;
+                }
 
-                for excerpt in excerpts {
+                for excerpt in &output.excerpts {
                     body.push_str("Excerpt from ");
                     body.push_str(excerpt.path.as_ref());
                     body.push_str(", score ");
@@ -218,3 +237,31 @@ impl LanguageModelTool for ProjectIndexTool {
         }
     }
 }
+
+struct ProjectIndexStatusView {
+    project_index: Model<ProjectIndex>,
+}
+
+impl ProjectIndexStatusView {
+    pub fn new(project_index: Model<ProjectIndex>, cx: &mut ViewContext<Self>) -> Self {
+        cx.subscribe(&project_index, |_this, _, _status: &Status, cx| {
+            cx.notify();
+        })
+        .detach();
+        Self { project_index }
+    }
+}
+
+impl Render for ProjectIndexStatusView {
+    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+        let status = self.project_index.read(cx).status();
+
+        h_flex().gap_2().map(|element| match status {
+            Status::Idle => element.child(Label::new("Project index ready")),
+            Status::Loading => element.child(Label::new("Project index loading...")),
+            Status::Scanning { remaining_count } => element.child(Label::new(format!(
+                "Project index scanning: {remaining_count} remaining..."
+            ))),
+        })
+    }
+}

crates/assistant_tooling/src/registry.rs 🔗

@@ -1,5 +1,5 @@
 use anyhow::{anyhow, Result};
-use gpui::{Task, WindowContext};
+use gpui::{AnyView, Task, WindowContext};
 use std::collections::HashMap;
 
 use crate::tool::{
@@ -12,6 +12,7 @@ pub struct ToolRegistry {
         Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
     >,
     definitions: Vec<ToolFunctionDefinition>,
+    status_views: Vec<AnyView>,
 }
 
 impl ToolRegistry {
@@ -19,6 +20,7 @@ impl ToolRegistry {
         Self {
             tools: HashMap::new(),
             definitions: Vec::new(),
+            status_views: Vec::new(),
         }
     }
 
@@ -26,8 +28,17 @@ impl ToolRegistry {
         &self.definitions
     }
 
-    pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
+    pub fn register<T: 'static + LanguageModelTool>(
+        &mut self,
+        tool: T,
+        cx: &mut WindowContext,
+    ) -> Result<()> {
         self.definitions.push(tool.definition());
+
+        if let Some(tool_view) = tool.status_view(cx) {
+            self.status_views.push(tool_view);
+        }
+
         let name = tool.name();
         let previous = self.tools.insert(
             name.clone(),
@@ -52,7 +63,7 @@ impl ToolRegistry {
                     cx.spawn(move |mut cx| async move {
                         let result: Result<T::Output> = result.await;
                         let for_model = T::format(&input, &result);
-                        let view = cx.update(|cx| T::new_view(id.clone(), input, result, cx))?;
+                        let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?;
 
                         Ok(ToolFunctionCall {
                             id,
@@ -100,6 +111,10 @@ impl ToolRegistry {
 
         tool(tool_call, cx)
     }
+
+    pub fn status_views(&self) -> &[AnyView] {
+        &self.status_views
+    }
 }
 
 #[cfg(test)]
@@ -165,7 +180,7 @@ mod test {
             Task::ready(Ok(weather))
         }
 
-        fn new_view(
+        fn output_view(
             _tool_call_id: String,
             _input: Self::Input,
             result: Result<Self::Output>,
@@ -182,46 +197,6 @@ mod test {
         }
     }
 
-    #[gpui::test]
-    async fn test_function_registry(cx: &mut TestAppContext) {
-        cx.background_executor.run_until_parked();
-
-        let mut registry = ToolRegistry::new();
-
-        let tool = WeatherTool {
-            current_weather: WeatherResult {
-                location: "San Francisco".to_string(),
-                temperature: 21.0,
-                unit: "Celsius".to_string(),
-            },
-        };
-
-        registry.register(tool).unwrap();
-
-        // let _result = cx
-        //     .update(|cx| {
-        //         registry.call(
-        //             &ToolFunctionCall {
-        //                 name: "get_current_weather".to_string(),
-        //                 arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
-        //                     .to_string(),
-        //                 id: "test-123".to_string(),
-        //                 result: None,
-        //             },
-        //             cx,
-        //         )
-        //     })
-        //     .await;
-
-        // assert!(result.is_ok());
-        // let result = result.unwrap();
-
-        // let expected = r#"{"location":"San Francisco","temperature":21.0,"unit":"Celsius"}"#;
-
-        // todo!(): Put this back in after the interface is stabilized
-        // assert_eq!(result, expected);
-    }
-
     #[gpui::test]
     async fn test_openai_weather_example(cx: &mut TestAppContext) {
         cx.background_executor.run_until_parked();

crates/assistant_tooling/src/tool.rs 🔗

@@ -95,10 +95,14 @@ pub trait LanguageModelTool {
 
     fn format(input: &Self::Input, output: &Result<Self::Output>) -> String;
 
-    fn new_view(
+    fn output_view(
         tool_call_id: String,
         input: Self::Input,
         output: Result<Self::Output>,
         cx: &mut WindowContext,
     ) -> View<Self::View>;
+
+    fn status_view(&self, _cx: &mut WindowContext) -> Option<AnyView> {
+        None
+    }
 }

crates/semantic_index/Cargo.toml 🔗

@@ -30,6 +30,7 @@ language.workspace = true
 log.workspace = true
 heed.workspace = true
 open_ai.workspace = true
+parking_lot.workspace = true
 project.workspace = true
 settings.workspace = true
 serde.workspace = true

crates/semantic_index/src/semantic_index.rs 🔗

@@ -3,7 +3,7 @@ mod embedding;
 
 use anyhow::{anyhow, Context as _, Result};
 use chunking::{chunk_text, Chunk};
-use collections::{Bound, HashMap};
+use collections::{Bound, HashMap, HashSet};
 pub use embedding::*;
 use fs::Fs;
 use futures::stream::StreamExt;
@@ -14,15 +14,17 @@ use gpui::{
 };
 use heed::types::{SerdeBincode, Str};
 use language::LanguageRegistry;
-use project::{Entry, Project, UpdatedEntriesSet, Worktree};
+use parking_lot::Mutex;
+use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree};
 use serde::{Deserialize, Serialize};
 use smol::channel;
 use std::{
     cmp::Ordering,
     future::Future,
+    num::NonZeroUsize,
     ops::Range,
     path::{Path, PathBuf},
-    sync::Arc,
+    sync::{Arc, Weak},
     time::{Duration, SystemTime},
 };
 use util::ResultExt;
@@ -102,19 +104,16 @@ pub struct ProjectIndex {
     worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
     language_registry: Arc<LanguageRegistry>,
     fs: Arc<dyn Fs>,
-    pub last_status: Status,
+    last_status: Status,
+    status_tx: channel::Sender<()>,
     embedding_provider: Arc<dyn EmbeddingProvider>,
+    _maintain_status: Task<()>,
     _subscription: Subscription,
 }
 
 enum WorktreeIndexHandle {
-    Loading {
-        _task: Task<Result<()>>,
-    },
-    Loaded {
-        index: Model<WorktreeIndex>,
-        _subscription: Subscription,
-    },
+    Loading { _task: Task<Result<()>> },
+    Loaded { index: Model<WorktreeIndex> },
 }
 
 impl ProjectIndex {
@@ -126,20 +125,36 @@ impl ProjectIndex {
     ) -> Self {
         let language_registry = project.read(cx).languages().clone();
         let fs = project.read(cx).fs().clone();
+        let (status_tx, mut status_rx) = channel::unbounded();
         let mut this = ProjectIndex {
             db_connection,
             project: project.downgrade(),
             worktree_indices: HashMap::default(),
             language_registry,
             fs,
+            status_tx,
             last_status: Status::Idle,
             embedding_provider,
             _subscription: cx.subscribe(&project, Self::handle_project_event),
+            _maintain_status: cx.spawn(|this, mut cx| async move {
+                while status_rx.next().await.is_some() {
+                    if this
+                        .update(&mut cx, |this, cx| this.update_status(cx))
+                        .is_err()
+                    {
+                        break;
+                    }
+                }
+            }),
         };
         this.update_worktree_indices(cx);
         this
     }
 
+    pub fn status(&self) -> Status {
+        self.last_status
+    }
+
     fn handle_project_event(
         &mut self,
         _: Model<Project>,
@@ -180,19 +195,18 @@ impl ProjectIndex {
                     self.db_connection.clone(),
                     self.language_registry.clone(),
                     self.fs.clone(),
+                    self.status_tx.clone(),
                     self.embedding_provider.clone(),
                     cx,
                 );
 
                 let load_worktree = cx.spawn(|this, mut cx| async move {
-                    if let Some(index) = worktree_index.await.log_err() {
-                        this.update(&mut cx, |this, cx| {
+                    if let Some(worktree_index) = worktree_index.await.log_err() {
+                        this.update(&mut cx, |this, _| {
                             this.worktree_indices.insert(
                                 worktree_id,
                                 WorktreeIndexHandle::Loaded {
-                                    _subscription: cx
-                                        .observe(&index, |this, _, cx| this.update_status(cx)),
-                                    index,
+                                    index: worktree_index,
                                 },
                             );
                         })?;
@@ -215,22 +229,29 @@ impl ProjectIndex {
     }
 
     fn update_status(&mut self, cx: &mut ModelContext<Self>) {
-        let mut status = Status::Idle;
-        for index in self.worktree_indices.values() {
+        let mut indexing_count = 0;
+        let mut any_loading = false;
+
+        for index in self.worktree_indices.values_mut() {
             match index {
                 WorktreeIndexHandle::Loading { .. } => {
-                    status = Status::Scanning;
+                    any_loading = true;
                     break;
                 }
                 WorktreeIndexHandle::Loaded { index, .. } => {
-                    if index.read(cx).status == Status::Scanning {
-                        status = Status::Scanning;
-                        break;
-                    }
+                    indexing_count += index.read(cx).entry_ids_being_indexed.len();
                 }
             }
         }
 
+        let status = if any_loading {
+            Status::Loading
+        } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
+            Status::Scanning { remaining_count }
+        } else {
+            Status::Idle
+        };
+
         if status != self.last_status {
             self.last_status = status;
             cx.emit(status);
@@ -263,6 +284,17 @@ impl ProjectIndex {
             results
         })
     }
+
+    #[cfg(test)]
+    pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
+        let mut result = 0;
+        for worktree_index in self.worktree_indices.values() {
+            if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
+                result += index.read(cx).path_count()?;
+            }
+        }
+        Ok(result)
+    }
 }
 
 pub struct SearchResult {
@@ -275,7 +307,8 @@ pub struct SearchResult {
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum Status {
     Idle,
-    Scanning,
+    Loading,
+    Scanning { remaining_count: NonZeroUsize },
 }
 
 impl EventEmitter<Status> for ProjectIndex {}
@@ -287,7 +320,7 @@ struct WorktreeIndex {
     language_registry: Arc<LanguageRegistry>,
     fs: Arc<dyn Fs>,
     embedding_provider: Arc<dyn EmbeddingProvider>,
-    status: Status,
+    entry_ids_being_indexed: Arc<IndexingEntrySet>,
     _index_entries: Task<Result<()>>,
     _subscription: Subscription,
 }
@@ -298,6 +331,7 @@ impl WorktreeIndex {
         db_connection: heed::Env,
         language_registry: Arc<LanguageRegistry>,
         fs: Arc<dyn Fs>,
+        status_tx: channel::Sender<()>,
         embedding_provider: Arc<dyn EmbeddingProvider>,
         cx: &mut AppContext,
     ) -> Task<Result<Model<Self>>> {
@@ -321,6 +355,7 @@ impl WorktreeIndex {
                     worktree,
                     db_connection,
                     db,
+                    status_tx,
                     language_registry,
                     fs,
                     embedding_provider,
@@ -330,10 +365,12 @@ impl WorktreeIndex {
         })
     }
 
+    #[allow(clippy::too_many_arguments)]
     fn new(
         worktree: Model<Worktree>,
         db_connection: heed::Env,
         db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
+        status: channel::Sender<()>,
         language_registry: Arc<LanguageRegistry>,
         fs: Arc<dyn Fs>,
         embedding_provider: Arc<dyn EmbeddingProvider>,
@@ -353,7 +390,7 @@ impl WorktreeIndex {
             language_registry,
             fs,
             embedding_provider,
-            status: Status::Idle,
+            entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)),
             _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
             _subscription,
         }
@@ -364,28 +401,14 @@ impl WorktreeIndex {
         updated_entries: channel::Receiver<UpdatedEntriesSet>,
         mut cx: AsyncAppContext,
     ) -> Result<()> {
-        let index = this.update(&mut cx, |this, cx| {
-            cx.notify();
-            this.status = Status::Scanning;
-            this.index_entries_changed_on_disk(cx)
-        })?;
+        let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?;
         index.await.log_err();
-        this.update(&mut cx, |this, cx| {
-            this.status = Status::Idle;
-            cx.notify();
-        })?;
 
         while let Ok(updated_entries) = updated_entries.recv().await {
             let index = this.update(&mut cx, |this, cx| {
-                cx.notify();
-                this.status = Status::Scanning;
                 this.index_updated_entries(updated_entries, cx)
             })?;
             index.await.log_err();
-            this.update(&mut cx, |this, cx| {
-                this.status = Status::Idle;
-                cx.notify();
-            })?;
         }
 
         Ok(())
@@ -426,6 +449,7 @@ impl WorktreeIndex {
         let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
         let db_connection = self.db_connection.clone();
         let db = self.db;
+        let entries_being_indexed = self.entry_ids_being_indexed.clone();
         let task = cx.background_executor().spawn(async move {
             let txn = db_connection
                 .read_txn()
@@ -476,7 +500,8 @@ impl WorktreeIndex {
                 }
 
                 if entry.mtime != saved_mtime {
-                    updated_entries_tx.send(entry.clone()).await?;
+                    let handle = entries_being_indexed.insert(&entry);
+                    updated_entries_tx.send((entry.clone(), handle)).await?;
                 }
             }
 
@@ -505,6 +530,7 @@ impl WorktreeIndex {
     ) -> ScanEntries {
         let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
         let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
+        let entries_being_indexed = self.entry_ids_being_indexed.clone();
         let task = cx.background_executor().spawn(async move {
             for (path, entry_id, status) in updated_entries.iter() {
                 match status {
@@ -513,7 +539,8 @@ impl WorktreeIndex {
                     | project::PathChange::AddedOrUpdated => {
                         if let Some(entry) = worktree.entry_for_id(*entry_id) {
                             if entry.is_file() {
-                                updated_entries_tx.send(entry.clone()).await?;
+                                let handle = entries_being_indexed.insert(&entry);
+                                updated_entries_tx.send((entry.clone(), handle)).await?;
                             }
                         }
                     }
@@ -542,7 +569,7 @@ impl WorktreeIndex {
     fn chunk_files(
         &self,
         worktree_abs_path: Arc<Path>,
-        entries: channel::Receiver<Entry>,
+        entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
         cx: &AppContext,
     ) -> ChunkFiles {
         let language_registry = self.language_registry.clone();
@@ -553,7 +580,7 @@ impl WorktreeIndex {
                 .scoped(|cx| {
                     for _ in 0..cx.num_cpus() {
                         cx.spawn(async {
-                            while let Ok(entry) = entries.recv().await {
+                            while let Ok((entry, handle)) = entries.recv().await {
                                 let entry_abs_path = worktree_abs_path.join(&entry.path);
                                 let Some(text) = fs
                                     .load(&entry_abs_path)
@@ -572,8 +599,8 @@ impl WorktreeIndex {
                                 let grammar =
                                     language.as_ref().and_then(|language| language.grammar());
                                 let chunked_file = ChunkedFile {
-                                    worktree_root: worktree_abs_path.clone(),
                                     chunks: chunk_text(&text, grammar),
+                                    handle,
                                     entry,
                                     text,
                                 };
@@ -622,7 +649,11 @@ impl WorktreeIndex {
 
                 let mut embeddings = Vec::new();
                 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
-                    embeddings.extend(embedding_provider.embed(embedding_batch).await?);
+                    if let Some(batch_embeddings) =
+                        embedding_provider.embed(embedding_batch).await.log_err()
+                    {
+                        embeddings.extend_from_slice(&batch_embeddings);
+                    }
                 }
 
                 let mut embeddings = embeddings.into_iter();
@@ -643,7 +674,9 @@ impl WorktreeIndex {
                         chunks: embedded_chunks,
                     };
 
-                    embedded_files_tx.send(embedded_file).await?;
+                    embedded_files_tx
+                        .send((embedded_file, chunked_file.handle))
+                        .await?;
                 }
             }
             Ok(())
@@ -658,7 +691,7 @@ impl WorktreeIndex {
     fn persist_embeddings(
         &self,
         mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
-        embedded_files: channel::Receiver<EmbeddedFile>,
+        embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
         cx: &AppContext,
     ) -> Task<Result<()>> {
         let db_connection = self.db_connection.clone();
@@ -676,12 +709,15 @@ impl WorktreeIndex {
             let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
             while let Some(embedded_files) = embedded_files.next().await {
                 let mut txn = db_connection.write_txn()?;
-                for file in embedded_files {
+                for (file, _) in &embedded_files {
                     log::debug!("saving embedding for file {:?}", file.path);
                     let key = db_key_for_path(&file.path);
-                    db.put(&mut txn, &key, &file)?;
+                    db.put(&mut txn, &key, file)?;
                 }
                 txn.commit()?;
+                eprintln!("committed {:?}", embedded_files.len());
+
+                drop(embedded_files);
                 log::debug!("committed");
             }
 
@@ -789,10 +825,19 @@ impl WorktreeIndex {
             Ok(search_results)
         })
     }
+
+    #[cfg(test)]
+    fn path_count(&self) -> Result<u64> {
+        let txn = self
+            .db_connection
+            .read_txn()
+            .context("failed to create read transaction")?;
+        Ok(self.db.len(&txn)?)
+    }
 }
 
 struct ScanEntries {
-    updated_entries: channel::Receiver<Entry>,
+    updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
     deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
     task: Task<Result<()>>,
 }
@@ -803,15 +848,14 @@ struct ChunkFiles {
 }
 
 struct ChunkedFile {
-    #[allow(dead_code)]
-    pub worktree_root: Arc<Path>,
     pub entry: Entry,
+    pub handle: IndexingEntryHandle,
     pub text: String,
     pub chunks: Vec<Chunk>,
 }
 
 struct EmbedFiles {
-    files: channel::Receiver<EmbeddedFile>,
+    files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
     task: Task<Result<()>>,
 }
 
@@ -828,6 +872,47 @@ struct EmbeddedChunk {
     embedding: Embedding,
 }
 
+struct IndexingEntrySet {
+    entry_ids: Mutex<HashSet<ProjectEntryId>>,
+    tx: channel::Sender<()>,
+}
+
+struct IndexingEntryHandle {
+    entry_id: ProjectEntryId,
+    set: Weak<IndexingEntrySet>,
+}
+
+impl IndexingEntrySet {
+    fn new(tx: channel::Sender<()>) -> Self {
+        Self {
+            entry_ids: Default::default(),
+            tx,
+        }
+    }
+
+    fn insert(self: &Arc<Self>, entry: &project::Entry) -> IndexingEntryHandle {
+        self.entry_ids.lock().insert(entry.id);
+        self.tx.send_blocking(()).ok();
+        IndexingEntryHandle {
+            entry_id: entry.id,
+            set: Arc::downgrade(self),
+        }
+    }
+
+    pub fn len(&self) -> usize {
+        self.entry_ids.lock().len()
+    }
+}
+
+impl Drop for IndexingEntryHandle {
+    fn drop(&mut self) {
+        if let Some(set) = self.set.upgrade() {
+            set.tx.send_blocking(()).ok();
+            set.entry_ids.lock().remove(&self.entry_id);
+        }
+    }
+}
+
 fn db_key_for_path(path: &Arc<Path>) -> String {
     path.to_string_lossy().replace('/', "\0")
 }
@@ -835,10 +920,7 @@ fn db_key_for_path(path: &Arc<Path>) -> String {
 #[cfg(test)]
 mod tests {
     use super::*;
-
-    use futures::channel::oneshot;
     use futures::{future::BoxFuture, FutureExt};
-
     use gpui::{Global, TestAppContext};
     use language::language_settings::AllLanguageSettings;
     use project::Project;
@@ -922,18 +1004,13 @@ mod tests {
 
         let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
 
-        let (tx, rx) = oneshot::channel();
-        let mut tx = Some(tx);
-        let subscription = cx.update(|cx| {
-            cx.subscribe(&project_index, move |_, event, _| {
-                if let Some(tx) = tx.take() {
-                    _ = tx.send(*event);
-                }
-            })
-        });
-
-        rx.await.expect("no event emitted");
-        drop(subscription);
+        while project_index
+            .read_with(cx, |index, cx| index.path_count(cx))
+            .unwrap()
+            == 0
+        {
+            project_index.next_event(cx).await;
+        }
 
         let results = cx
             .update(|cx| {