Cargo.lock 🔗
@@ -8689,6 +8689,7 @@ dependencies = [
"languages",
"log",
"open_ai",
+ "parking_lot",
"project",
"serde",
"serde_json",
Max Brunsfeld , Antonio Scandurra , Kyle , Marshall , and Marshall Bowers created
Release Notes:
- N/A
---------
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Kyle <kylek@zed.dev>
Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>
Cargo.lock | 1
crates/assistant2/examples/assistant_example.rs | 16
crates/assistant2/examples/chat_with_functions.rs | 24
crates/assistant2/examples/file_interactions.rs | 24
crates/assistant2/src/assistant2.rs | 258 +---------------
crates/assistant2/src/tools.rs | 85 ++++-
crates/assistant_tooling/src/registry.rs | 63 +--
crates/assistant_tooling/src/tool.rs | 6
crates/semantic_index/Cargo.toml | 1
crates/semantic_index/src/semantic_index.rs | 217 +++++++++----
10 files changed, 293 insertions(+), 402 deletions(-)
@@ -8689,6 +8689,7 @@ dependencies = [
"languages",
"log",
"open_ai",
+ "parking_lot",
"project",
"serde",
"serde_json",
@@ -87,16 +87,14 @@ fn main() {
let project_index = semantic_index.project_index(project.clone(), cx);
- let mut tool_registry = ToolRegistry::new();
- tool_registry
- .register(ProjectIndexTool::new(project_index.clone(), fs.clone()))
- .context("failed to register ProjectIndexTool")
- .log_err();
-
- let tool_registry = Arc::new(tool_registry);
-
cx.open_window(WindowOptions::default(), |cx| {
- cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
+ let mut tool_registry = ToolRegistry::new();
+ tool_registry
+ .register(ProjectIndexTool::new(project_index.clone(), fs.clone()), cx)
+ .context("failed to register ProjectIndexTool")
+ .log_err();
+
+ cx.new_view(|cx| Example::new(language_registry, Arc::new(tool_registry), cx))
});
cx.activate(true);
})
@@ -135,7 +135,7 @@ impl LanguageModelTool for RollDiceTool {
return Task::ready(Ok(DiceRoll { rolls }));
}
- fn new_view(
+ fn output_view(
_tool_call_id: String,
_input: Self::Input,
result: Result<Self::Output>,
@@ -194,20 +194,20 @@ fn main() {
cx.spawn(|cx| async move {
cx.update(|cx| {
- let mut tool_registry = ToolRegistry::new();
- tool_registry
- .register(RollDiceTool::new())
- .context("failed to register DummyTool")
- .log_err();
+ cx.open_window(WindowOptions::default(), |cx| {
+ let mut tool_registry = ToolRegistry::new();
+ tool_registry
+ .register(RollDiceTool::new(), cx)
+ .context("failed to register DummyTool")
+ .log_err();
- let tool_registry = Arc::new(tool_registry);
+ let tool_registry = Arc::new(tool_registry);
- println!("Tools registered");
- for definition in tool_registry.definitions() {
- println!("{}", definition);
- }
+ println!("Tools registered");
+ for definition in tool_registry.definitions() {
+ println!("{}", definition);
+ }
- cx.open_window(WindowOptions::default(), |cx| {
cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
});
cx.activate(true);
@@ -115,7 +115,7 @@ impl LanguageModelTool for FileBrowserTool {
})
}
- fn new_view(
+ fn output_view(
_tool_call_id: String,
_input: Self::Input,
result: Result<Self::Output>,
@@ -174,20 +174,20 @@ fn main() {
let fs = Arc::new(fs::RealFs::new(None));
let cwd = std::env::current_dir().expect("Failed to get current working directory");
- let mut tool_registry = ToolRegistry::new();
- tool_registry
- .register(FileBrowserTool::new(fs, cwd))
- .context("failed to register FileBrowserTool")
- .log_err();
+ cx.open_window(WindowOptions::default(), |cx| {
+ let mut tool_registry = ToolRegistry::new();
+ tool_registry
+ .register(FileBrowserTool::new(fs, cwd), cx)
+ .context("failed to register FileBrowserTool")
+ .log_err();
- let tool_registry = Arc::new(tool_registry);
+ let tool_registry = Arc::new(tool_registry);
- println!("Tools registered");
- for definition in tool_registry.definitions() {
- println!("{}", definition);
- }
+ println!("Tools registered");
+ for definition in tool_registry.definitions() {
+ println!("{}", definition);
+ }
- cx.open_window(WindowOptions::default(), |cx| {
cx.new_view(|cx| Example::new(language_registry, tool_registry, cx))
});
cx.activate(true);
@@ -8,22 +8,21 @@ use client::{proto, Client};
use completion_provider::*;
use editor::Editor;
use feature_flags::FeatureFlagAppExt as _;
-use futures::{channel::oneshot, future::join_all, Future, FutureExt, StreamExt};
+use futures::{future::join_all, StreamExt};
use gpui::{
list, prelude::*, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
- FocusableView, Global, ListAlignment, ListState, Model, Render, Task, View, WeakView,
+ FocusableView, Global, ListAlignment, ListState, Render, Task, View, WeakView,
};
use language::{language_settings::SoftWrap, LanguageRegistry};
use open_ai::{FunctionContent, ToolCall, ToolCallContent};
-use project::Fs;
use rich_text::RichText;
-use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex};
+use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::Deserialize;
use settings::Settings;
-use std::{cmp, sync::Arc};
+use std::sync::Arc;
use theme::ThemeSettings;
use tools::ProjectIndexTool;
-use ui::{popover_menu, prelude::*, ButtonLike, CollapsibleContainer, Color, ContextMenu, Tooltip};
+use ui::{popover_menu, prelude::*, ButtonLike, Color, ContextMenu, Tooltip};
use util::{paths::EMBEDDINGS_DIR, ResultExt};
use workspace::{
dock::{DockPosition, Panel, PanelEvent},
@@ -110,10 +109,10 @@ impl AssistantPanel {
let mut tool_registry = ToolRegistry::new();
tool_registry
- .register(ProjectIndexTool::new(
- project_index.clone(),
- app_state.fs.clone(),
- ))
+ .register(
+ ProjectIndexTool::new(project_index.clone(), app_state.fs.clone()),
+ cx,
+ )
.context("failed to register ProjectIndexTool")
.log_err();
@@ -447,11 +446,7 @@ impl AssistantChat {
}
editor
});
- let message = ChatMessage::User(UserMessage {
- id,
- body,
- contexts: Vec::new(),
- });
+ let message = ChatMessage::User(UserMessage { id, body });
self.push_message(message, cx);
}
@@ -525,11 +520,7 @@ impl AssistantChat {
let is_last = ix == self.messages.len() - 1;
match &self.messages[ix] {
- ChatMessage::User(UserMessage {
- body,
- contexts: _contexts,
- ..
- }) => div()
+ ChatMessage::User(UserMessage { body, .. }) => div()
.when(!is_last, |element| element.mb_2())
.child(div().p_2().child(Label::new("You").color(Color::Default)))
.child(
@@ -539,7 +530,7 @@ impl AssistantChat {
.text_color(cx.theme().colors().editor_foreground)
.font(ThemeSettings::get_global(cx).buffer_font.clone())
.bg(cx.theme().colors().editor_background)
- .child(body.clone()), // .children(contexts.iter().map(|context| context.render(cx))),
+ .child(body.clone()),
)
.into_any(),
ChatMessage::Assistant(AssistantMessage {
@@ -588,11 +579,11 @@ impl AssistantChat {
for message in &self.messages {
match message {
- ChatMessage::User(UserMessage { body, contexts, .. }) => {
- // setup context for model
- contexts.iter().for_each(|context| {
- completion_messages.extend(context.completion_messages(cx))
- });
+ ChatMessage::User(UserMessage { body, .. }) => {
+ // When we re-introduce contexts like active file, we'll inject them here instead of relying on the model to request them
+ // contexts.iter().for_each(|context| {
+ // completion_messages.extend(context.completion_messages(cx))
+ // });
// Show user's message last so that the assistant is grounded in the user's request
completion_messages.push(CompletionMessage::User {
@@ -712,6 +703,12 @@ impl Render for AssistantChat {
.text_color(Color::Default.color(cx))
.child(self.render_model_dropdown(cx))
.child(list(self.list_state.clone()).flex_1())
+ .child(
+ h_flex()
+ .mt_2()
+ .gap_2()
+ .children(self.tool_registry.status_views().iter().cloned()),
+ )
}
}
@@ -743,7 +740,6 @@ impl ChatMessage {
struct UserMessage {
id: MessageId,
body: View<Editor>,
- contexts: Vec<AssistantContext>,
}
struct AssistantMessage {
@@ -752,211 +748,3 @@ struct AssistantMessage {
tool_calls: Vec<ToolFunctionCall>,
error: Option<SharedString>,
}
-
-// Since we're swapping out for direct query usage, we might not need to use this injected context
-// It will be useful though for when the user _definitely_ wants the model to see a specific file,
-// query, error, etc.
-#[allow(dead_code)]
-enum AssistantContext {
- Codebase(View<CodebaseContext>),
-}
-
-#[allow(dead_code)]
-struct CodebaseExcerpt {
- element_id: ElementId,
- path: SharedString,
- text: SharedString,
- score: f32,
- expanded: bool,
-}
-
-impl AssistantContext {
- #[allow(dead_code)]
- fn render(&self, _cx: &mut ViewContext<AssistantChat>) -> AnyElement {
- match self {
- AssistantContext::Codebase(context) => context.clone().into_any_element(),
- }
- }
-
- fn completion_messages(&self, cx: &WindowContext) -> Vec<CompletionMessage> {
- match self {
- AssistantContext::Codebase(context) => context.read(cx).completion_messages(),
- }
- }
-}
-
-enum CodebaseContext {
- Pending { _task: Task<()> },
- Done(Result<Vec<CodebaseExcerpt>>),
-}
-
-impl CodebaseContext {
- fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext<Self>) {
- if let CodebaseContext::Done(Ok(excerpts)) = self {
- if let Some(excerpt) = excerpts
- .iter_mut()
- .find(|excerpt| excerpt.element_id == element_id)
- {
- excerpt.expanded = !excerpt.expanded;
- cx.notify();
- }
- }
- }
-}
-
-impl Render for CodebaseContext {
- fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
- match self {
- CodebaseContext::Pending { .. } => div()
- .h_flex()
- .items_center()
- .gap_1()
- .child(Icon::new(IconName::Ai).color(Color::Muted).into_element())
- .child("Searching codebase..."),
- CodebaseContext::Done(Ok(excerpts)) => {
- div()
- .v_flex()
- .gap_2()
- .children(excerpts.iter().map(|excerpt| {
- let expanded = excerpt.expanded;
- let element_id = excerpt.element_id.clone();
-
- CollapsibleContainer::new(element_id.clone(), expanded)
- .start_slot(
- h_flex()
- .gap_1()
- .child(Icon::new(IconName::File).color(Color::Muted))
- .child(Label::new(excerpt.path.clone()).color(Color::Muted)),
- )
- .on_click(cx.listener(move |this, _, cx| {
- this.toggle_expanded(element_id.clone(), cx);
- }))
- .child(
- div()
- .p_2()
- .rounded_md()
- .bg(cx.theme().colors().editor_background)
- .child(
- excerpt.text.clone(), // todo!(): Show as an editor block
- ),
- )
- }))
- }
- CodebaseContext::Done(Err(error)) => div().child(error.to_string()),
- }
- }
-}
-
-impl CodebaseContext {
- #[allow(dead_code)]
- fn new(
- query: impl 'static + Future<Output = Result<String>>,
- populated: oneshot::Sender<bool>,
- project_index: Model<ProjectIndex>,
- fs: Arc<dyn Fs>,
- cx: &mut ViewContext<Self>,
- ) -> Self {
- let query = query.boxed_local();
- let _task = cx.spawn(|this, mut cx| async move {
- let result = async {
- let query = query.await?;
- let results = this
- .update(&mut cx, |_this, cx| {
- project_index.read(cx).search(&query, 16, cx)
- })?
- .await;
-
- let excerpts = results.into_iter().map(|result| {
- let abs_path = result
- .worktree
- .read_with(&cx, |worktree, _| worktree.abs_path().join(&result.path));
- let fs = fs.clone();
-
- async move {
- let path = result.path.clone();
- let text = fs.load(&abs_path?).await?;
- // todo!("what should we do with stale ranges?");
- let range = cmp::min(result.range.start, text.len())
- ..cmp::min(result.range.end, text.len());
-
- let text = SharedString::from(text[range].to_string());
-
- anyhow::Ok(CodebaseExcerpt {
- element_id: ElementId::Name(nanoid::nanoid!().into()),
- path: path.to_string_lossy().to_string().into(),
- text,
- score: result.score,
- expanded: false,
- })
- }
- });
-
- anyhow::Ok(
- futures::future::join_all(excerpts)
- .await
- .into_iter()
- .filter_map(|result| result.log_err())
- .collect(),
- )
- }
- .await;
-
- this.update(&mut cx, |this, cx| {
- this.populate(result, populated, cx);
- })
- .ok();
- });
-
- Self::Pending { _task }
- }
-
- #[allow(dead_code)]
- fn populate(
- &mut self,
- result: Result<Vec<CodebaseExcerpt>>,
- populated: oneshot::Sender<bool>,
- cx: &mut ViewContext<Self>,
- ) {
- let success = result.is_ok();
- *self = Self::Done(result);
- populated.send(success).ok();
- cx.notify();
- }
-
- fn completion_messages(&self) -> Vec<CompletionMessage> {
- // One system message for the whole batch of excerpts:
-
- // Semantic search results for user query:
- //
- // Excerpt from $path:
- // ~~~
- // `text`
- // ~~~
- //
- // Excerpt from $path:
-
- match self {
- CodebaseContext::Done(Ok(excerpts)) => {
- if excerpts.is_empty() {
- return Vec::new();
- }
-
- let mut body = "Semantic search results for user query:\n".to_string();
-
- for excerpt in excerpts {
- body.push_str("Excerpt from ");
- body.push_str(excerpt.path.as_ref());
- body.push_str(", score ");
- body.push_str(&excerpt.score.to_string());
- body.push_str(":\n");
- body.push_str("~~~\n");
- body.push_str(excerpt.text.as_ref());
- body.push_str("~~~\n");
- }
-
- vec![CompletionMessage::System { content: body }]
- }
- _ => vec![],
- }
- }
-}
@@ -1,9 +1,9 @@
use anyhow::Result;
use assistant_tooling::LanguageModelTool;
-use gpui::{prelude::*, AppContext, Model, Task};
+use gpui::{prelude::*, AnyView, AppContext, Model, Task};
use project::Fs;
use schemars::JsonSchema;
-use semantic_index::ProjectIndex;
+use semantic_index::{ProjectIndex, Status};
use serde::Deserialize;
use std::sync::Arc;
use ui::{
@@ -36,13 +36,14 @@ pub struct CodebaseQuery {
pub struct ProjectIndexView {
input: CodebaseQuery,
- output: Result<Vec<CodebaseExcerpt>>,
+ output: Result<ProjectIndexOutput>,
}
impl ProjectIndexView {
fn toggle_expanded(&mut self, element_id: ElementId, cx: &mut ViewContext<Self>) {
- if let Ok(excerpts) = &mut self.output {
- if let Some(excerpt) = excerpts
+ if let Ok(output) = &mut self.output {
+ if let Some(excerpt) = output
+ .excerpts
.iter_mut()
.find(|excerpt| excerpt.element_id == element_id)
{
@@ -59,11 +60,11 @@ impl Render for ProjectIndexView {
let result = &self.output;
- let excerpts = match result {
+ let output = match result {
Err(err) => {
return div().child(Label::new(format!("Error: {}", err)).color(Color::Error));
}
- Ok(excerpts) => excerpts,
+ Ok(output) => output,
};
div()
@@ -80,7 +81,7 @@ impl Render for ProjectIndexView {
.child(Label::new(query).color(Color::Muted)),
),
)
- .children(excerpts.iter().map(|excerpt| {
+ .children(output.excerpts.iter().map(|excerpt| {
let element_id = excerpt.element_id.clone();
let expanded = excerpt.expanded;
@@ -99,9 +100,7 @@ impl Render for ProjectIndexView {
.p_2()
.rounded_md()
.bg(cx.theme().colors().editor_background)
- .child(
- excerpt.text.clone(), // todo!(): Show as an editor block
- ),
+ .child(excerpt.text.clone()),
)
}))
}
@@ -112,8 +111,15 @@ pub struct ProjectIndexTool {
fs: Arc<dyn Fs>,
}
+pub struct ProjectIndexOutput {
+ excerpts: Vec<CodebaseExcerpt>,
+ status: Status,
+}
+
impl ProjectIndexTool {
pub fn new(project_index: Model<ProjectIndex>, fs: Arc<dyn Fs>) -> Self {
+ // Listen for project index status and update the ProjectIndexTool directly
+
// TODO: setup a better description based on the user's current codebase.
Self { project_index, fs }
}
@@ -121,7 +127,7 @@ impl ProjectIndexTool {
impl LanguageModelTool for ProjectIndexTool {
type Input = CodebaseQuery;
- type Output = Vec<CodebaseExcerpt>;
+ type Output = ProjectIndexOutput;
type View = ProjectIndexView;
fn name(&self) -> String {
@@ -135,6 +141,7 @@ impl LanguageModelTool for ProjectIndexTool {
fn execute(&self, query: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>> {
let project_index = self.project_index.read(cx);
+ let status = project_index.status();
let results = project_index.search(
query.query.as_str(),
query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
@@ -180,11 +187,11 @@ impl LanguageModelTool for ProjectIndexTool {
.into_iter()
.filter_map(|result| result.log_err())
.collect();
- anyhow::Ok(excerpts)
+ anyhow::Ok(ProjectIndexOutput { excerpts, status })
})
}
- fn new_view(
+ fn output_view(
_tool_call_id: String,
input: Self::Input,
output: Result<Self::Output>,
@@ -193,16 +200,28 @@ impl LanguageModelTool for ProjectIndexTool {
cx.new_view(|_cx| ProjectIndexView { input, output })
}
+ fn status_view(&self, cx: &mut WindowContext) -> Option<AnyView> {
+ Some(
+ cx.new_view(|cx| ProjectIndexStatusView::new(self.project_index.clone(), cx))
+ .into(),
+ )
+ }
+
fn format(_input: &Self::Input, output: &Result<Self::Output>) -> String {
match &output {
- Ok(excerpts) => {
- if excerpts.len() == 0 {
- return "No results found".to_string();
+ Ok(output) => {
+ let mut body = "Semantic search results:\n".to_string();
+
+ if output.status != Status::Idle {
+ body.push_str("Still indexing. Results may be incomplete.\n");
}
- let mut body = "Semantic search results:\n".to_string();
+ if output.excerpts.is_empty() {
+ body.push_str("No results found");
+ return body;
+ }
- for excerpt in excerpts {
+ for excerpt in &output.excerpts {
body.push_str("Excerpt from ");
body.push_str(excerpt.path.as_ref());
body.push_str(", score ");
@@ -218,3 +237,31 @@ impl LanguageModelTool for ProjectIndexTool {
}
}
}
+
+struct ProjectIndexStatusView {
+ project_index: Model<ProjectIndex>,
+}
+
+impl ProjectIndexStatusView {
+ pub fn new(project_index: Model<ProjectIndex>, cx: &mut ViewContext<Self>) -> Self {
+ cx.subscribe(&project_index, |_this, _, _status: &Status, cx| {
+ cx.notify();
+ })
+ .detach();
+ Self { project_index }
+ }
+}
+
+impl Render for ProjectIndexStatusView {
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ let status = self.project_index.read(cx).status();
+
+ h_flex().gap_2().map(|element| match status {
+ Status::Idle => element.child(Label::new("Project index ready")),
+ Status::Loading => element.child(Label::new("Project index loading...")),
+ Status::Scanning { remaining_count } => element.child(Label::new(format!(
+ "Project index scanning: {remaining_count} remaining..."
+ ))),
+ })
+ }
+}
@@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
-use gpui::{Task, WindowContext};
+use gpui::{AnyView, Task, WindowContext};
use std::collections::HashMap;
use crate::tool::{
@@ -12,6 +12,7 @@ pub struct ToolRegistry {
Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
>,
definitions: Vec<ToolFunctionDefinition>,
+ status_views: Vec<AnyView>,
}
impl ToolRegistry {
@@ -19,6 +20,7 @@ impl ToolRegistry {
Self {
tools: HashMap::new(),
definitions: Vec::new(),
+ status_views: Vec::new(),
}
}
@@ -26,8 +28,17 @@ impl ToolRegistry {
&self.definitions
}
- pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
+ pub fn register<T: 'static + LanguageModelTool>(
+ &mut self,
+ tool: T,
+ cx: &mut WindowContext,
+ ) -> Result<()> {
self.definitions.push(tool.definition());
+
+ if let Some(tool_view) = tool.status_view(cx) {
+ self.status_views.push(tool_view);
+ }
+
let name = tool.name();
let previous = self.tools.insert(
name.clone(),
@@ -52,7 +63,7 @@ impl ToolRegistry {
cx.spawn(move |mut cx| async move {
let result: Result<T::Output> = result.await;
let for_model = T::format(&input, &result);
- let view = cx.update(|cx| T::new_view(id.clone(), input, result, cx))?;
+ let view = cx.update(|cx| T::output_view(id.clone(), input, result, cx))?;
Ok(ToolFunctionCall {
id,
@@ -100,6 +111,10 @@ impl ToolRegistry {
tool(tool_call, cx)
}
+
+ pub fn status_views(&self) -> &[AnyView] {
+ &self.status_views
+ }
}
#[cfg(test)]
@@ -165,7 +180,7 @@ mod test {
Task::ready(Ok(weather))
}
- fn new_view(
+ fn output_view(
_tool_call_id: String,
_input: Self::Input,
result: Result<Self::Output>,
@@ -182,46 +197,6 @@ mod test {
}
}
- #[gpui::test]
- async fn test_function_registry(cx: &mut TestAppContext) {
- cx.background_executor.run_until_parked();
-
- let mut registry = ToolRegistry::new();
-
- let tool = WeatherTool {
- current_weather: WeatherResult {
- location: "San Francisco".to_string(),
- temperature: 21.0,
- unit: "Celsius".to_string(),
- },
- };
-
- registry.register(tool).unwrap();
-
- // let _result = cx
- // .update(|cx| {
- // registry.call(
- // &ToolFunctionCall {
- // name: "get_current_weather".to_string(),
- // arguments: r#"{ "location": "San Francisco", "unit": "Celsius" }"#
- // .to_string(),
- // id: "test-123".to_string(),
- // result: None,
- // },
- // cx,
- // )
- // })
- // .await;
-
- // assert!(result.is_ok());
- // let result = result.unwrap();
-
- // let expected = r#"{"location":"San Francisco","temperature":21.0,"unit":"Celsius"}"#;
-
- // todo!(): Put this back in after the interface is stabilized
- // assert_eq!(result, expected);
- }
-
#[gpui::test]
async fn test_openai_weather_example(cx: &mut TestAppContext) {
cx.background_executor.run_until_parked();
@@ -95,10 +95,14 @@ pub trait LanguageModelTool {
fn format(input: &Self::Input, output: &Result<Self::Output>) -> String;
- fn new_view(
+ fn output_view(
tool_call_id: String,
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
) -> View<Self::View>;
+
+ fn status_view(&self, _cx: &mut WindowContext) -> Option<AnyView> {
+ None
+ }
}
@@ -30,6 +30,7 @@ language.workspace = true
log.workspace = true
heed.workspace = true
open_ai.workspace = true
+parking_lot.workspace = true
project.workspace = true
settings.workspace = true
serde.workspace = true
@@ -3,7 +3,7 @@ mod embedding;
use anyhow::{anyhow, Context as _, Result};
use chunking::{chunk_text, Chunk};
-use collections::{Bound, HashMap};
+use collections::{Bound, HashMap, HashSet};
pub use embedding::*;
use fs::Fs;
use futures::stream::StreamExt;
@@ -14,15 +14,17 @@ use gpui::{
};
use heed::types::{SerdeBincode, Str};
use language::LanguageRegistry;
-use project::{Entry, Project, UpdatedEntriesSet, Worktree};
+use parking_lot::Mutex;
+use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree};
use serde::{Deserialize, Serialize};
use smol::channel;
use std::{
cmp::Ordering,
future::Future,
+ num::NonZeroUsize,
ops::Range,
path::{Path, PathBuf},
- sync::Arc,
+ sync::{Arc, Weak},
time::{Duration, SystemTime},
};
use util::ResultExt;
@@ -102,19 +104,16 @@ pub struct ProjectIndex {
worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
- pub last_status: Status,
+ last_status: Status,
+ status_tx: channel::Sender<()>,
embedding_provider: Arc<dyn EmbeddingProvider>,
+ _maintain_status: Task<()>,
_subscription: Subscription,
}
enum WorktreeIndexHandle {
- Loading {
- _task: Task<Result<()>>,
- },
- Loaded {
- index: Model<WorktreeIndex>,
- _subscription: Subscription,
- },
+ Loading { _task: Task<Result<()>> },
+ Loaded { index: Model<WorktreeIndex> },
}
impl ProjectIndex {
@@ -126,20 +125,36 @@ impl ProjectIndex {
) -> Self {
let language_registry = project.read(cx).languages().clone();
let fs = project.read(cx).fs().clone();
+ let (status_tx, mut status_rx) = channel::unbounded();
let mut this = ProjectIndex {
db_connection,
project: project.downgrade(),
worktree_indices: HashMap::default(),
language_registry,
fs,
+ status_tx,
last_status: Status::Idle,
embedding_provider,
_subscription: cx.subscribe(&project, Self::handle_project_event),
+ _maintain_status: cx.spawn(|this, mut cx| async move {
+ while status_rx.next().await.is_some() {
+ if this
+ .update(&mut cx, |this, cx| this.update_status(cx))
+ .is_err()
+ {
+ break;
+ }
+ }
+ }),
};
this.update_worktree_indices(cx);
this
}
+ pub fn status(&self) -> Status {
+ self.last_status
+ }
+
fn handle_project_event(
&mut self,
_: Model<Project>,
@@ -180,19 +195,18 @@ impl ProjectIndex {
self.db_connection.clone(),
self.language_registry.clone(),
self.fs.clone(),
+ self.status_tx.clone(),
self.embedding_provider.clone(),
cx,
);
let load_worktree = cx.spawn(|this, mut cx| async move {
- if let Some(index) = worktree_index.await.log_err() {
- this.update(&mut cx, |this, cx| {
+ if let Some(worktree_index) = worktree_index.await.log_err() {
+ this.update(&mut cx, |this, _| {
this.worktree_indices.insert(
worktree_id,
WorktreeIndexHandle::Loaded {
- _subscription: cx
- .observe(&index, |this, _, cx| this.update_status(cx)),
- index,
+ index: worktree_index,
},
);
})?;
@@ -215,22 +229,29 @@ impl ProjectIndex {
}
fn update_status(&mut self, cx: &mut ModelContext<Self>) {
- let mut status = Status::Idle;
- for index in self.worktree_indices.values() {
+ let mut indexing_count = 0;
+ let mut any_loading = false;
+
+ for index in self.worktree_indices.values_mut() {
match index {
WorktreeIndexHandle::Loading { .. } => {
- status = Status::Scanning;
+ any_loading = true;
break;
}
WorktreeIndexHandle::Loaded { index, .. } => {
- if index.read(cx).status == Status::Scanning {
- status = Status::Scanning;
- break;
- }
+ indexing_count += index.read(cx).entry_ids_being_indexed.len();
}
}
}
+ let status = if any_loading {
+ Status::Loading
+ } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
+ Status::Scanning { remaining_count }
+ } else {
+ Status::Idle
+ };
+
if status != self.last_status {
self.last_status = status;
cx.emit(status);
@@ -263,6 +284,17 @@ impl ProjectIndex {
results
})
}
+
+ #[cfg(test)]
+ pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
+ let mut result = 0;
+ for worktree_index in self.worktree_indices.values() {
+ if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
+ result += index.read(cx).path_count()?;
+ }
+ }
+ Ok(result)
+ }
}
pub struct SearchResult {
@@ -275,7 +307,8 @@ pub struct SearchResult {
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Status {
Idle,
- Scanning,
+ Loading,
+ Scanning { remaining_count: NonZeroUsize },
}
impl EventEmitter<Status> for ProjectIndex {}
@@ -287,7 +320,7 @@ struct WorktreeIndex {
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
embedding_provider: Arc<dyn EmbeddingProvider>,
- status: Status,
+ entry_ids_being_indexed: Arc<IndexingEntrySet>,
_index_entries: Task<Result<()>>,
_subscription: Subscription,
}
@@ -298,6 +331,7 @@ impl WorktreeIndex {
db_connection: heed::Env,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
+ status_tx: channel::Sender<()>,
embedding_provider: Arc<dyn EmbeddingProvider>,
cx: &mut AppContext,
) -> Task<Result<Model<Self>>> {
@@ -321,6 +355,7 @@ impl WorktreeIndex {
worktree,
db_connection,
db,
+ status_tx,
language_registry,
fs,
embedding_provider,
@@ -330,10 +365,12 @@ impl WorktreeIndex {
})
}
+ #[allow(clippy::too_many_arguments)]
fn new(
worktree: Model<Worktree>,
db_connection: heed::Env,
db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
+ status: channel::Sender<()>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
embedding_provider: Arc<dyn EmbeddingProvider>,
@@ -353,7 +390,7 @@ impl WorktreeIndex {
language_registry,
fs,
embedding_provider,
- status: Status::Idle,
+ entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)),
_index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
_subscription,
}
@@ -364,28 +401,14 @@ impl WorktreeIndex {
updated_entries: channel::Receiver<UpdatedEntriesSet>,
mut cx: AsyncAppContext,
) -> Result<()> {
- let index = this.update(&mut cx, |this, cx| {
- cx.notify();
- this.status = Status::Scanning;
- this.index_entries_changed_on_disk(cx)
- })?;
+ let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?;
index.await.log_err();
- this.update(&mut cx, |this, cx| {
- this.status = Status::Idle;
- cx.notify();
- })?;
while let Ok(updated_entries) = updated_entries.recv().await {
let index = this.update(&mut cx, |this, cx| {
- cx.notify();
- this.status = Status::Scanning;
this.index_updated_entries(updated_entries, cx)
})?;
index.await.log_err();
- this.update(&mut cx, |this, cx| {
- this.status = Status::Idle;
- cx.notify();
- })?;
}
Ok(())
@@ -426,6 +449,7 @@ impl WorktreeIndex {
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
let db_connection = self.db_connection.clone();
let db = self.db;
+ let entries_being_indexed = self.entry_ids_being_indexed.clone();
let task = cx.background_executor().spawn(async move {
let txn = db_connection
.read_txn()
@@ -476,7 +500,8 @@ impl WorktreeIndex {
}
if entry.mtime != saved_mtime {
- updated_entries_tx.send(entry.clone()).await?;
+ let handle = entries_being_indexed.insert(&entry);
+ updated_entries_tx.send((entry.clone(), handle)).await?;
}
}
@@ -505,6 +530,7 @@ impl WorktreeIndex {
) -> ScanEntries {
let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
+ let entries_being_indexed = self.entry_ids_being_indexed.clone();
let task = cx.background_executor().spawn(async move {
for (path, entry_id, status) in updated_entries.iter() {
match status {
@@ -513,7 +539,8 @@ impl WorktreeIndex {
| project::PathChange::AddedOrUpdated => {
if let Some(entry) = worktree.entry_for_id(*entry_id) {
if entry.is_file() {
- updated_entries_tx.send(entry.clone()).await?;
+ let handle = entries_being_indexed.insert(&entry);
+ updated_entries_tx.send((entry.clone(), handle)).await?;
}
}
}
@@ -542,7 +569,7 @@ impl WorktreeIndex {
fn chunk_files(
&self,
worktree_abs_path: Arc<Path>,
- entries: channel::Receiver<Entry>,
+ entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
cx: &AppContext,
) -> ChunkFiles {
let language_registry = self.language_registry.clone();
@@ -553,7 +580,7 @@ impl WorktreeIndex {
.scoped(|cx| {
for _ in 0..cx.num_cpus() {
cx.spawn(async {
- while let Ok(entry) = entries.recv().await {
+ while let Ok((entry, handle)) = entries.recv().await {
let entry_abs_path = worktree_abs_path.join(&entry.path);
let Some(text) = fs
.load(&entry_abs_path)
@@ -572,8 +599,8 @@ impl WorktreeIndex {
let grammar =
language.as_ref().and_then(|language| language.grammar());
let chunked_file = ChunkedFile {
- worktree_root: worktree_abs_path.clone(),
chunks: chunk_text(&text, grammar),
+ handle,
entry,
text,
};
@@ -622,7 +649,11 @@ impl WorktreeIndex {
let mut embeddings = Vec::new();
for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
- embeddings.extend(embedding_provider.embed(embedding_batch).await?);
+ if let Some(batch_embeddings) =
+ embedding_provider.embed(embedding_batch).await.log_err()
+ {
+ embeddings.extend_from_slice(&batch_embeddings);
+ }
}
let mut embeddings = embeddings.into_iter();
@@ -643,7 +674,9 @@ impl WorktreeIndex {
chunks: embedded_chunks,
};
- embedded_files_tx.send(embedded_file).await?;
+ embedded_files_tx
+ .send((embedded_file, chunked_file.handle))
+ .await?;
}
}
Ok(())
@@ -658,7 +691,7 @@ impl WorktreeIndex {
fn persist_embeddings(
&self,
mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
- embedded_files: channel::Receiver<EmbeddedFile>,
+ embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
cx: &AppContext,
) -> Task<Result<()>> {
let db_connection = self.db_connection.clone();
@@ -676,12 +709,15 @@ impl WorktreeIndex {
let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
while let Some(embedded_files) = embedded_files.next().await {
let mut txn = db_connection.write_txn()?;
- for file in embedded_files {
+ for (file, _) in &embedded_files {
log::debug!("saving embedding for file {:?}", file.path);
let key = db_key_for_path(&file.path);
- db.put(&mut txn, &key, &file)?;
+ db.put(&mut txn, &key, file)?;
}
txn.commit()?;
+ eprintln!("committed {:?}", embedded_files.len());
+
+ drop(embedded_files);
log::debug!("committed");
}
@@ -789,10 +825,19 @@ impl WorktreeIndex {
Ok(search_results)
})
}
+
+ #[cfg(test)]
+ fn path_count(&self) -> Result<u64> {
+ let txn = self
+ .db_connection
+ .read_txn()
+ .context("failed to create read transaction")?;
+ Ok(self.db.len(&txn)?)
+ }
}
struct ScanEntries {
- updated_entries: channel::Receiver<Entry>,
+ updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
task: Task<Result<()>>,
}
@@ -803,15 +848,14 @@ struct ChunkFiles {
}
struct ChunkedFile {
- #[allow(dead_code)]
- pub worktree_root: Arc<Path>,
pub entry: Entry,
+ pub handle: IndexingEntryHandle,
pub text: String,
pub chunks: Vec<Chunk>,
}
struct EmbedFiles {
- files: channel::Receiver<EmbeddedFile>,
+ files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
task: Task<Result<()>>,
}
@@ -828,6 +872,47 @@ struct EmbeddedChunk {
embedding: Embedding,
}
+struct IndexingEntrySet {
+ entry_ids: Mutex<HashSet<ProjectEntryId>>,
+ tx: channel::Sender<()>,
+}
+
+struct IndexingEntryHandle {
+ entry_id: ProjectEntryId,
+ set: Weak<IndexingEntrySet>,
+}
+
+impl IndexingEntrySet {
+ fn new(tx: channel::Sender<()>) -> Self {
+ Self {
+ entry_ids: Default::default(),
+ tx,
+ }
+ }
+
+ fn insert(self: &Arc<Self>, entry: &project::Entry) -> IndexingEntryHandle {
+ self.entry_ids.lock().insert(entry.id);
+ self.tx.send_blocking(()).ok();
+ IndexingEntryHandle {
+ entry_id: entry.id,
+ set: Arc::downgrade(self),
+ }
+ }
+
+ pub fn len(&self) -> usize {
+ self.entry_ids.lock().len()
+ }
+}
+
+impl Drop for IndexingEntryHandle {
+ fn drop(&mut self) {
+ if let Some(set) = self.set.upgrade() {
+ set.tx.send_blocking(()).ok();
+ set.entry_ids.lock().remove(&self.entry_id);
+ }
+ }
+}
+
fn db_key_for_path(path: &Arc<Path>) -> String {
path.to_string_lossy().replace('/', "\0")
}
@@ -835,10 +920,7 @@ fn db_key_for_path(path: &Arc<Path>) -> String {
#[cfg(test)]
mod tests {
use super::*;
-
- use futures::channel::oneshot;
use futures::{future::BoxFuture, FutureExt};
-
use gpui::{Global, TestAppContext};
use language::language_settings::AllLanguageSettings;
use project::Project;
@@ -922,18 +1004,13 @@ mod tests {
let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
- let (tx, rx) = oneshot::channel();
- let mut tx = Some(tx);
- let subscription = cx.update(|cx| {
- cx.subscribe(&project_index, move |_, event, _| {
- if let Some(tx) = tx.take() {
- _ = tx.send(*event);
- }
- })
- });
-
- rx.await.expect("no event emitted");
- drop(subscription);
+ while project_index
+ .read_with(cx, |index, cx| index.path_count(cx))
+ .unwrap()
+ == 0
+ {
+ project_index.next_event(cx).await;
+ }
let results = cx
.update(|cx| {