diff --git a/Cargo.lock b/Cargo.lock index 87557afcb1b868cf9321bc0a4746e92687bb456d..6d41fbe96fac878f496e93461c180e1c184216d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5342,6 +5342,32 @@ dependencies = [ "zlog", ] +[[package]] +name = "edit_prediction_context2" +version = "0.1.0" +dependencies = [ + "anyhow", + "collections", + "env_logger 0.11.8", + "futures 0.3.31", + "gpui", + "indoc", + "language", + "log", + "lsp", + "parking_lot", + "pretty_assertions", + "project", + "serde", + "serde_json", + "settings", + "smallvec", + "text", + "tree-sitter", + "util", + "zlog", +] + [[package]] name = "editor" version = "0.1.0" @@ -21693,6 +21719,7 @@ dependencies = [ "db", "edit_prediction", "edit_prediction_context", + "edit_prediction_context2", "editor", "feature_flags", "fs", @@ -21742,7 +21769,6 @@ dependencies = [ "clap", "client", "cloud_llm_client", - "cloud_zeta2_prompt", "collections", "edit_prediction_context", "editor", diff --git a/Cargo.toml b/Cargo.toml index 59b9a53d4a60b28582625fb90b64b934079cdc40..62a44dbf35fefbf02a1b570146b0bf24cea6dcd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ members = [ "crates/edit_prediction", "crates/edit_prediction_button", "crates/edit_prediction_context", + "crates/edit_prediction_context2", "crates/zeta2_tools", "crates/editor", "crates/eval", @@ -316,6 +317,7 @@ image_viewer = { path = "crates/image_viewer" } edit_prediction = { path = "crates/edit_prediction" } edit_prediction_button = { path = "crates/edit_prediction_button" } edit_prediction_context = { path = "crates/edit_prediction_context" } +edit_prediction_context2 = { path = "crates/edit_prediction_context2" } zeta2_tools = { path = "crates/zeta2_tools" } inspector_ui = { path = "crates/inspector_ui" } install_cli = { path = "crates/install_cli" } diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index 8ce8441859b7cc747a2b566dedd913e58259969d..8b234497376aefdc972681c877a1122f3f9cee17 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -1105,9 +1105,33 @@ impl EditPredictionButton { .separator(); } - let menu = self.build_language_settings_menu(menu, window, cx); - let menu = self.add_provider_switching_section(menu, provider, cx); + menu = self.build_language_settings_menu(menu, window, cx); + + if cx.has_flag::() { + let settings = all_language_settings(None, cx); + let context_retrieval = settings.edit_predictions.use_context; + menu = menu.separator().header("Context Retrieval").item( + ContextMenuEntry::new("Enable Context Retrieval") + .toggleable(IconPosition::Start, context_retrieval) + .action(workspace::ToggleEditPrediction.boxed_clone()) + .handler({ + let fs = self.fs.clone(); + move |_, cx| { + update_settings_file(fs.clone(), cx, move |settings, _| { + settings + .project + .all_languages + .features + .get_or_insert_default() + .experimental_edit_prediction_context_retrieval = + Some(!context_retrieval) + }); + } + }), + ); + } + menu = self.add_provider_switching_section(menu, provider, cx); menu }) } diff --git a/crates/edit_prediction_context2/Cargo.toml b/crates/edit_prediction_context2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..597884b44821e24a930c8730225be4c6bf1c90f6 --- /dev/null +++ b/crates/edit_prediction_context2/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "edit_prediction_context2" +version = "0.1.0" +edition.workspace = true +publish.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/edit_prediction_context2.rs" + +[dependencies] +parking_lot.workspace = true +anyhow.workspace = true +collections.workspace = true +futures.workspace = true +gpui.workspace = true +language.workspace = true +lsp.workspace = true +project.workspace = true +log.workspace = true +serde.workspace = true +smallvec.workspace = true +tree-sitter.workspace = true +util.workspace = true + +[dev-dependencies] +env_logger.workspace = true +indoc.workspace = true +futures.workspace = true +gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } +lsp = { workspace = true, features = ["test-support"] } +pretty_assertions.workspace = true +project = {workspace= true, features = ["test-support"]} +serde_json.workspace = true +settings = {workspace= true, features = ["test-support"]} +text = { workspace = true, features = ["test-support"] } +util = { workspace = true, features = ["test-support"] } +zlog.workspace = true diff --git a/crates/edit_prediction_context2/LICENSE-GPL b/crates/edit_prediction_context2/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/edit_prediction_context2/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/edit_prediction_context2/src/assemble_excerpts.rs b/crates/edit_prediction_context2/src/assemble_excerpts.rs new file mode 100644 index 0000000000000000000000000000000000000000..b3b8d4f8bc480053a1e9ab9d498d5350039ed609 --- /dev/null +++ b/crates/edit_prediction_context2/src/assemble_excerpts.rs @@ -0,0 +1,324 @@ +use crate::RelatedExcerpt; +use language::{BufferSnapshot, OffsetRangeExt as _, Point}; +use std::ops::Range; + +#[cfg(not(test))] +const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512; +#[cfg(test)] +const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 24; + +pub fn assemble_excerpts( + buffer: &BufferSnapshot, + mut input_ranges: Vec>, +) -> Vec { + merge_ranges(&mut input_ranges); + + let mut outline_ranges = Vec::new(); + let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None); + let mut outline_ix = 0; + for input_range in &mut input_ranges { + *input_range = clip_range_to_lines(input_range, false, buffer); + + while let Some(outline_item) = outline_items.get(outline_ix) { + let item_range = clip_range_to_lines(&outline_item.range, false, buffer); + + if item_range.start > input_range.start { + break; + } + + if item_range.end > input_range.start { + let body_range = outline_item + .body_range(buffer) + .map(|body| clip_range_to_lines(&body, true, buffer)) + .filter(|body_range| { + body_range.to_offset(buffer).len() > MAX_OUTLINE_ITEM_BODY_SIZE + }); + + add_outline_item( + item_range.clone(), + body_range.clone(), + buffer, + &mut outline_ranges, + ); + + if let Some(body_range) = body_range + && input_range.start < body_range.start + { + let mut child_outline_ix = outline_ix + 1; + while let Some(next_outline_item) = outline_items.get(child_outline_ix) { + if next_outline_item.range.end > body_range.end { + break; + } + if next_outline_item.depth == outline_item.depth + 1 { + let next_item_range = + clip_range_to_lines(&next_outline_item.range, false, buffer); + + add_outline_item( + next_item_range, + next_outline_item + .body_range(buffer) + .map(|body| clip_range_to_lines(&body, true, buffer)), + buffer, + &mut outline_ranges, + ); + child_outline_ix += 1; + } + } + } + } + + outline_ix += 1; + } + } + + input_ranges.extend_from_slice(&outline_ranges); + merge_ranges(&mut input_ranges); + + input_ranges + .into_iter() + .map(|range| { + let offset_range = range.to_offset(buffer); + RelatedExcerpt { + point_range: range, + anchor_range: buffer.anchor_before(offset_range.start) + ..buffer.anchor_after(offset_range.end), + text: buffer.as_rope().slice(offset_range), + } + }) + .collect() +} + +fn clip_range_to_lines( + range: &Range, + inward: bool, + buffer: &BufferSnapshot, +) -> Range { + let mut range = range.clone(); + if inward { + if range.start.column > 0 { + range.start.column = buffer.line_len(range.start.row); + } + range.end.column = 0; + } else { + range.start.column = 0; + if range.end.column > 0 { + range.end.column = buffer.line_len(range.end.row); + } + } + range +} + +fn add_outline_item( + mut item_range: Range, + body_range: Option>, + buffer: &BufferSnapshot, + outline_ranges: &mut Vec>, +) { + if let Some(mut body_range) = body_range { + if body_range.start.column > 0 { + body_range.start.column = buffer.line_len(body_range.start.row); + } + body_range.end.column = 0; + + let head_range = item_range.start..body_range.start; + if head_range.start < head_range.end { + outline_ranges.push(head_range); + } + + let tail_range = body_range.end..item_range.end; + if tail_range.start < tail_range.end { + outline_ranges.push(tail_range); + } + } else { + item_range.start.column = 0; + item_range.end.column = buffer.line_len(item_range.end.row); + outline_ranges.push(item_range); + } +} + +pub fn merge_ranges(ranges: &mut Vec>) { + ranges.sort_unstable_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end))); + + let mut index = 1; + while index < ranges.len() { + let mut prev_range_end = ranges[index - 1].end; + if prev_range_end.column > 0 { + prev_range_end += Point::new(1, 0); + } + + if (prev_range_end + Point::new(1, 0)) + .cmp(&ranges[index].start) + .is_ge() + { + let removed = ranges.remove(index); + if removed.end.cmp(&ranges[index - 1].end).is_gt() { + ranges[index - 1].end = removed.end; + } + } else { + index += 1; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::{TestAppContext, prelude::*}; + use indoc::indoc; + use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt}; + use pretty_assertions::assert_eq; + use std::{fmt::Write as _, sync::Arc}; + use util::test::marked_text_ranges; + + #[gpui::test] + fn test_rust(cx: &mut TestAppContext) { + let table = [ + ( + indoc! {r#" + struct User { + first_name: String, + «last_name»: String, + age: u32, + email: String, + create_at: Instant, + } + + impl User { + pub fn first_name(&self) -> String { + self.first_name.clone() + } + + pub fn full_name(&self) -> String { + « format!("{} {}", self.first_name, self.last_name) + » } + } + "#}, + indoc! {r#" + struct User { + first_name: String, + last_name: String, + … + } + + impl User { + … + pub fn full_name(&self) -> String { + format!("{} {}", self.first_name, self.last_name) + } + } + "#}, + ), + ( + indoc! {r#" + struct «User» { + first_name: String, + last_name: String, + age: u32, + } + + impl User { + // methods + } + "# + }, + indoc! {r#" + struct User { + first_name: String, + last_name: String, + age: u32, + } + … + "#}, + ), + ( + indoc! {r#" + trait «FooProvider» { + const NAME: &'static str; + + fn provide_foo(&self, id: usize) -> Foo; + + fn provide_foo_batched(&self, ids: &[usize]) -> Vec { + ids.iter() + .map(|id| self.provide_foo(*id)) + .collect() + } + + fn sync(&self); + } + "# + }, + indoc! {r#" + trait FooProvider { + const NAME: &'static str; + + fn provide_foo(&self, id: usize) -> Foo; + + fn provide_foo_batched(&self, ids: &[usize]) -> Vec { + … + } + + fn sync(&self); + } + "#}, + ), + ]; + + for (input, expected_output) in table { + let (input, ranges) = marked_text_ranges(&input, false); + let buffer = + cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx)); + buffer.read_with(cx, |buffer, _cx| { + let ranges: Vec> = ranges + .into_iter() + .map(|range| range.to_point(&buffer)) + .collect(); + + let excerpts = assemble_excerpts(&buffer.snapshot(), ranges); + + let output = format_excerpts(buffer, &excerpts); + assert_eq!(output, expected_output); + }); + } + } + + fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String { + let mut output = String::new(); + let file_line_count = buffer.max_point().row; + let mut current_row = 0; + for excerpt in excerpts { + if excerpt.text.is_empty() { + continue; + } + if current_row < excerpt.point_range.start.row { + writeln!(&mut output, "…").unwrap(); + } + current_row = excerpt.point_range.start.row; + + for line in excerpt.text.to_string().lines() { + output.push_str(line); + output.push('\n'); + current_row += 1; + } + } + if current_row < file_line_count { + writeln!(&mut output, "…").unwrap(); + } + output + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(language::tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap() + } +} diff --git a/crates/edit_prediction_context2/src/edit_prediction_context2.rs b/crates/edit_prediction_context2/src/edit_prediction_context2.rs new file mode 100644 index 0000000000000000000000000000000000000000..f8790478547ddb8b7b873015846f2af6c1bcbc2c --- /dev/null +++ b/crates/edit_prediction_context2/src/edit_prediction_context2.rs @@ -0,0 +1,465 @@ +use crate::assemble_excerpts::assemble_excerpts; +use anyhow::Result; +use collections::HashMap; +use futures::{FutureExt, StreamExt as _, channel::mpsc, future}; +use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity}; +use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _}; +use project::{LocationLink, Project, ProjectPath}; +use serde::{Serialize, Serializer}; +use smallvec::SmallVec; +use std::{ + collections::hash_map, + ops::Range, + sync::Arc, + time::{Duration, Instant}, +}; +use util::{RangeExt as _, ResultExt}; + +mod assemble_excerpts; +#[cfg(test)] +mod edit_prediction_context_tests; +#[cfg(test)] +mod fake_definition_lsp; + +pub struct RelatedExcerptStore { + project: WeakEntity, + related_files: Vec, + cache: HashMap>, + update_tx: mpsc::UnboundedSender<(Entity, Anchor)>, +} + +pub enum RelatedExcerptStoreEvent { + StartedRefresh, + FinishedRefresh { + cache_hit_count: usize, + cache_miss_count: usize, + mean_definition_latency: Duration, + max_definition_latency: Duration, + }, +} + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct Identifier { + pub name: String, + pub range: Range, +} + +enum DefinitionTask { + CacheHit(Arc), + CacheMiss(Task>>>), +} + +#[derive(Debug)] +struct CacheEntry { + definitions: SmallVec<[CachedDefinition; 1]>, +} + +#[derive(Clone, Debug)] +struct CachedDefinition { + path: ProjectPath, + buffer: Entity, + anchor_range: Range, +} + +#[derive(Clone, Debug, Serialize)] +pub struct RelatedFile { + #[serde(serialize_with = "serialize_project_path")] + pub path: ProjectPath, + #[serde(skip)] + pub buffer: WeakEntity, + pub excerpts: Vec, + pub max_row: u32, +} + +impl RelatedFile { + pub fn merge_excerpts(&mut self) { + self.excerpts.sort_unstable_by(|a, b| { + a.point_range + .start + .cmp(&b.point_range.start) + .then(b.point_range.end.cmp(&a.point_range.end)) + }); + + let mut index = 1; + while index < self.excerpts.len() { + if self.excerpts[index - 1] + .point_range + .end + .cmp(&self.excerpts[index].point_range.start) + .is_ge() + { + let removed = self.excerpts.remove(index); + if removed + .point_range + .end + .cmp(&self.excerpts[index - 1].point_range.end) + .is_gt() + { + self.excerpts[index - 1].point_range.end = removed.point_range.end; + self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end; + } + } else { + index += 1; + } + } + } +} + +#[derive(Clone, Debug, Serialize)] +pub struct RelatedExcerpt { + #[serde(skip)] + pub anchor_range: Range, + #[serde(serialize_with = "serialize_point_range")] + pub point_range: Range, + #[serde(serialize_with = "serialize_rope")] + pub text: Rope, +} + +fn serialize_project_path( + project_path: &ProjectPath, + serializer: S, +) -> Result { + project_path.path.serialize(serializer) +} + +fn serialize_rope(rope: &Rope, serializer: S) -> Result { + rope.to_string().serialize(serializer) +} + +fn serialize_point_range( + range: &Range, + serializer: S, +) -> Result { + [ + [range.start.row, range.start.column], + [range.end.row, range.end.column], + ] + .serialize(serializer) +} + +const DEBOUNCE_DURATION: Duration = Duration::from_millis(100); + +impl EventEmitter for RelatedExcerptStore {} + +impl RelatedExcerptStore { + pub fn new(project: &Entity, cx: &mut Context) -> Self { + let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity, Anchor)>(); + cx.spawn(async move |this, cx| { + let executor = cx.background_executor().clone(); + while let Some((mut buffer, mut position)) = update_rx.next().await { + let mut timer = executor.timer(DEBOUNCE_DURATION).fuse(); + loop { + futures::select_biased! { + next = update_rx.next() => { + if let Some((new_buffer, new_position)) = next { + buffer = new_buffer; + position = new_position; + timer = executor.timer(DEBOUNCE_DURATION).fuse(); + } else { + return anyhow::Ok(()); + } + } + _ = timer => break, + } + } + + Self::fetch_excerpts(this.clone(), buffer, position, cx).await?; + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + + RelatedExcerptStore { + project: project.downgrade(), + update_tx, + related_files: Vec::new(), + cache: Default::default(), + } + } + + pub fn refresh(&mut self, buffer: Entity, position: Anchor, _: &mut Context) { + self.update_tx.unbounded_send((buffer, position)).ok(); + } + + pub fn related_files(&self) -> &[RelatedFile] { + &self.related_files + } + + async fn fetch_excerpts( + this: WeakEntity, + buffer: Entity, + position: Anchor, + cx: &mut AsyncApp, + ) -> Result<()> { + let (project, snapshot) = this.read_with(cx, |this, cx| { + (this.project.upgrade(), buffer.read(cx).snapshot()) + })?; + let Some(project) = project else { + return Ok(()); + }; + + let file = snapshot.file().cloned(); + if let Some(file) = &file { + log::debug!("retrieving_context buffer:{}", file.path().as_unix_str()); + } + + this.update(cx, |_, cx| { + cx.emit(RelatedExcerptStoreEvent::StartedRefresh); + })?; + + let identifiers = cx + .background_spawn(async move { identifiers_for_position(&snapshot, position) }) + .await; + + let async_cx = cx.clone(); + let start_time = Instant::now(); + let futures = this.update(cx, |this, cx| { + identifiers + .into_iter() + .filter_map(|identifier| { + let task = if let Some(entry) = this.cache.get(&identifier) { + DefinitionTask::CacheHit(entry.clone()) + } else { + DefinitionTask::CacheMiss( + this.project + .update(cx, |project, cx| { + project.definitions(&buffer, identifier.range.start, cx) + }) + .ok()?, + ) + }; + + let cx = async_cx.clone(); + let project = project.clone(); + Some(async move { + match task { + DefinitionTask::CacheHit(cache_entry) => { + Some((identifier, cache_entry, None)) + } + DefinitionTask::CacheMiss(task) => { + let locations = task.await.log_err()??; + let duration = start_time.elapsed(); + cx.update(|cx| { + ( + identifier, + Arc::new(CacheEntry { + definitions: locations + .into_iter() + .filter_map(|location| { + process_definition(location, &project, cx) + }) + .collect(), + }), + Some(duration), + ) + }) + .ok() + } + } + }) + }) + .collect::>() + })?; + + let mut cache_hit_count = 0; + let mut cache_miss_count = 0; + let mut mean_definition_latency = Duration::ZERO; + let mut max_definition_latency = Duration::ZERO; + let mut new_cache = HashMap::default(); + new_cache.reserve(futures.len()); + for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() { + new_cache.insert(identifier, entry); + if let Some(duration) = duration { + cache_miss_count += 1; + mean_definition_latency += duration; + max_definition_latency = max_definition_latency.max(duration); + } else { + cache_hit_count += 1; + } + } + mean_definition_latency /= cache_miss_count.max(1) as u32; + + let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?; + + if let Some(file) = &file { + log::debug!( + "finished retrieving context buffer:{}, latency:{:?}", + file.path().as_unix_str(), + start_time.elapsed() + ); + } + + this.update(cx, |this, cx| { + this.cache = new_cache; + this.related_files = related_files; + cx.emit(RelatedExcerptStoreEvent::FinishedRefresh { + cache_hit_count, + cache_miss_count, + mean_definition_latency, + max_definition_latency, + }); + })?; + + anyhow::Ok(()) + } +} + +async fn rebuild_related_files( + new_entries: HashMap>, + cx: &mut AsyncApp, +) -> Result<(HashMap>, Vec)> { + let mut snapshots = HashMap::default(); + for entry in new_entries.values() { + for definition in &entry.definitions { + if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) { + definition + .buffer + .read_with(cx, |buffer, _| buffer.parsing_idle())? + .await; + e.insert( + definition + .buffer + .read_with(cx, |buffer, _| buffer.snapshot())?, + ); + } + } + } + + Ok(cx + .background_spawn(async move { + let mut files = Vec::::new(); + let mut ranges_by_buffer = HashMap::<_, Vec>>::default(); + let mut paths_by_buffer = HashMap::default(); + for entry in new_entries.values() { + for definition in &entry.definitions { + let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else { + continue; + }; + paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone()); + ranges_by_buffer + .entry(definition.buffer.clone()) + .or_default() + .push(definition.anchor_range.to_point(snapshot)); + } + } + + for (buffer, ranges) in ranges_by_buffer { + let Some(snapshot) = snapshots.get(&buffer.entity_id()) else { + continue; + }; + let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else { + continue; + }; + let excerpts = assemble_excerpts(snapshot, ranges); + files.push(RelatedFile { + path: project_path.clone(), + buffer: buffer.downgrade(), + excerpts, + max_row: snapshot.max_point().row, + }); + } + + files.sort_by_key(|file| file.path.clone()); + (new_entries, files) + }) + .await) +} + +fn process_definition( + location: LocationLink, + project: &Entity, + cx: &mut App, +) -> Option { + let buffer = location.target.buffer.read(cx); + let anchor_range = location.target.range; + let file = buffer.file()?; + let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?; + if worktree.read(cx).is_single_file() { + return None; + } + Some(CachedDefinition { + path: ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + }, + buffer: location.target.buffer, + anchor_range, + }) +} + +/// Gets all of the identifiers that are present in the given line, and its containing +/// outline items. +fn identifiers_for_position(buffer: &BufferSnapshot, position: Anchor) -> Vec { + let offset = position.to_offset(buffer); + let point = buffer.offset_to_point(offset); + + let line_range = Point::new(point.row, 0)..Point::new(point.row + 1, 0).min(buffer.max_point()); + let mut ranges = vec![line_range.to_offset(&buffer)]; + + // Include the range of the outline item itself, but not its body. + let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None); + for item in outline_items { + if let Some(body_range) = item.body_range(&buffer) { + ranges.push(item.range.start..body_range.start.to_offset(&buffer)); + } else { + ranges.push(item.range.clone()); + } + } + + ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end))); + ranges.dedup_by(|a, b| { + if a.start <= b.end { + b.start = b.start.min(a.start); + b.end = b.end.max(a.end); + true + } else { + false + } + }); + + let mut identifiers = Vec::new(); + let outer_range = + ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end); + + let mut captures = buffer + .syntax + .captures(outer_range.clone(), &buffer.text, |grammar| { + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) + }); + + for range in ranges { + captures.set_byte_range(range.start..outer_range.end); + + let mut last_range = None; + while let Some(capture) = captures.peek() { + let node_range = capture.node.byte_range(); + if node_range.start > range.end { + break; + } + let config = captures.grammars()[capture.grammar_index] + .highlights_config + .as_ref(); + + if let Some(config) = config + && config.identifier_capture_indices.contains(&capture.index) + && range.contains_inclusive(&node_range) + && Some(&node_range) != last_range.as_ref() + { + let name = buffer.text_for_range(node_range.clone()).collect(); + identifiers.push(Identifier { + range: buffer.anchor_after(node_range.start) + ..buffer.anchor_before(node_range.end), + name, + }); + last_range = Some(node_range); + } + + captures.advance(); + } + } + + identifiers +} diff --git a/crates/edit_prediction_context2/src/edit_prediction_context_tests.rs b/crates/edit_prediction_context2/src/edit_prediction_context_tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..05d1becc2167837a5f9741d77e7bc96c2f5b8d34 --- /dev/null +++ b/crates/edit_prediction_context2/src/edit_prediction_context_tests.rs @@ -0,0 +1,360 @@ +use super::*; +use futures::channel::mpsc::UnboundedReceiver; +use gpui::TestAppContext; +use indoc::indoc; +use language::{Language, LanguageConfig, LanguageMatcher, Point, ToPoint as _, tree_sitter_rust}; +use lsp::FakeLanguageServer; +use project::{FakeFs, LocationLink, Project}; +use serde_json::json; +use settings::SettingsStore; +use std::sync::Arc; +use util::path; + +#[gpui::test] +async fn test_edit_prediction_context(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/root"), test_project_1()).await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let mut servers = setup_fake_lsp(&project, cx); + + let (buffer, _handle) = project + .update(cx, |project, cx| { + project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx) + }) + .await + .unwrap(); + + let _server = servers.next().await.unwrap(); + cx.run_until_parked(); + + let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(&project, cx)); + related_excerpt_store.update(cx, |store, cx| { + let position = { + let buffer = buffer.read(cx); + let offset = buffer.text().find("todo").unwrap(); + buffer.anchor_before(offset) + }; + + store.refresh(buffer.clone(), position, cx); + }); + + cx.executor().advance_clock(DEBOUNCE_DURATION); + related_excerpt_store.update(cx, |store, _| { + let excerpts = store.related_files(); + assert_related_files( + &excerpts, + &[ + ( + "src/company.rs", + &[indoc! {" + pub struct Company { + owner: Arc, + address: Address, + }"}], + ), + ( + "src/main.rs", + &[ + indoc! {" + pub struct Session { + company: Arc, + } + + impl Session { + pub fn set_company(&mut self, company: Arc) {"}, + indoc! {" + } + }"}, + ], + ), + ( + "src/person.rs", + &[ + indoc! {" + impl Person { + pub fn get_first_name(&self) -> &str { + &self.first_name + }"}, + "}", + ], + ), + ], + ); + }); +} + +#[gpui::test] +async fn test_fake_definition_lsp(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/root"), test_project_1()).await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let mut servers = setup_fake_lsp(&project, cx); + + let (buffer, _handle) = project + .update(cx, |project, cx| { + project.open_local_buffer_with_lsp(path!("/root/src/main.rs"), cx) + }) + .await + .unwrap(); + + let _server = servers.next().await.unwrap(); + cx.run_until_parked(); + + let buffer_text = buffer.read_with(cx, |buffer, _| buffer.text()); + + let definitions = project + .update(cx, |project, cx| { + let offset = buffer_text.find("Address {").unwrap(); + project.definitions(&buffer, offset, cx) + }) + .await + .unwrap() + .unwrap(); + assert_definitions(&definitions, &["pub struct Address {"], cx); + + let definitions = project + .update(cx, |project, cx| { + let offset = buffer_text.find("State::CA").unwrap(); + project.definitions(&buffer, offset, cx) + }) + .await + .unwrap() + .unwrap(); + assert_definitions(&definitions, &["pub enum State {"], cx); + + let definitions = project + .update(cx, |project, cx| { + let offset = buffer_text.find("to_string()").unwrap(); + project.definitions(&buffer, offset, cx) + }) + .await + .unwrap() + .unwrap(); + assert_definitions(&definitions, &["pub fn to_string(&self) -> String {"], cx); +} + +fn init_test(cx: &mut TestAppContext) { + let settings_store = cx.update(|cx| SettingsStore::test(cx)); + cx.set_global(settings_store); + env_logger::try_init().ok(); +} + +fn setup_fake_lsp( + project: &Entity, + cx: &mut TestAppContext, +) -> UnboundedReceiver { + let (language_registry, fs) = project.read_with(cx, |project, _| { + (project.languages().clone(), project.fs().clone()) + }); + let language = rust_lang(); + language_registry.add(language.clone()); + fake_definition_lsp::register_fake_definition_server(&language_registry, language, fs) +} + +fn test_project_1() -> serde_json::Value { + let person_rs = indoc! {r#" + pub struct Person { + first_name: String, + last_name: String, + email: String, + age: u32, + } + + impl Person { + pub fn get_first_name(&self) -> &str { + &self.first_name + } + + pub fn get_last_name(&self) -> &str { + &self.last_name + } + + pub fn get_email(&self) -> &str { + &self.email + } + + pub fn get_age(&self) -> u32 { + self.age + } + } + "#}; + + let address_rs = indoc! {r#" + pub struct Address { + street: String, + city: String, + state: State, + zip: u32, + } + + pub enum State { + CA, + OR, + WA, + TX, + // ... + } + + impl Address { + pub fn get_street(&self) -> &str { + &self.street + } + + pub fn get_city(&self) -> &str { + &self.city + } + + pub fn get_state(&self) -> State { + self.state + } + + pub fn get_zip(&self) -> u32 { + self.zip + } + } + "#}; + + let company_rs = indoc! {r#" + use super::person::Person; + use super::address::Address; + + pub struct Company { + owner: Arc, + address: Address, + } + + impl Company { + pub fn get_owner(&self) -> &Person { + &self.owner + } + + pub fn get_address(&self) -> &Address { + &self.address + } + + pub fn to_string(&self) -> String { + format!("{} ({})", self.owner.first_name, self.address.city) + } + } + "#}; + + let main_rs = indoc! {r#" + use std::sync::Arc; + use super::person::Person; + use super::address::Address; + use super::company::Company; + + pub struct Session { + company: Arc, + } + + impl Session { + pub fn set_company(&mut self, company: Arc) { + self.company = company; + if company.owner != self.company.owner { + log("new owner", company.owner.get_first_name()); todo(); + } + } + } + + fn main() { + let company = Company { + owner: Arc::new(Person { + first_name: "John".to_string(), + last_name: "Doe".to_string(), + email: "john@example.com".to_string(), + age: 30, + }), + address: Address { + street: "123 Main St".to_string(), + city: "Anytown".to_string(), + state: State::CA, + zip: 12345, + }, + }; + + println!("Company: {}", company.to_string()); + } + "#}; + + json!({ + "src": { + "person.rs": person_rs, + "address.rs": address_rs, + "company.rs": company_rs, + "main.rs": main_rs, + }, + }) +} + +fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &[&str])]) { + let actual_files = actual_files + .iter() + .map(|file| { + let excerpts = file + .excerpts + .iter() + .map(|excerpt| excerpt.text.to_string()) + .collect::>(); + (file.path.path.as_unix_str(), excerpts) + }) + .collect::>(); + let expected_excerpts = expected_files + .iter() + .map(|(path, texts)| { + ( + *path, + texts + .iter() + .map(|line| line.to_string()) + .collect::>(), + ) + }) + .collect::>(); + pretty_assertions::assert_eq!(actual_files, expected_excerpts) +} + +fn assert_definitions(definitions: &[LocationLink], first_lines: &[&str], cx: &mut TestAppContext) { + let actual_first_lines = definitions + .iter() + .map(|definition| { + definition.target.buffer.read_with(cx, |buffer, _| { + let mut start = definition.target.range.start.to_point(&buffer); + start.column = 0; + let end = Point::new(start.row, buffer.line_len(start.row)); + buffer + .text_for_range(start..end) + .collect::() + .trim() + .to_string() + }) + }) + .collect::>(); + + assert_eq!(actual_first_lines, first_lines); +} + +pub(crate) fn rust_lang() -> Arc { + Arc::new( + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + first_line_pattern: None, + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm")) + .unwrap() + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap(), + ) +} diff --git a/crates/edit_prediction_context2/src/fake_definition_lsp.rs b/crates/edit_prediction_context2/src/fake_definition_lsp.rs new file mode 100644 index 0000000000000000000000000000000000000000..31fb681309c610a37c7f886390ef5adb92ee78ef --- /dev/null +++ b/crates/edit_prediction_context2/src/fake_definition_lsp.rs @@ -0,0 +1,329 @@ +use collections::HashMap; +use futures::channel::mpsc::UnboundedReceiver; +use language::{Language, LanguageRegistry}; +use lsp::{ + FakeLanguageServer, LanguageServerBinary, TextDocumentSyncCapability, TextDocumentSyncKind, Uri, +}; +use parking_lot::Mutex; +use project::Fs; +use std::{ops::Range, path::PathBuf, sync::Arc}; +use tree_sitter::{Parser, QueryCursor, StreamingIterator, Tree}; + +/// Registers a fake language server that implements go-to-definition using tree-sitter, +/// making the assumption that all names are unique, and all variables' types are +/// explicitly declared. +pub fn register_fake_definition_server( + language_registry: &Arc, + language: Arc, + fs: Arc, +) -> UnboundedReceiver { + let index = Arc::new(Mutex::new(DefinitionIndex::new(language.clone()))); + + language_registry.register_fake_lsp( + language.name(), + language::FakeLspAdapter { + name: "fake-definition-lsp", + initialization_options: None, + prettier_plugins: Vec::new(), + disk_based_diagnostics_progress_token: None, + disk_based_diagnostics_sources: Vec::new(), + language_server_binary: LanguageServerBinary { + path: PathBuf::from("fake-definition-lsp"), + arguments: Vec::new(), + env: None, + }, + capabilities: lsp::ServerCapabilities { + definition_provider: Some(lsp::OneOf::Left(true)), + text_document_sync: Some(TextDocumentSyncCapability::Kind( + TextDocumentSyncKind::FULL, + )), + ..Default::default() + }, + label_for_completion: None, + initializer: Some(Box::new({ + move |server| { + server.handle_notification::({ + let index = index.clone(); + move |params, _cx| { + index + .lock() + .open_buffer(params.text_document.uri, ¶ms.text_document.text); + } + }); + + server.handle_notification::({ + let index = index.clone(); + let fs = fs.clone(); + move |params, cx| { + let uri = params.text_document.uri; + let path = uri.to_file_path().ok(); + index.lock().mark_buffer_closed(&uri); + + if let Some(path) = path { + let index = index.clone(); + let fs = fs.clone(); + cx.spawn(async move |_cx| { + if let Ok(content) = fs.load(&path).await { + index.lock().index_file(uri, &content); + } + }) + .detach(); + } + } + }); + + server.handle_notification::({ + let index = index.clone(); + let fs = fs.clone(); + move |params, cx| { + let index = index.clone(); + let fs = fs.clone(); + cx.spawn(async move |_cx| { + for event in params.changes { + if index.lock().is_buffer_open(&event.uri) { + continue; + } + + match event.typ { + lsp::FileChangeType::DELETED => { + index.lock().remove_definitions_for_file(&event.uri); + } + lsp::FileChangeType::CREATED + | lsp::FileChangeType::CHANGED => { + if let Some(path) = event.uri.to_file_path().ok() { + if let Ok(content) = fs.load(&path).await { + index.lock().index_file(event.uri, &content); + } + } + } + _ => {} + } + } + }) + .detach(); + } + }); + + server.handle_notification::({ + let index = index.clone(); + move |params, _cx| { + if let Some(change) = params.content_changes.into_iter().last() { + index + .lock() + .index_file(params.text_document.uri, &change.text); + } + } + }); + + server.handle_notification::( + { + let index = index.clone(); + let fs = fs.clone(); + move |params, cx| { + let index = index.clone(); + let fs = fs.clone(); + let files = fs.as_fake().files(); + cx.spawn(async move |_cx| { + for folder in params.event.added { + let Ok(path) = folder.uri.to_file_path() else { + continue; + }; + for file in &files { + if let Some(uri) = Uri::from_file_path(&file).ok() + && file.starts_with(&path) + && let Ok(content) = fs.load(&file).await + { + index.lock().index_file(uri, &content); + } + } + } + }) + .detach(); + } + }, + ); + + server.set_request_handler::({ + let index = index.clone(); + move |params, _cx| { + let result = index.lock().get_definitions( + params.text_document_position_params.text_document.uri, + params.text_document_position_params.position, + ); + async move { Ok(result) } + } + }); + } + })), + }, + ) +} + +struct DefinitionIndex { + language: Arc, + definitions: HashMap>, + files: HashMap, +} + +#[derive(Debug)] +struct FileEntry { + contents: String, + is_open_in_buffer: bool, +} + +impl DefinitionIndex { + fn new(language: Arc) -> Self { + Self { + language, + definitions: HashMap::default(), + files: HashMap::default(), + } + } + + fn remove_definitions_for_file(&mut self, uri: &Uri) { + self.definitions.retain(|_, locations| { + locations.retain(|loc| &loc.uri != uri); + !locations.is_empty() + }); + self.files.remove(uri); + } + + fn open_buffer(&mut self, uri: Uri, content: &str) { + self.index_file_inner(uri, content, true); + } + + fn mark_buffer_closed(&mut self, uri: &Uri) { + if let Some(entry) = self.files.get_mut(uri) { + entry.is_open_in_buffer = false; + } + } + + fn is_buffer_open(&self, uri: &Uri) -> bool { + self.files + .get(uri) + .map(|entry| entry.is_open_in_buffer) + .unwrap_or(false) + } + + fn index_file(&mut self, uri: Uri, content: &str) { + self.index_file_inner(uri, content, false); + } + + fn index_file_inner(&mut self, uri: Uri, content: &str, is_open_in_buffer: bool) -> Option<()> { + self.remove_definitions_for_file(&uri); + let grammar = self.language.grammar()?; + let outline_config = grammar.outline_config.as_ref()?; + let mut parser = Parser::new(); + parser.set_language(&grammar.ts_language).ok()?; + let tree = parser.parse(content, None)?; + let declarations = extract_declarations_from_tree(&tree, content, outline_config); + for (name, byte_range) in declarations { + let range = byte_range_to_lsp_range(content, byte_range); + let location = lsp::Location { + uri: uri.clone(), + range, + }; + self.definitions + .entry(name) + .or_insert_with(Vec::new) + .push(location); + } + self.files.insert( + uri, + FileEntry { + contents: content.to_string(), + is_open_in_buffer, + }, + ); + + Some(()) + } + + fn get_definitions( + &mut self, + uri: Uri, + position: lsp::Position, + ) -> Option { + let entry = self.files.get(&uri)?; + let name = word_at_position(&entry.contents, position)?; + let locations = self.definitions.get(name).cloned()?; + Some(lsp::GotoDefinitionResponse::Array(locations)) + } +} + +fn extract_declarations_from_tree( + tree: &Tree, + content: &str, + outline_config: &language::OutlineConfig, +) -> Vec<(String, Range)> { + let mut cursor = QueryCursor::new(); + let mut declarations = Vec::new(); + let mut matches = cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()); + while let Some(query_match) = matches.next() { + let mut name_range: Option> = None; + let mut has_item_range = false; + + for capture in query_match.captures { + let range = capture.node.byte_range(); + if capture.index == outline_config.name_capture_ix { + name_range = Some(range); + } else if capture.index == outline_config.item_capture_ix { + has_item_range = true; + } + } + + if let Some(name_range) = name_range + && has_item_range + { + let name = content[name_range.clone()].to_string(); + if declarations.iter().any(|(n, _)| n == &name) { + continue; + } + declarations.push((name, name_range)); + } + } + declarations +} + +fn byte_range_to_lsp_range(content: &str, byte_range: Range) -> lsp::Range { + let start = byte_offset_to_position(content, byte_range.start); + let end = byte_offset_to_position(content, byte_range.end); + lsp::Range { start, end } +} + +fn byte_offset_to_position(content: &str, offset: usize) -> lsp::Position { + let mut line = 0; + let mut character = 0; + let mut current_offset = 0; + for ch in content.chars() { + if current_offset >= offset { + break; + } + if ch == '\n' { + line += 1; + character = 0; + } else { + character += 1; + } + current_offset += ch.len_utf8(); + } + lsp::Position { line, character } +} + +fn word_at_position(content: &str, position: lsp::Position) -> Option<&str> { + let mut lines = content.lines(); + let line = lines.nth(position.line as usize)?; + let column = position.character as usize; + if column > line.len() { + return None; + } + let start = line[..column] + .rfind(|c: char| !c.is_alphanumeric() && c != '_') + .map(|i| i + 1) + .unwrap_or(0); + let end = line[column..] + .find(|c: char| !c.is_alphanumeric() && c != '_') + .map(|i| i + column) + .unwrap_or(line.len()); + Some(&line[start..end]).filter(|word| !word.is_empty()) +} diff --git a/crates/extension_host/src/extension_store_test.rs b/crates/extension_host/src/extension_store_test.rs index 85a3a720ce8c62fc4317756ec264926c981864c4..6d3aadeb5ac498b3948d871a0a87f7ecf49b6bd8 100644 --- a/crates/extension_host/src/extension_store_test.rs +++ b/crates/extension_host/src/extension_store_test.rs @@ -705,7 +705,7 @@ async fn test_extension_store_with_test_extension(cx: &mut TestAppContext) { .await .unwrap(); - let mut fake_servers = language_registry.register_fake_language_server( + let mut fake_servers = language_registry.register_fake_lsp_server( LanguageServerName("gleam".into()), lsp::ServerCapabilities { completion_provider: Some(Default::default()), diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index a46f7cc35912d4c6da42ba69f7aee6d25caca2e7..7166a01ef64bff9e47c70cac47910f714ae2dc39 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -4022,6 +4022,20 @@ impl BufferSnapshot { }) } + pub fn outline_items_as_offsets_containing( + &self, + range: Range, + include_extra_context: bool, + theme: Option<&SyntaxTheme>, + ) -> Vec> { + self.outline_items_containing_internal( + range, + include_extra_context, + theme, + |buffer, range| range.to_offset(buffer), + ) + } + fn outline_items_containing_internal( &self, range: Range, diff --git a/crates/language/src/buffer_tests.rs b/crates/language/src/buffer_tests.rs index efef0a08127bc66f9c6d8f21fe5a545dbee20fb1..e95bc544a56ecf9d561936ca48b10ccffcb23e72 100644 --- a/crates/language/src/buffer_tests.rs +++ b/crates/language/src/buffer_tests.rs @@ -784,28 +784,48 @@ async fn test_outline(cx: &mut gpui::TestAppContext) { .unindent(); let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx)); - let outline = buffer.update(cx, |buffer, _| buffer.snapshot().outline(None)); + let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot()); + let outline = snapshot.outline(None); - assert_eq!( + pretty_assertions::assert_eq!( outline .items .iter() - .map(|item| (item.text.as_str(), item.depth)) + .map(|item| ( + item.text.as_str(), + item.depth, + item.to_point(&snapshot).body_range(&snapshot) + .map(|range| minimize_space(&snapshot.text_for_range(range).collect::())) + )) .collect::>(), &[ - ("struct Person", 0), - ("name", 1), - ("age", 1), - ("mod module", 0), - ("enum LoginState", 1), - ("LoggedOut", 2), - ("LoggingOn", 2), - ("LoggedIn", 2), - ("person", 3), - ("time", 3), - ("impl Eq for Person", 0), - ("impl Drop for Person", 0), - ("fn drop", 1), + ("struct Person", 0, Some("name: String, age: usize,".to_string())), + ("name", 1, None), + ("age", 1, None), + ( + "mod module", + 0, + Some( + "enum LoginState { LoggedOut, LoggingOn, LoggedIn { person: Person, time: Instant, } }".to_string() + ) + ), + ( + "enum LoginState", + 1, + Some("LoggedOut, LoggingOn, LoggedIn { person: Person, time: Instant, }".to_string()) + ), + ("LoggedOut", 2, None), + ("LoggingOn", 2, None), + ("LoggedIn", 2, Some("person: Person, time: Instant,".to_string())), + ("person", 3, None), + ("time", 3, None), + ("impl Eq for Person", 0, None), + ( + "impl Drop for Person", + 0, + Some("fn drop(&mut self) { println!(\"bye\"); }".to_string()) + ), + ("fn drop", 1, Some("println!(\"bye\");".to_string())), ] ); @@ -840,6 +860,11 @@ async fn test_outline(cx: &mut gpui::TestAppContext) { ] ); + fn minimize_space(text: &str) -> String { + static WHITESPACE: LazyLock = LazyLock::new(|| Regex::new("[\\n\\s]+").unwrap()); + WHITESPACE.replace_all(text, " ").trim().to_string() + } + async fn search<'a>( outline: &'a Outline, query: &'a str, diff --git a/crates/language/src/language_registry.rs b/crates/language/src/language_registry.rs index 022eb89e6d2b378b8c4305c81887060d776bb411..a0b04efd1b1366a101812d8656965637c13769a5 100644 --- a/crates/language/src/language_registry.rs +++ b/crates/language/src/language_registry.rs @@ -437,26 +437,14 @@ impl LanguageRegistry { language_name: impl Into, mut adapter: crate::FakeLspAdapter, ) -> futures::channel::mpsc::UnboundedReceiver { - let language_name = language_name.into(); let adapter_name = LanguageServerName(adapter.name.into()); let capabilities = adapter.capabilities.clone(); let initializer = adapter.initializer.take(); - let adapter = CachedLspAdapter::new(Arc::new(adapter)); - { - let mut state = self.state.write(); - state - .lsp_adapters - .entry(language_name) - .or_default() - .push(adapter.clone()); - state.all_lsp_adapters.insert(adapter.name(), adapter); - } - - self.register_fake_language_server(adapter_name, capabilities, initializer) + self.register_fake_lsp_adapter(language_name, adapter); + self.register_fake_lsp_server(adapter_name, capabilities, initializer) } /// Register a fake lsp adapter (without the language server) - /// The returned channel receives a new instance of the language server every time it is started #[cfg(any(feature = "test-support", test))] pub fn register_fake_lsp_adapter( &self, @@ -479,7 +467,7 @@ impl LanguageRegistry { /// Register a fake language server (without the adapter) /// The returned channel receives a new instance of the language server every time it is started #[cfg(any(feature = "test-support", test))] - pub fn register_fake_language_server( + pub fn register_fake_lsp_server( &self, lsp_name: LanguageServerName, capabilities: lsp::ServerCapabilities, diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index 3bf4e35c6b5cfd7f2a1f221bde4cec181998ab6a..068f8e1aa39ca3422fda8eb5706c00de6f2f62ce 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -373,6 +373,8 @@ impl InlayHintSettings { pub struct EditPredictionSettings { /// The provider that supplies edit predictions. pub provider: settings::EditPredictionProvider, + /// Whether to use the experimental edit prediction context retrieval system. + pub use_context: bool, /// A list of globs representing files that edit predictions should be disabled for. /// This list adds to a pre-existing, sensible default set of globs. /// Any additional ones you add are combined with them. @@ -622,6 +624,11 @@ impl settings::Settings for AllLanguageSettings { .features .as_ref() .and_then(|f| f.edit_prediction_provider); + let use_edit_prediction_context = all_languages + .features + .as_ref() + .and_then(|f| f.experimental_edit_prediction_context_retrieval) + .unwrap_or_default(); let edit_predictions = all_languages.edit_predictions.clone().unwrap(); let edit_predictions_mode = edit_predictions.mode.unwrap(); @@ -668,6 +675,7 @@ impl settings::Settings for AllLanguageSettings { } else { EditPredictionProvider::None }, + use_context: use_edit_prediction_context, disabled_globs: disabled_globs .iter() .filter_map(|g| { diff --git a/crates/language/src/outline.rs b/crates/language/src/outline.rs index 2ce2b42734465a4710a7439f5e2225debc96b04a..875042bfc83ae42fb580ab848029902d68988511 100644 --- a/crates/language/src/outline.rs +++ b/crates/language/src/outline.rs @@ -1,4 +1,4 @@ -use crate::{BufferSnapshot, Point, ToPoint}; +use crate::{BufferSnapshot, Point, ToPoint, ToTreeSitterPoint}; use fuzzy::{StringMatch, StringMatchCandidate}; use gpui::{BackgroundExecutor, HighlightStyle}; use std::ops::Range; @@ -48,6 +48,54 @@ impl OutlineItem { .map(|r| r.start.to_point(buffer)..r.end.to_point(buffer)), } } + + pub fn body_range(&self, buffer: &BufferSnapshot) -> Option> { + if let Some(range) = self.body_range.as_ref() { + return Some(range.start.to_point(buffer)..range.end.to_point(buffer)); + } + + let range = self.range.start.to_point(buffer)..self.range.end.to_point(buffer); + let start_indent = buffer.indent_size_for_line(range.start.row); + let node = buffer.syntax_ancestor(range.clone())?; + + let mut cursor = node.walk(); + loop { + let node = cursor.node(); + if node.start_position() >= range.start.to_ts_point() + && node.end_position() <= range.end.to_ts_point() + { + break; + } + cursor.goto_first_child_for_point(range.start.to_ts_point()); + } + + if !cursor.goto_last_child() { + return None; + } + let body_node = loop { + let node = cursor.node(); + if node.child_count() > 0 { + break node; + } + if !cursor.goto_previous_sibling() { + return None; + } + }; + + let mut start_row = body_node.start_position().row as u32; + let mut end_row = body_node.end_position().row as u32; + + while start_row < end_row && buffer.indent_size_for_line(start_row) == start_indent { + start_row += 1; + } + while start_row < end_row && buffer.indent_size_for_line(end_row - 1) == start_indent { + end_row -= 1; + } + if start_row < end_row { + return Some(Point::new(start_row, 0)..Point::new(end_row, 0)); + } + None + } } impl Outline { diff --git a/crates/language/src/syntax_map.rs b/crates/language/src/syntax_map.rs index 8574d52ff900563ddfb733c09204caab5eb6ae44..17285ca315fb64dd518d00039d28266c0a7f51ab 100644 --- a/crates/language/src/syntax_map.rs +++ b/crates/language/src/syntax_map.rs @@ -1215,6 +1215,19 @@ impl<'a> SyntaxMapMatches<'a> { true } + + // pub fn set_byte_range(&mut self, range: Range) { + // for layer in &mut self.layers { + // layer.matches.set_byte_range(range.clone()); + // layer.advance(); + // } + // self.layers.sort_unstable_by_key(|layer| layer.sort_key()); + // self.active_layer_count = self + // .layers + // .iter() + // .position(|layer| !layer.has_next) + // .unwrap_or(self.layers.len()); + // } } impl SyntaxMapCapturesLayer<'_> { diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 1e6ecddb5f2599a0ded0180f3afd3df0f197f037..a91d1d055d582eb2f2de4883314ad5984238103a 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -452,7 +452,7 @@ async fn test_remote_lsp(cx: &mut TestAppContext, server_cx: &mut TestAppContext }); let mut fake_lsp = server_cx.update(|cx| { - headless.read(cx).languages.register_fake_language_server( + headless.read(cx).languages.register_fake_lsp_server( LanguageServerName("rust-analyzer".into()), lsp::ServerCapabilities { completion_provider: Some(lsp::CompletionOptions::default()), @@ -476,7 +476,7 @@ async fn test_remote_lsp(cx: &mut TestAppContext, server_cx: &mut TestAppContext ..FakeLspAdapter::default() }, ); - headless.read(cx).languages.register_fake_language_server( + headless.read(cx).languages.register_fake_lsp_server( LanguageServerName("fake-analyzer".into()), lsp::ServerCapabilities { completion_provider: Some(lsp::CompletionOptions::default()), @@ -669,7 +669,7 @@ async fn test_remote_cancel_language_server_work( }); let mut fake_lsp = server_cx.update(|cx| { - headless.read(cx).languages.register_fake_language_server( + headless.read(cx).languages.register_fake_lsp_server( LanguageServerName("rust-analyzer".into()), Default::default(), None, diff --git a/crates/settings/src/settings_content/language.rs b/crates/settings/src/settings_content/language.rs index 6b8a372269d44935e20426a0b669fed96a33dadf..b466b4e0dd88bf41e0f77f67a38842305c11906f 100644 --- a/crates/settings/src/settings_content/language.rs +++ b/crates/settings/src/settings_content/language.rs @@ -62,6 +62,8 @@ impl merge_from::MergeFrom for AllLanguageSettingsContent { pub struct FeaturesContent { /// Determines which edit prediction provider to use. pub edit_prediction_provider: Option, + /// Enables the experimental edit prediction context retrieval system. + pub experimental_edit_prediction_context_retrieval: Option, } /// The provider that supplies edit predictions. diff --git a/crates/text/src/anchor.rs b/crates/text/src/anchor.rs index c6d47a1e233b2fdf58fbc73adb622fc801832335..bf660b1302466e2b244a86b3d1e58ea2b6991067 100644 --- a/crates/text/src/anchor.rs +++ b/crates/text/src/anchor.rs @@ -8,10 +8,14 @@ use sum_tree::{Bias, Dimensions}; /// A timestamped position in a buffer #[derive(Copy, Clone, Eq, PartialEq, Hash)] pub struct Anchor { + /// The timestamp of the operation that inserted the text + /// in which this anchor is located. pub timestamp: clock::Lamport, - /// The byte offset in the buffer + /// The byte offset into the text inserted in the operation + /// at `timestamp`. pub offset: usize, - /// Describes which character the anchor is biased towards + /// Whether this anchor stays attached to the character *before* or *after* + /// the offset. pub bias: Bias, pub buffer_id: Option, } diff --git a/crates/ui/src/components/data_table.rs b/crates/ui/src/components/data_table.rs index f7cce2b85ffa3aeb9f97634c6c0fa65c46f4a8e7..9cd2a5cb7a0d802d170fcfbe6a812027c779d942 100644 --- a/crates/ui/src/components/data_table.rs +++ b/crates/ui/src/components/data_table.rs @@ -485,6 +485,7 @@ pub struct Table { interaction_state: Option>, col_widths: Option>, map_row: Option), &mut Window, &mut App) -> AnyElement>>, + use_ui_font: bool, empty_table_callback: Option AnyElement>>, } @@ -498,6 +499,7 @@ impl Table { rows: TableContents::Vec(Vec::new()), interaction_state: None, map_row: None, + use_ui_font: true, empty_table_callback: None, col_widths: None, } @@ -590,6 +592,11 @@ impl Table { self } + pub fn no_ui_font(mut self) -> Self { + self.use_ui_font = false; + self + } + pub fn map_row( mut self, callback: impl Fn((usize, Stateful
), &mut Window, &mut App) -> AnyElement + 'static, @@ -618,8 +625,8 @@ fn base_cell_style(width: Option) -> Div { .overflow_hidden() } -fn base_cell_style_text(width: Option, cx: &App) -> Div { - base_cell_style(width).text_ui(cx) +fn base_cell_style_text(width: Option, use_ui_font: bool, cx: &App) -> Div { + base_cell_style(width).when(use_ui_font, |el| el.text_ui(cx)) } pub fn render_table_row( @@ -656,7 +663,12 @@ pub fn render_table_row( .map(IntoElement::into_any_element) .into_iter() .zip(column_widths) - .map(|(cell, width)| base_cell_style_text(width, cx).px_1().py_0p5().child(cell)), + .map(|(cell, width)| { + base_cell_style_text(width, table_context.use_ui_font, cx) + .px_1() + .py_0p5() + .child(cell) + }), ); let row = if let Some(map_row) = table_context.map_row { @@ -700,7 +712,7 @@ pub fn render_table_header( .border_color(cx.theme().colors().border) .children(headers.into_iter().enumerate().zip(column_widths).map( |((header_idx, h), width)| { - base_cell_style_text(width, cx) + base_cell_style_text(width, table_context.use_ui_font, cx) .child(h) .id(ElementId::NamedInteger( shared_element_id.clone(), @@ -739,6 +751,7 @@ pub struct TableRenderContext { pub total_row_count: usize, pub column_widths: Option<[Length; COLS]>, pub map_row: Option), &mut Window, &mut App) -> AnyElement>>, + pub use_ui_font: bool, } impl TableRenderContext { @@ -748,6 +761,7 @@ impl TableRenderContext { total_row_count: table.rows.len(), column_widths: table.col_widths.as_ref().map(|widths| widths.lengths(cx)), map_row: table.map_row.clone(), + use_ui_font: table.use_ui_font, } } } diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index 7429fcb8e8d5e4b485f69ea87c37d7d670c3b199..b90934e67c2a689e1f7bb9704ff28a408de3049a 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -30,6 +30,7 @@ credentials_provider.workspace = true db.workspace = true edit_prediction.workspace = true edit_prediction_context.workspace = true +edit_prediction_context2.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true diff --git a/crates/zeta/src/assemble_excerpts.rs b/crates/zeta/src/assemble_excerpts.rs deleted file mode 100644 index f2a5b5adb1fcffab945cd9bdb88153bc5e494138..0000000000000000000000000000000000000000 --- a/crates/zeta/src/assemble_excerpts.rs +++ /dev/null @@ -1,173 +0,0 @@ -use cloud_llm_client::predict_edits_v3::Excerpt; -use edit_prediction_context::Line; -use language::{BufferSnapshot, Point}; -use std::ops::Range; - -pub fn assemble_excerpts( - buffer: &BufferSnapshot, - merged_line_ranges: impl IntoIterator>, -) -> Vec { - let mut output = Vec::new(); - - let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None); - let mut outline_items = outline_items.into_iter().peekable(); - - for range in merged_line_ranges { - let point_range = Point::new(range.start.0, 0)..Point::new(range.end.0, 0); - - while let Some(outline_item) = outline_items.peek() { - if outline_item.range.start >= point_range.start { - break; - } - if outline_item.range.end > point_range.start { - let mut point_range = outline_item.source_range_for_text.clone(); - point_range.start.column = 0; - point_range.end.column = buffer.line_len(point_range.end.row); - - output.push(Excerpt { - start_line: Line(point_range.start.row), - text: buffer - .text_for_range(point_range.clone()) - .collect::() - .into(), - }) - } - outline_items.next(); - } - - output.push(Excerpt { - start_line: Line(point_range.start.row), - text: buffer - .text_for_range(point_range.clone()) - .collect::() - .into(), - }) - } - - output -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use super::*; - use cloud_llm_client::predict_edits_v3; - use gpui::{TestAppContext, prelude::*}; - use indoc::indoc; - use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt}; - use pretty_assertions::assert_eq; - use util::test::marked_text_ranges; - - #[gpui::test] - fn test_rust(cx: &mut TestAppContext) { - let table = [ - ( - indoc! {r#" - struct User { - first_name: String, - « last_name: String, - ageˇ: u32, - » email: String, - create_at: Instant, - } - - impl User { - pub fn first_name(&self) -> String { - self.first_name.clone() - } - - pub fn full_name(&self) -> String { - « format!("{} {}", self.first_name, self.last_name) - » } - } - "#}, - indoc! {r#" - 1|struct User { - … - 3| last_name: String, - 4| age<|cursor|>: u32, - … - 9|impl User { - … - 14| pub fn full_name(&self) -> String { - 15| format!("{} {}", self.first_name, self.last_name) - … - "#}, - ), - ( - indoc! {r#" - struct User { - first_name: String, - « last_name: String, - age: u32, - } - »"# - }, - indoc! {r#" - 1|struct User { - … - 3| last_name: String, - 4| age: u32, - 5|} - "#}, - ), - ]; - - for (input, expected_output) in table { - let input_without_ranges = input.replace(['«', '»'], ""); - let input_without_caret = input.replace('ˇ', ""); - let cursor_offset = input_without_ranges.find('ˇ'); - let (input, ranges) = marked_text_ranges(&input_without_caret, false); - let buffer = - cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx)); - buffer.read_with(cx, |buffer, _cx| { - let insertions = cursor_offset - .map(|offset| { - let point = buffer.offset_to_point(offset); - vec![( - predict_edits_v3::Point { - line: Line(point.row), - column: point.column, - }, - "<|cursor|>", - )] - }) - .unwrap_or_default(); - let ranges: Vec> = ranges - .into_iter() - .map(|range| { - let point_range = range.to_point(&buffer); - Line(point_range.start.row)..Line(point_range.end.row) - }) - .collect(); - - let mut output = String::new(); - cloud_zeta2_prompt::write_excerpts( - assemble_excerpts(&buffer.snapshot(), ranges).iter(), - &insertions, - Line(buffer.max_point().row), - true, - &mut output, - ); - assert_eq!(output, expected_output); - }); - } - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(language::tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() - } -} diff --git a/crates/zeta/src/retrieval_search.rs b/crates/zeta/src/retrieval_search.rs index bcc0233ff7e872a151ecddf2cf55a3cb434f02b3..f429f167744422c3641b5a68ca662af48c8e1614 100644 --- a/crates/zeta/src/retrieval_search.rs +++ b/crates/zeta/src/retrieval_search.rs @@ -1,6 +1,7 @@ use anyhow::Result; use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery; use collections::HashMap; +use edit_prediction_context2::{RelatedExcerpt, RelatedFile}; use futures::{ StreamExt, channel::mpsc::{self, UnboundedSender}, @@ -8,7 +9,7 @@ use futures::{ use gpui::{AppContext, AsyncApp, Entity}; use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint}; use project::{ - Project, WorktreeSettings, + Project, ProjectPath, WorktreeSettings, search::{SearchQuery, SearchResult}, }; use smol::channel; @@ -20,14 +21,14 @@ use util::{ use workspace::item::Settings as _; #[cfg(feature = "eval-support")] -type CachedSearchResults = std::collections::BTreeMap>>; +type CachedSearchResults = std::collections::BTreeMap>>; pub async fn run_retrieval_searches( queries: Vec, project: Entity, #[cfg(feature = "eval-support")] eval_cache: Option>, cx: &mut AsyncApp, -) -> Result, Vec>>> { +) -> Result> { #[cfg(feature = "eval-support")] let cache = if let Some(eval_cache) = eval_cache { use crate::EvalCacheEntryKind; @@ -54,24 +55,44 @@ pub async fn run_retrieval_searches( if let Some(cached_results) = eval_cache.read(key) { let file_results = serde_json::from_str::(&cached_results) .context("Failed to deserialize cached search results")?; - let mut results = HashMap::default(); + let mut results = Vec::new(); for (path, ranges) in file_results { + let project_path = project.update(cx, |project, cx| { + project.find_project_path(path, cx).unwrap() + })?; let buffer = project .update(cx, |project, cx| { - let project_path = project.find_project_path(path, cx).unwrap(); - project.open_buffer(project_path, cx) + project.open_buffer(project_path.clone(), cx) })? .await?; let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; let mut ranges: Vec<_> = ranges .into_iter() - .map(|range| { - snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end) - }) + .map( + |Range { + start: (start_row, start_col), + end: (end_row, end_col), + }| { + snapshot.anchor_before(Point::new(start_row, start_col)) + ..snapshot.anchor_after(Point::new(end_row, end_col)) + }, + ) .collect(); merge_anchor_ranges(&mut ranges, &snapshot); - results.insert(buffer, ranges); + results.push(RelatedFile { + path: project_path, + buffer: buffer.downgrade(), + excerpts: ranges + .into_iter() + .map(|range| RelatedExcerpt { + point_range: range.to_point(&snapshot), + text: snapshot.as_rope().slice(range.to_offset(&snapshot)), + anchor_range: range, + }) + .collect(), + max_row: snapshot.max_point().row, + }); } return Ok(results); @@ -117,14 +138,29 @@ pub async fn run_retrieval_searches( #[cfg(feature = "eval-support")] let cache = cache.clone(); cx.background_spawn(async move { - let mut results: HashMap, Vec>> = HashMap::default(); + let mut results: Vec = Vec::default(); let mut snapshots = HashMap::default(); let mut total_bytes = 0; - 'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await { - snapshots.insert(buffer.entity_id(), snapshot); - let existing = results.entry(buffer).or_default(); - existing.reserve(excerpts.len()); + 'outer: while let Some((project_path, buffer, snapshot, excerpts)) = results_rx.next().await + { + let existing = results + .iter_mut() + .find(|related_file| related_file.buffer.entity_id() == buffer.entity_id()); + let existing = match existing { + Some(existing) => existing, + None => { + results.push(RelatedFile { + path: project_path, + buffer: buffer.downgrade(), + excerpts: Vec::new(), + max_row: snapshot.max_point().row, + }); + results.last_mut().unwrap() + } + }; + // let existing = results.entry(buffer).or_default(); + existing.excerpts.reserve(excerpts.len()); for (range, size) in excerpts { // Blunt trimming of the results until we have a proper algorithmic filtering step @@ -133,24 +169,34 @@ pub async fn run_retrieval_searches( break 'outer; } total_bytes += size; - existing.push(range); + existing.excerpts.push(RelatedExcerpt { + point_range: range.to_point(&snapshot), + text: snapshot.as_rope().slice(range.to_offset(&snapshot)), + anchor_range: range, + }); } + snapshots.insert(buffer.entity_id(), snapshot); } #[cfg(feature = "eval-support")] if let Some((cache, queries, key)) = cache { let cached_results: CachedSearchResults = results .iter() - .filter_map(|(buffer, ranges)| { - let snapshot = snapshots.get(&buffer.entity_id())?; - let path = snapshot.file().map(|f| f.path()); - let mut ranges = ranges + .map(|related_file| { + let mut ranges = related_file + .excerpts .iter() - .map(|range| range.to_offset(&snapshot)) + .map( + |RelatedExcerpt { + point_range: Range { start, end }, + .. + }| { + (start.row, start.column)..(end.row, end.column) + }, + ) .collect::>(); ranges.sort_unstable_by_key(|range| (range.start, range.end)); - - Some((path?.as_std_path().to_path_buf(), ranges)) + (related_file.path.path.as_std_path().to_path_buf(), ranges) }) .collect(); cache.write( @@ -160,10 +206,8 @@ pub async fn run_retrieval_searches( ); } - for (buffer, ranges) in results.iter_mut() { - if let Some(snapshot) = snapshots.get(&buffer.entity_id()) { - merge_anchor_ranges(ranges, snapshot); - } + for related_file in results.iter_mut() { + related_file.merge_excerpts(); } Ok(results) @@ -171,6 +215,7 @@ pub async fn run_retrieval_searches( .await } +#[cfg(feature = "eval-support")] pub(crate) fn merge_anchor_ranges(ranges: &mut Vec>, snapshot: &BufferSnapshot) { ranges.sort_unstable_by(|a, b| { a.start @@ -201,6 +246,7 @@ const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5; struct SearchJob { buffer: Entity, snapshot: BufferSnapshot, + project_path: ProjectPath, ranges: Vec>, query_ix: usize, jobs_tx: channel::Sender, @@ -208,7 +254,12 @@ struct SearchJob { async fn run_query( input_query: SearchToolQuery, - results_tx: UnboundedSender<(Entity, BufferSnapshot, Vec<(Range, usize)>)>, + results_tx: UnboundedSender<( + ProjectPath, + Entity, + BufferSnapshot, + Vec<(Range, usize)>, + )>, path_style: PathStyle, exclude_matcher: PathMatcher, project: &Entity, @@ -257,12 +308,21 @@ async fn run_query( .read_with(cx, |buffer, _| buffer.parsing_idle())? .await; let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let Some(file) = snapshot.file() else { + continue; + }; + + let project_path = cx.update(|cx| ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + })?; let expanded_ranges: Vec<_> = ranges .into_iter() .filter_map(|range| expand_to_parent_range(&range, &snapshot)) .collect(); jobs_tx .send(SearchJob { + project_path, buffer, snapshot, ranges: expanded_ranges, @@ -301,6 +361,13 @@ async fn run_query( while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await { let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + let Some(file) = snapshot.file() else { + continue; + }; + let project_path = cx.update(|cx| ProjectPath { + worktree_id: file.worktree_id(cx), + path: file.path().clone(), + })?; let ranges = ranges .into_iter() @@ -314,7 +381,8 @@ async fn run_query( }) .collect(); - let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges)); + let send_result = + results_tx.unbounded_send((project_path, buffer.clone(), snapshot.clone(), ranges)); if let Err(err) = send_result && !err.is_disconnected() @@ -330,7 +398,12 @@ async fn run_query( } async fn process_nested_search_job( - results_tx: &UnboundedSender<(Entity, BufferSnapshot, Vec<(Range, usize)>)>, + results_tx: &UnboundedSender<( + ProjectPath, + Entity, + BufferSnapshot, + Vec<(Range, usize)>, + )>, queries: &Vec, content_query: &Option, job: SearchJob, @@ -347,6 +420,7 @@ async fn process_nested_search_job( } job.jobs_tx .send(SearchJob { + project_path: job.project_path, buffer: job.buffer, snapshot: job.snapshot, ranges: subranges, @@ -382,7 +456,8 @@ async fn process_nested_search_job( }) .collect(); - let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches)); + let send_result = + results_tx.unbounded_send((job.project_path, job.buffer, job.snapshot, matches)); if let Err(err) = send_result && !err.is_disconnected() @@ -413,230 +488,3 @@ fn expand_to_parent_range( let node = snapshot.syntax_ancestor(line_range)?; Some(node.byte_range()) } - -#[cfg(test)] -mod tests { - use super::*; - use crate::assemble_excerpts::assemble_excerpts; - use cloud_zeta2_prompt::write_codeblock; - use edit_prediction_context::Line; - use gpui::TestAppContext; - use indoc::indoc; - use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust}; - use pretty_assertions::assert_eq; - use project::FakeFs; - use serde_json::json; - use settings::SettingsStore; - use std::path::Path; - use util::path; - - #[gpui::test] - async fn test_retrieval(cx: &mut TestAppContext) { - init_test(cx); - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "user.rs": indoc!{" - pub struct Organization { - owner: Arc, - } - - pub struct User { - first_name: String, - last_name: String, - } - - impl Organization { - pub fn owner(&self) -> Arc { - self.owner.clone() - } - } - - impl User { - pub fn new(first_name: String, last_name: String) -> Self { - Self { - first_name, - last_name - } - } - - pub fn first_name(&self) -> String { - self.first_name.clone() - } - - pub fn last_name(&self) -> String { - self.last_name.clone() - } - } - "}, - "main.rs": indoc!{r#" - fn main() { - let user = User::new(FIRST_NAME.clone(), "doe".into()); - println!("user {:?}", user); - } - "#}, - }), - ) - .await; - - let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await; - project.update(cx, |project, _cx| { - project.languages().add(rust_lang().into()) - }); - - assert_results( - &project, - SearchToolQuery { - glob: "user.rs".into(), - syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()], - content: None, - }, - indoc! {r#" - `````root/user.rs - … - impl User { - … - pub fn first_name(&self) -> String { - self.first_name.clone() - } - … - ````` - "#}, - cx, - ) - .await; - - assert_results( - &project, - SearchToolQuery { - glob: "user.rs".into(), - syntax_node: vec!["impl\\s+User".into()], - content: Some("\\.clone".into()), - }, - indoc! {r#" - `````root/user.rs - … - impl User { - … - pub fn first_name(&self) -> String { - self.first_name.clone() - … - pub fn last_name(&self) -> String { - self.last_name.clone() - … - ````` - "#}, - cx, - ) - .await; - - assert_results( - &project, - SearchToolQuery { - glob: "*.rs".into(), - syntax_node: vec![], - content: Some("\\.clone".into()), - }, - indoc! {r#" - `````root/main.rs - fn main() { - let user = User::new(FIRST_NAME.clone(), "doe".into()); - … - ````` - - `````root/user.rs - … - impl Organization { - pub fn owner(&self) -> Arc { - self.owner.clone() - … - impl User { - … - pub fn first_name(&self) -> String { - self.first_name.clone() - … - pub fn last_name(&self) -> String { - self.last_name.clone() - … - ````` - "#}, - cx, - ) - .await; - } - - async fn assert_results( - project: &Entity, - query: SearchToolQuery, - expected_output: &str, - cx: &mut TestAppContext, - ) { - let results = run_retrieval_searches( - vec![query], - project.clone(), - #[cfg(feature = "eval-support")] - None, - &mut cx.to_async(), - ) - .await - .unwrap(); - - let mut results = results.into_iter().collect::>(); - results.sort_by_key(|results| { - results - .0 - .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone()) - }); - - let mut output = String::new(); - for (buffer, ranges) in results { - buffer.read_with(cx, |buffer, cx| { - let excerpts = ranges.into_iter().map(|range| { - let point_range = range.to_point(buffer); - if point_range.end.column > 0 { - Line(point_range.start.row)..Line(point_range.end.row + 1) - } else { - Line(point_range.start.row)..Line(point_range.end.row) - } - }); - - write_codeblock( - &buffer.file().unwrap().full_path(cx), - assemble_excerpts(&buffer.snapshot(), excerpts).iter(), - &[], - Line(buffer.max_point().row), - false, - &mut output, - ); - }); - } - output.pop(); - - assert_eq!(output, expected_output); - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() - } - - fn init_test(cx: &mut TestAppContext) { - cx.update(move |cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - zlog::init_test(); - }); - } -} diff --git a/crates/zeta/src/sweep_ai.rs b/crates/zeta/src/sweep_ai.rs index 8fd5398f3facc807d99951c48c749e9247fb5670..0bc0d1d41e2393212f865e402912f6d760aa252e 100644 --- a/crates/zeta/src/sweep_ai.rs +++ b/crates/zeta/src/sweep_ai.rs @@ -1,6 +1,7 @@ use anyhow::{Context as _, Result}; use cloud_llm_client::predict_edits_v3::Event; use credentials_provider::CredentialsProvider; +use edit_prediction_context2::RelatedFile; use futures::{AsyncReadExt as _, FutureExt, future::Shared}; use gpui::{ App, AppContext as _, Entity, Task, @@ -49,6 +50,7 @@ impl SweepAi { position: language::Anchor, events: Vec>, recent_paths: &VecDeque, + related_files: Vec, diagnostic_search_range: Range, cx: &mut App, ) -> Task>> { @@ -120,6 +122,19 @@ impl SweepAi { }) .collect::>(); + let retrieval_chunks = related_files + .iter() + .flat_map(|related_file| { + related_file.excerpts.iter().map(|excerpt| FileChunk { + file_path: related_file.path.path.as_unix_str().to_string(), + start_line: excerpt.point_range.start.row as usize, + end_line: excerpt.point_range.end.row as usize, + content: excerpt.text.to_string(), + timestamp: None, + }) + }) + .collect(); + let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false); let mut diagnostic_content = String::new(); let mut diagnostic_count = 0; @@ -168,7 +183,7 @@ impl SweepAi { multiple_suggestions: false, branch: None, file_chunks, - retrieval_chunks: vec![], + retrieval_chunks, recent_user_actions: vec![], use_bytes: true, // TODO @@ -320,7 +335,7 @@ struct AutocompleteRequest { pub cursor_position: usize, pub original_file_contents: String, pub file_chunks: Vec, - pub retrieval_chunks: Vec, + pub retrieval_chunks: Vec, pub recent_user_actions: Vec, pub multiple_suggestions: bool, pub privacy_mode_enabled: bool, @@ -337,15 +352,6 @@ struct FileChunk { pub timestamp: Option, } -#[derive(Debug, Clone, Serialize)] -struct RetrievalChunk { - pub file_path: String, - pub start_line: usize, - pub end_line: usize, - pub content: String, - pub timestamp: u64, -} - #[derive(Debug, Clone, Serialize)] struct UserAction { pub action_type: ActionType, diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 33d37d9e3aa0c5c89830d5ec86663330da1daf77..576067b9844cd668c69411d7a4098975db4a5d26 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -1,7 +1,7 @@ use anyhow::{Context as _, Result, anyhow, bail}; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; -use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, Signature}; +use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat}; use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, @@ -14,31 +14,39 @@ use collections::{HashMap, HashSet}; use command_palette_hooks::CommandPaletteFilter; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use edit_prediction_context::{ - DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, - EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line, - SyntaxIndex, SyntaxIndexState, + EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions, + EditPredictionScoreOptions, Line, SyntaxIndex, +}; +use edit_prediction_context2::{ + RelatedExcerpt, RelatedExcerptStore, RelatedExcerptStoreEvent, RelatedFile, }; use feature_flags::{FeatureFlag, FeatureFlagAppExt as _, PredictEditsRateCompletionsFeatureFlag}; -use futures::channel::mpsc::UnboundedReceiver; -use futures::channel::{mpsc, oneshot}; -use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, select_biased}; +use futures::{ + AsyncReadExt as _, FutureExt as _, StreamExt as _, + channel::{ + mpsc::{self, UnboundedReceiver}, + oneshot, + }, + select_biased, +}; use gpui::BackgroundExecutor; use gpui::{ App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions, http_client::{self, AsyncBody, Method}, prelude::*, }; +use language::language_settings::all_language_settings; use language::{ Anchor, Buffer, DiagnosticSet, File, LanguageServerId, Point, ToOffset as _, ToPoint, }; use language::{BufferSnapshot, OffsetRangeExt}; use language_model::{LlmApiToken, RefreshLlmTokenListener}; use open_ai::FunctionDefinition; -use project::{DisableAiSettings, Project, ProjectPath, WorktreeId}; +use project::{DisableAiSettings, Project, ProjectItem as _, ProjectPath, WorktreeId}; use release_channel::AppVersion; use semver::Version; use serde::de::DeserializeOwned; -use settings::{EditPredictionProvider, Settings as _, SettingsStore, update_settings_file}; +use settings::{EditPredictionProvider, Settings, SettingsStore, update_settings_file}; use std::any::{Any as _, TypeId}; use std::collections::{VecDeque, hash_map}; use telemetry_events::EditPredictionRating; @@ -52,11 +60,9 @@ use std::sync::{Arc, LazyLock}; use std::time::{Duration, Instant}; use std::{env, mem}; use thiserror::Error; -use util::rel_path::RelPathBuf; use util::{LogErrorFuture, RangeExt as _, ResultExt as _, TryFutureExt}; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; -pub mod assemble_excerpts; mod license_detection; mod onboarding_modal; mod prediction; @@ -71,7 +77,6 @@ pub mod zeta1; #[cfg(test)] mod zeta_tests; -use crate::assemble_excerpts::assemble_excerpts; use crate::license_detection::LicenseDetectionWatcher; use crate::onboarding_modal::ZedPredictModal; pub use crate::prediction::EditPrediction; @@ -115,8 +120,7 @@ pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPrediction target_before_cursor_over_total_bytes: 0.5, }; -pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = - ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS); +pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Lsp(DEFAULT_EXCERPT_OPTIONS); pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions { excerpt: DEFAULT_EXCERPT_OPTIONS, @@ -190,6 +194,7 @@ pub struct Zeta { llm_token: LlmApiToken, _llm_token_subscription: Subscription, projects: HashMap, + use_context: bool, options: ZetaOptions, update_required: bool, debug_tx: Option>, @@ -225,6 +230,7 @@ pub struct ZetaOptions { pub enum ContextMode { Agentic(AgenticContextOptions), Syntax(EditPredictionContextOptions), + Lsp(EditPredictionExcerptOptions), } #[derive(Debug, Clone, PartialEq)] @@ -237,6 +243,7 @@ impl ContextMode { match self { ContextMode::Agentic(options) => &options.excerpt, ContextMode::Syntax(options) => &options.excerpt, + ContextMode::Lsp(options) => &options, } } } @@ -244,23 +251,22 @@ impl ContextMode { #[derive(Debug)] pub enum ZetaDebugInfo { ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo), - SearchQueriesGenerated(ZetaSearchQueryDebugInfo), - SearchQueriesExecuted(ZetaContextRetrievalDebugInfo), - ContextRetrievalFinished(ZetaContextRetrievalDebugInfo), + ContextRetrievalFinished(ZetaContextRetrievalFinishedDebugInfo), EditPredictionRequested(ZetaEditPredictionDebugInfo), } #[derive(Debug)] pub struct ZetaContextRetrievalStartedDebugInfo { - pub project: Entity, + pub project_entity_id: EntityId, pub timestamp: Instant, pub search_prompt: String, } #[derive(Debug)] -pub struct ZetaContextRetrievalDebugInfo { - pub project: Entity, +pub struct ZetaContextRetrievalFinishedDebugInfo { + pub project_entity_id: EntityId, pub timestamp: Instant, + pub metadata: Vec<(&'static str, SharedString)>, } #[derive(Debug)] @@ -273,17 +279,9 @@ pub struct ZetaEditPredictionDebugInfo { pub response_rx: oneshot::Receiver<(Result, Duration)>, } -#[derive(Debug)] -pub struct ZetaSearchQueryDebugInfo { - pub project: Entity, - pub timestamp: Instant, - pub search_queries: Vec, -} - pub type RequestDebugInfo = predict_edits_v3::DebugInfo; struct ZetaProject { - syntax_index: Option>, events: VecDeque>, last_event: Option, recent_paths: VecDeque, @@ -291,16 +289,26 @@ struct ZetaProject { current_prediction: Option, next_pending_prediction_id: usize, pending_predictions: ArrayVec, + context_updates_tx: smol::channel::Sender<()>, + context_updates_rx: smol::channel::Receiver<()>, last_prediction_refresh: Option<(EntityId, Instant)>, cancelled_predictions: HashSet, - context: Option, Vec>>>, - refresh_context_task: Option>>>, - refresh_context_debounce_task: Option>>, - refresh_context_timestamp: Option, + context: ZetaProjectContext, license_detection_watchers: HashMap>, _subscription: gpui::Subscription, } +enum ZetaProjectContext { + Syntax(Entity), + Lsp(Entity), + Agentic { + refresh_context_task: Option>>>, + refresh_context_debounce_task: Option>>, + refresh_context_timestamp: Option, + context: Vec, + }, +} + impl ZetaProject { pub fn events(&self, cx: &App) -> Vec> { self.events @@ -521,11 +529,12 @@ impl Zeta { }) .detach(); - Self { + let mut this = Self { projects: HashMap::default(), client, user_store, options: DEFAULT_OPTIONS, + use_context: false, llm_token, _llm_token_subscription: cx.subscribe( &refresh_llm_token_listener, @@ -549,7 +558,22 @@ impl Zeta { reject_predictions_tx: reject_tx, rated_predictions: Default::default(), shown_predictions: Default::default(), - } + }; + + this.enable_or_disable_context_retrieval(cx); + let weak_this = cx.weak_entity(); + cx.on_flags_ready(move |_, cx| { + weak_this + .update(cx, |this, cx| this.enable_or_disable_context_retrieval(cx)) + .ok(); + }) + .detach(); + cx.observe_global::(|this, cx| { + this.enable_or_disable_context_retrieval(cx); + }) + .detach(); + + this } pub fn set_edit_prediction_model(&mut self, model: ZetaEditPredictionModel) { @@ -584,29 +608,29 @@ impl Zeta { self.options = options; } + pub fn set_use_context(&mut self, use_context: bool) { + self.use_context = use_context; + } + pub fn clear_history(&mut self) { for zeta_project in self.projects.values_mut() { zeta_project.events.clear(); } } - pub fn context_for_project( - &self, + pub fn context_for_project<'a>( + &'a self, project: &Entity, - ) -> impl Iterator, &[Range])> { + cx: &'a App, + ) -> &'a [RelatedFile] { self.projects .get(&project.entity_id()) - .and_then(|project| { - Some( - project - .context - .as_ref()? - .iter() - .map(|(buffer, ranges)| (buffer.clone(), ranges.as_slice())), - ) + .and_then(|project| match &project.context { + ZetaProjectContext::Syntax(_) => None, + ZetaProjectContext::Lsp(store) => Some(store.read(cx).related_files()), + ZetaProjectContext::Agentic { context, .. } => Some(context.as_slice()), }) - .into_iter() - .flatten() + .unwrap_or(&[]) } pub fn usage(&self, cx: &App) -> Option { @@ -636,34 +660,122 @@ impl Zeta { project: &Entity, cx: &mut Context, ) -> &mut ZetaProject { + let entity_id = project.entity_id(); + let (context_updates_tx, context_updates_rx) = smol::channel::unbounded(); self.projects - .entry(project.entity_id()) + .entry(entity_id) .or_insert_with(|| ZetaProject { - syntax_index: if let ContextMode::Syntax(_) = &self.options.context { - Some(cx.new(|cx| { + context: match &self.options.context { + ContextMode::Agentic(_) => ZetaProjectContext::Agentic { + refresh_context_task: None, + refresh_context_debounce_task: None, + refresh_context_timestamp: None, + context: Vec::new(), + }, + ContextMode::Syntax(_) => ZetaProjectContext::Syntax(cx.new(|cx| { SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx) - })) - } else { - None + })), + ContextMode::Lsp(_) => { + let related_excerpt_store = + cx.new(|cx| RelatedExcerptStore::new(project, cx)); + cx.subscribe( + &related_excerpt_store, + move |this, _, event, _| match event { + RelatedExcerptStoreEvent::StartedRefresh => { + if let Some(debug_tx) = this.debug_tx.clone() { + debug_tx + .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted( + ZetaContextRetrievalStartedDebugInfo { + project_entity_id: entity_id, + timestamp: Instant::now(), + search_prompt: String::new(), + }, + )) + .ok(); + } + } + RelatedExcerptStoreEvent::FinishedRefresh { + cache_hit_count, + cache_miss_count, + mean_definition_latency, + max_definition_latency, + } => { + if let Some(debug_tx) = this.debug_tx.clone() { + debug_tx + .unbounded_send( + ZetaDebugInfo::ContextRetrievalFinished( + ZetaContextRetrievalFinishedDebugInfo { + project_entity_id: entity_id, + timestamp: Instant::now(), + metadata: vec![ + ( + "Cache Hits", + format!( + "{}/{}", + cache_hit_count, + cache_hit_count + + cache_miss_count + ) + .into(), + ), + ( + "Max LSP Time", + format!( + "{} ms", + max_definition_latency + .as_millis() + ) + .into(), + ), + ( + "Mean LSP Time", + format!( + "{} ms", + mean_definition_latency + .as_millis() + ) + .into(), + ), + ], + }, + ), + ) + .ok(); + } + if let Some(project_state) = this.projects.get(&entity_id) { + project_state.context_updates_tx.send_blocking(()).ok(); + } + } + }, + ) + .detach(); + ZetaProjectContext::Lsp(related_excerpt_store) + } }, events: VecDeque::new(), last_event: None, recent_paths: VecDeque::new(), + context_updates_rx, + context_updates_tx, registered_buffers: HashMap::default(), current_prediction: None, cancelled_predictions: HashSet::default(), pending_predictions: ArrayVec::new(), next_pending_prediction_id: 0, last_prediction_refresh: None, - context: None, - refresh_context_task: None, - refresh_context_debounce_task: None, - refresh_context_timestamp: None, license_detection_watchers: HashMap::default(), _subscription: cx.subscribe(&project, Self::handle_project_event), }) } + pub fn project_context_updates( + &self, + project: &Entity, + ) -> Option> { + let project_state = self.projects.get(&project.entity_id())?; + Some(project_state.context_updates_rx.clone()) + } + fn handle_project_event( &mut self, project: Entity, @@ -1349,6 +1461,11 @@ impl Zeta { position, events, &zeta_project.recent_paths, + if self.use_context { + self.context_for_project(&project, cx).to_vec() + } else { + Vec::new() + }, diagnostic_search_range.clone(), cx, ), @@ -1480,73 +1597,34 @@ impl Zeta { trigger: PredictEditsRequestTrigger, cx: &mut Context, ) -> Task>> { - let project_state = self.projects.get(&project.entity_id()); - - let index_state = project_state.and_then(|state| { - state - .syntax_index - .as_ref() - .map(|syntax_index| syntax_index.read_with(cx, |index, _cx| index.state().clone())) - }); let options = self.options.clone(); let buffer_snapshotted_at = Instant::now(); - let Some(excerpt_path) = active_snapshot + + let Some((excerpt_path, active_project_path)) = active_snapshot .file() - .map(|path| -> Arc { path.full_path(cx).into() }) + .map(|file| -> Arc { file.full_path(cx).into() }) + .zip(active_buffer.read(cx).project_path(cx)) else { return Task::ready(Err(anyhow!("No file path for excerpt"))); }; + let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); - let worktree_snapshots = project - .read(cx) - .worktrees(cx) - .map(|worktree| worktree.read(cx).snapshot()) - .collect::>(); let debug_tx = self.debug_tx.clone(); let diagnostics = active_snapshot.diagnostic_sets().clone(); let file = active_buffer.read(cx).file(); - let parent_abs_path = project::File::from_dyn(file).and_then(|f| { - let mut path = f.worktree.read(cx).absolutize(&f.path); - if path.pop() { Some(path) } else { None } - }); + + let active_file_full_path = file.as_ref().map(|f| f.full_path(cx)); // TODO data collection let can_collect_data = file .as_ref() .map_or(false, |file| self.can_collect_file(project, file, cx)); - let empty_context_files = HashMap::default(); - let context_files = project_state - .and_then(|project_state| project_state.context.as_ref()) - .unwrap_or(&empty_context_files); - - #[cfg(feature = "eval-support")] - let parsed_fut = futures::future::join_all( - context_files - .keys() - .map(|buffer| buffer.read(cx).parsing_idle()), - ); - - let mut included_files = context_files - .iter() - .filter_map(|(buffer_entity, ranges)| { - let buffer = buffer_entity.read(cx); - Some(( - buffer_entity.clone(), - buffer.snapshot(), - buffer.file()?.full_path(cx).into(), - ranges.clone(), - )) - }) - .collect::>(); - - included_files.sort_by(|(_, _, path_a, ranges_a), (_, _, path_b, ranges_b)| { - (path_a, ranges_a.len()).cmp(&(path_b, ranges_b.len())) - }); + let mut included_files = self.context_for_project(project, cx).to_vec(); #[cfg(feature = "eval-support")] let eval_cache = self.eval_cache.clone(); @@ -1554,15 +1632,6 @@ impl Zeta { let request_task = cx.background_spawn({ let active_buffer = active_buffer.clone(); async move { - #[cfg(feature = "eval-support")] - parsed_fut.await; - - let index_state = if let Some(index_state) = index_state { - Some(index_state.lock_owned().await) - } else { - None - }; - let cursor_offset = position.to_offset(&active_snapshot); let cursor_point = cursor_offset.to_point(&active_snapshot); @@ -1576,110 +1645,84 @@ impl Zeta { options.max_diagnostic_bytes, ); - let cloud_request = match options.context { - ContextMode::Agentic(context_options) => { - let Some(excerpt) = EditPredictionExcerpt::select_from_buffer( - cursor_point, - &active_snapshot, - &context_options.excerpt, - index_state.as_deref(), - ) else { - return Ok((None, None)); - }; + let excerpt_options = options.context.excerpt(); - let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start) - ..active_snapshot.anchor_before(excerpt.range.end); + let Some(excerpt) = EditPredictionExcerpt::select_from_buffer( + cursor_point, + &active_snapshot, + &excerpt_options, + None, + ) else { + return Ok((None, None)); + }; - if let Some(buffer_ix) = - included_files.iter().position(|(_, snapshot, _, _)| { - snapshot.remote_id() == active_snapshot.remote_id() - }) - { - let (_, buffer, _, ranges) = &mut included_files[buffer_ix]; - ranges.push(excerpt_anchor_range); - retrieval_search::merge_anchor_ranges(ranges, buffer); - let last_ix = included_files.len() - 1; - included_files.swap(buffer_ix, last_ix); - } else { - included_files.push(( - active_buffer.clone(), - active_snapshot.clone(), - excerpt_path.clone(), - vec![excerpt_anchor_range], - )); - } + let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start) + ..active_snapshot.anchor_before(excerpt.range.end); + let related_excerpt = RelatedExcerpt { + anchor_range: excerpt_anchor_range.clone(), + point_range: Point::new(excerpt.line_range.start.0, 0) + ..Point::new(excerpt.line_range.end.0, 0), + text: active_snapshot.as_rope().slice(excerpt.range), + }; + + if let Some(buffer_ix) = included_files + .iter() + .position(|file| file.buffer.entity_id() == active_buffer.entity_id()) + { + let file = &mut included_files[buffer_ix]; + file.excerpts.push(related_excerpt); + file.merge_excerpts(); + let last_ix = included_files.len() - 1; + included_files.swap(buffer_ix, last_ix); + } else { + let active_file = RelatedFile { + path: active_project_path, + buffer: active_buffer.downgrade(), + excerpts: vec![related_excerpt], + max_row: active_snapshot.max_point().row, + }; + included_files.push(active_file); + } - let included_files = included_files + let included_files = included_files + .iter() + .map(|related_file| predict_edits_v3::IncludedFile { + path: Arc::from(related_file.path.path.as_std_path()), + max_row: Line(related_file.max_row), + excerpts: related_file + .excerpts .iter() - .map(|(_, snapshot, path, ranges)| { - let ranges = ranges - .iter() - .map(|range| { - let point_range = range.to_point(&snapshot); - Line(point_range.start.row)..Line(point_range.end.row) - }) - .collect::>(); - let excerpts = assemble_excerpts(&snapshot, ranges); - predict_edits_v3::IncludedFile { - path: path.clone(), - max_row: Line(snapshot.max_point().row), - excerpts, - } + .map(|excerpt| predict_edits_v3::Excerpt { + start_line: Line(excerpt.point_range.start.row), + text: excerpt.text.to_string().into(), }) - .collect::>(); - - predict_edits_v3::PredictEditsRequest { - excerpt_path, - excerpt: String::new(), - excerpt_line_range: Line(0)..Line(0), - excerpt_range: 0..0, - cursor_point: predict_edits_v3::Point { - line: predict_edits_v3::Line(cursor_point.row), - column: cursor_point.column, - }, - included_files, - referenced_declarations: vec![], - events, - can_collect_data, - diagnostic_groups, - diagnostic_groups_truncated, - debug_info: debug_tx.is_some(), - prompt_max_bytes: Some(options.max_prompt_bytes), - prompt_format: options.prompt_format, - // TODO [zeta2] - signatures: vec![], - excerpt_parent: None, - git_info: None, - trigger, - } - } - ContextMode::Syntax(context_options) => { - let Some(context) = EditPredictionContext::gather_context( - cursor_point, - &active_snapshot, - parent_abs_path.as_deref(), - &context_options, - index_state.as_deref(), - ) else { - return Ok((None, None)); - }; - - make_syntax_context_cloud_request( - excerpt_path, - context, - events, - can_collect_data, - diagnostic_groups, - diagnostic_groups_truncated, - None, - debug_tx.is_some(), - &worktree_snapshots, - index_state.as_deref(), - Some(options.max_prompt_bytes), - options.prompt_format, - trigger, - ) - } + .collect(), + }) + .collect::>(); + + let cloud_request = predict_edits_v3::PredictEditsRequest { + excerpt_path, + excerpt: String::new(), + excerpt_line_range: Line(0)..Line(0), + excerpt_range: 0..0, + cursor_point: predict_edits_v3::Point { + line: predict_edits_v3::Line(cursor_point.row), + column: cursor_point.column, + }, + included_files, + referenced_declarations: vec![], + events, + can_collect_data, + diagnostic_groups, + diagnostic_groups_truncated, + debug_info: debug_tx.is_some(), + prompt_max_bytes: Some(options.max_prompt_bytes), + prompt_format: options.prompt_format, + // TODO [zeta2] + signatures: vec![], + excerpt_parent: None, + git_info: None, + trigger, }; let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request); @@ -1787,18 +1830,17 @@ impl Zeta { } let get_buffer_from_context = |path: &Path| { - included_files - .iter() - .find_map(|(_, buffer, probe_path, ranges)| { - if probe_path.as_ref() == path { - Some((buffer, ranges.as_slice())) - } else { - None - } - }) + if Some(path) == active_file_full_path.as_deref() { + Some(( + &active_snapshot, + std::slice::from_ref(&excerpt_anchor_range), + )) + } else { + None + } }; - let (edited_buffer_snapshot, edits) = match options.prompt_format { + let (_, edits) = match options.prompt_format { PromptFormat::NumLinesUniDiff => { // TODO: Implement parsing of multi-file diffs crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? @@ -1822,24 +1864,13 @@ impl Zeta { } }; - let edited_buffer = included_files - .iter() - .find_map(|(buffer, snapshot, _, _)| { - if snapshot.remote_id() == edited_buffer_snapshot.remote_id() { - Some(buffer.clone()) - } else { - None - } - }) - .context("Failed to find buffer in included_buffers")?; - anyhow::Ok(( Some(( request_id, Some(( inputs, - edited_buffer, - edited_buffer_snapshot.clone(), + active_buffer, + active_snapshot.clone(), edits, received_response_at, )), @@ -2058,61 +2089,78 @@ impl Zeta { pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10); pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3); - // Refresh the related excerpts when the user just beguns editing after - // an idle period, and after they pause editing. - fn refresh_context_if_needed( + pub fn refresh_context_if_needed( &mut self, project: &Entity, buffer: &Entity, cursor_position: language::Anchor, cx: &mut Context, ) { - if !matches!(self.edit_prediction_model, ZetaEditPredictionModel::Zeta2) { + if !self.use_context { return; } - - if !matches!(&self.options().context, ContextMode::Agentic { .. }) { - return; - } - let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else { return; }; - let now = Instant::now(); - let was_idle = zeta_project - .refresh_context_timestamp - .map_or(true, |timestamp| { - now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION - }); - zeta_project.refresh_context_timestamp = Some(now); - zeta_project.refresh_context_debounce_task = Some(cx.spawn({ - let buffer = buffer.clone(); - let project = project.clone(); - async move |this, cx| { - if was_idle { - log::debug!("refetching edit prediction context after idle"); - } else { - cx.background_executor() - .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION) - .await; - log::debug!("refetching edit prediction context after pause"); - } - this.update(cx, |this, cx| { - let task = this.refresh_context(project.clone(), buffer, cursor_position, cx); + match &mut zeta_project.context { + ZetaProjectContext::Syntax(_entity) => {} + ZetaProjectContext::Lsp(related_excerpt_store) => { + related_excerpt_store.update(cx, |store, cx| { + store.refresh(buffer.clone(), cursor_position, cx); + }); + } + ZetaProjectContext::Agentic { + refresh_context_debounce_task, + refresh_context_timestamp, + .. + } => { + let now = Instant::now(); + let was_idle = refresh_context_timestamp.map_or(true, |timestamp| { + now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION + }); + *refresh_context_timestamp = Some(now); + *refresh_context_debounce_task = Some(cx.spawn({ + let buffer = buffer.clone(); + let project = project.clone(); + async move |this, cx| { + if was_idle { + log::debug!("refetching edit prediction context after idle"); + } else { + cx.background_executor() + .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION) + .await; + log::debug!("refetching edit prediction context after pause"); + } + this.update(cx, |this, cx| { + let task = this.refresh_context_with_agentic_retrieval( + project.clone(), + buffer, + cursor_position, + cx, + ); - if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) { - zeta_project.refresh_context_task = Some(task.log_err()); - }; - }) - .ok() + if let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) + { + if let ZetaProjectContext::Agentic { + refresh_context_task, + .. + } = &mut zeta_project.context + { + *refresh_context_task = Some(task.log_err()); + } + }; + }) + .ok() + } + })); } - })); + } } // Refresh the related excerpts asynchronously. Ensure the task runs to completion, // and avoid spawning more than one concurrent task. - pub fn refresh_context( + pub fn refresh_context_with_agentic_retrieval( &mut self, project: Entity, buffer: Entity, @@ -2162,12 +2210,14 @@ impl Zeta { } }; + let retrieval_started_at = Instant::now(); + if let Some(debug_tx) = &debug_tx { debug_tx .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted( ZetaContextRetrievalStartedDebugInfo { - project: project.clone(), - timestamp: Instant::now(), + project_entity_id: project.entity_id(), + timestamp: retrieval_started_at, search_prompt: prompt.clone(), }, )) @@ -2260,19 +2310,8 @@ impl Zeta { queries.extend(input.queries); } - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated( - ZetaSearchQueryDebugInfo { - project: project.clone(), - timestamp: Instant::now(), - search_queries: queries.clone(), - }, - )) - .ok(); - } - log::trace!("Running retrieval search: {queries:#?}"); + let query_generation_finished_at = Instant::now(); let related_excerpts_result = retrieval_search::run_retrieval_searches( queries, @@ -2284,54 +2323,62 @@ impl Zeta { .await; log::trace!("Search queries executed"); - - if let Some(debug_tx) = &debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted( - ZetaContextRetrievalDebugInfo { - project: project.clone(), - timestamp: Instant::now(), - }, - )) - .ok(); - } + let query_execution_finished_at = Instant::now(); this.update(cx, |this, _cx| { let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else { return Ok(()); }; - zeta_project.refresh_context_task.take(); - if let Some(debug_tx) = &this.debug_tx { - debug_tx - .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished( - ZetaContextRetrievalDebugInfo { - project, - timestamp: Instant::now(), - }, - )) - .ok(); - } - match related_excerpts_result { - Ok(excerpts) => { - zeta_project.context = Some(excerpts); - Ok(()) + if let ZetaProjectContext::Agentic { + refresh_context_task, + context, + .. + } = &mut zeta_project.context + { + refresh_context_task.take(); + if let Some(debug_tx) = &this.debug_tx { + debug_tx + .unbounded_send(ZetaDebugInfo::ContextRetrievalFinished( + ZetaContextRetrievalFinishedDebugInfo { + project_entity_id: project.entity_id(), + timestamp: Instant::now(), + metadata: vec![ + ( + "query_generation", + format!( + "{:?}", + query_generation_finished_at - retrieval_started_at + ) + .into(), + ), + ( + "search_execution", + format!( + "{:?}", + query_execution_finished_at + - query_generation_finished_at + ) + .into(), + ), + ], + }, + )) + .ok(); + } + match related_excerpts_result { + Ok(excerpts) => { + *context = excerpts; + Ok(()) + } + Err(error) => Err(error), } - Err(error) => Err(error), + } else { + Ok(()) } })? }) } - pub fn set_context( - &mut self, - project: Entity, - context: HashMap, Vec>>, - ) { - if let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) { - zeta_project.context = Some(context); - } - } - fn gather_nearby_diagnostics( cursor_offset: usize, diagnostic_sets: &[(LanguageServerId, DiagnosticSet)], @@ -2378,92 +2425,13 @@ impl Zeta { (results, diagnostic_groups_truncated) } - // TODO: Dedupe with similar code in request_prediction? - pub fn cloud_request_for_zeta_cli( - &mut self, - project: &Entity, - buffer: &Entity, - position: language::Anchor, - cx: &mut Context, - ) -> Task> { - let project_state = self.projects.get(&project.entity_id()); - - let index_state = project_state.and_then(|state| { - state - .syntax_index - .as_ref() - .map(|index| index.read_with(cx, |index, _cx| index.state().clone())) - }); - let options = self.options.clone(); - let snapshot = buffer.read(cx).snapshot(); - let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else { - return Task::ready(Err(anyhow!("No file path for excerpt"))); - }; - let worktree_snapshots = project - .read(cx) - .worktrees(cx) - .map(|worktree| worktree.read(cx).snapshot()) - .collect::>(); - - let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| { - let mut path = f.worktree.read(cx).absolutize(&f.path); - if path.pop() { Some(path) } else { None } - }); - - cx.background_spawn(async move { - let index_state = if let Some(index_state) = index_state { - Some(index_state.lock_owned().await) - } else { - None - }; - - let cursor_point = position.to_point(&snapshot); - - let debug_info = true; - EditPredictionContext::gather_context( - cursor_point, - &snapshot, - parent_abs_path.as_deref(), - match &options.context { - ContextMode::Agentic(_) => { - // TODO - panic!("Llm mode not supported in zeta cli yet"); - } - ContextMode::Syntax(edit_prediction_context_options) => { - edit_prediction_context_options - } - }, - index_state.as_deref(), - ) - .context("Failed to select excerpt") - .map(|context| { - make_syntax_context_cloud_request( - excerpt_path.into(), - context, - // TODO pass everything - Vec::new(), - false, - Vec::new(), - false, - None, - debug_info, - &worktree_snapshots, - index_state.as_deref(), - Some(options.max_prompt_bytes), - options.prompt_format, - PredictEditsRequestTrigger::Other, - ) - }) - }) - } - pub fn wait_for_initial_indexing( &mut self, project: &Entity, cx: &mut Context, ) -> Task> { let zeta_project = self.get_or_init_zeta_project(project, cx); - if let Some(syntax_index) = &zeta_project.syntax_index { + if let ZetaProjectContext::Syntax(syntax_index) = &zeta_project.context { syntax_index.read(cx).wait_for_initial_file_indexing(cx) } else { Task::ready(Ok(())) @@ -2555,6 +2523,11 @@ impl Zeta { self.client.telemetry().flush_events().detach(); cx.notify(); } + + fn enable_or_disable_context_retrieval(&mut self, cx: &mut Context<'_, Zeta>) { + self.use_context = cx.has_flag::() + && all_language_settings(None, cx).edit_predictions.use_context; + } } pub fn text_from_response(mut res: open_ai::Response) -> Option { @@ -2597,131 +2570,6 @@ pub struct ZedUpdateRequiredError { minimum_version: Version, } -fn make_syntax_context_cloud_request( - excerpt_path: Arc, - context: EditPredictionContext, - events: Vec>, - can_collect_data: bool, - diagnostic_groups: Vec, - diagnostic_groups_truncated: bool, - git_info: Option, - debug_info: bool, - worktrees: &Vec, - index_state: Option<&SyntaxIndexState>, - prompt_max_bytes: Option, - prompt_format: PromptFormat, - trigger: PredictEditsRequestTrigger, -) -> predict_edits_v3::PredictEditsRequest { - let mut signatures = Vec::new(); - let mut declaration_to_signature_index = HashMap::default(); - let mut referenced_declarations = Vec::new(); - - for snippet in context.declarations { - let project_entry_id = snippet.declaration.project_entry_id(); - let Some(path) = worktrees.iter().find_map(|worktree| { - worktree.entry_for_id(project_entry_id).map(|entry| { - let mut full_path = RelPathBuf::new(); - full_path.push(worktree.root_name()); - full_path.push(&entry.path); - full_path - }) - }) else { - continue; - }; - - let parent_index = index_state.and_then(|index_state| { - snippet.declaration.parent().and_then(|parent| { - add_signature( - parent, - &mut declaration_to_signature_index, - &mut signatures, - index_state, - ) - }) - }); - - let (text, text_is_truncated) = snippet.declaration.item_text(); - referenced_declarations.push(predict_edits_v3::ReferencedDeclaration { - path: path.as_std_path().into(), - text: text.into(), - range: snippet.declaration.item_line_range(), - text_is_truncated, - signature_range: snippet.declaration.signature_range_in_item_text(), - parent_index, - signature_score: snippet.score(DeclarationStyle::Signature), - declaration_score: snippet.score(DeclarationStyle::Declaration), - score_components: snippet.components, - }); - } - - let excerpt_parent = index_state.and_then(|index_state| { - context - .excerpt - .parent_declarations - .last() - .and_then(|(parent, _)| { - add_signature( - *parent, - &mut declaration_to_signature_index, - &mut signatures, - index_state, - ) - }) - }); - - predict_edits_v3::PredictEditsRequest { - excerpt_path, - excerpt: context.excerpt_text.body, - excerpt_line_range: context.excerpt.line_range, - excerpt_range: context.excerpt.range, - cursor_point: predict_edits_v3::Point { - line: predict_edits_v3::Line(context.cursor_point.row), - column: context.cursor_point.column, - }, - referenced_declarations, - included_files: vec![], - signatures, - excerpt_parent, - events, - can_collect_data, - diagnostic_groups, - diagnostic_groups_truncated, - git_info, - debug_info, - prompt_max_bytes, - prompt_format, - trigger, - } -} - -fn add_signature( - declaration_id: DeclarationId, - declaration_to_signature_index: &mut HashMap, - signatures: &mut Vec, - index: &SyntaxIndexState, -) -> Option { - if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) { - return Some(*signature_index); - } - let Some(parent_declaration) = index.declaration(declaration_id) else { - log::error!("bug: missing parent declaration"); - return None; - }; - let parent_index = parent_declaration.parent().and_then(|parent| { - add_signature(parent, declaration_to_signature_index, signatures, index) - }); - let (text, text_is_truncated) = parent_declaration.signature_text(); - let signature_index = signatures.len(); - signatures.push(Signature { - text: text.into(), - text_is_truncated, - parent_index, - range: parent_declaration.signature_line_range(), - }); - declaration_to_signature_index.insert(declaration_id, signature_index); - Some(signature_index) -} - #[cfg(feature = "eval-support")] pub type EvalCacheKey = (EvalCacheEntryKind, u64); @@ -2917,7 +2765,6 @@ mod tests { use cloud_llm_client::{ EditPredictionRejectReason, EditPredictionRejection, RejectEditPredictionsBody, }; - use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery}; use futures::{ AsyncReadExt, StreamExt, channel::{mpsc, oneshot}, @@ -2929,6 +2776,7 @@ mod tests { }; use indoc::indoc; use language::OffsetRangeExt as _; + use lsp::LanguageServerId; use open_ai::Usage; use pretty_assertions::{assert_eq, assert_matches}; use project::{FakeFs, Project}; @@ -2959,7 +2807,8 @@ mod tests { let buffer1 = project .update(cx, |project, cx| { - let path = project.find_project_path(path!("root/1.txt"), cx).unwrap(); + let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap(); + project.set_active_path(Some(path.clone()), cx); project.open_buffer(path, cx) }) .await @@ -2995,58 +2844,38 @@ mod tests { assert_matches!(prediction, BufferEditPrediction::Local { .. }); }); - // Context refresh - let refresh_task = zeta.update(cx, |zeta, cx| { - zeta.refresh_context(project.clone(), buffer1.clone(), position, cx) - }); - let (_request, respond_tx) = requests.predict.next().await.unwrap(); - respond_tx - .send(open_ai::Response { - id: Uuid::new_v4().to_string(), - object: "response".into(), - created: 0, - model: "model".into(), - choices: vec![open_ai::Choice { - index: 0, - message: open_ai::RequestMessage::Assistant { - content: None, - tool_calls: vec![open_ai::ToolCall { - id: "search".into(), - content: open_ai::ToolCallContent::Function { - function: open_ai::FunctionContent { - name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME - .to_string(), - arguments: serde_json::to_string(&SearchToolInput { - queries: Box::new([SearchToolQuery { - glob: "root/2.txt".to_string(), - syntax_node: vec![], - content: Some(".".into()), - }]), - }) - .unwrap(), - }, - }, - }], - }, - finish_reason: None, - }], - usage: Usage { - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - }, - }) - .unwrap(); - refresh_task.await.unwrap(); - zeta.update(cx, |zeta, _cx| { zeta.reject_current_prediction(EditPredictionRejectReason::Discarded, &project); }); - // Prediction for another file - zeta.update(cx, |zeta, cx| { - zeta.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) + // Prediction for diagnostic in another file + + let diagnostic = lsp::Diagnostic { + range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)), + severity: Some(lsp::DiagnosticSeverity::ERROR), + message: "Sentence is incomplete".to_string(), + ..Default::default() + }; + + project.update(cx, |project, cx| { + project.lsp_store().update(cx, |lsp_store, cx| { + lsp_store + .update_diagnostics( + LanguageServerId(0), + lsp::PublishDiagnosticsParams { + uri: lsp::Uri::from_file_path(path!("/root/2.txt")).unwrap(), + diagnostics: vec![diagnostic], + version: None, + }, + None, + language::DiagnosticSourceKind::Pushed, + &[], + cx, + ) + .unwrap(); + }); }); + let (_request, respond_tx) = requests.predict.next().await.unwrap(); respond_tx .send(model_response(indoc! {r#" @@ -4018,7 +3847,6 @@ mod tests { let mut buf = Vec::new(); body.read_to_end(&mut buf).await.ok(); let req = serde_json::from_slice(&buf).unwrap(); - let (res_tx, res_rx) = oneshot::channel(); predict_req_tx.unbounded_send((req, res_tx)).unwrap(); serde_json::to_string(&res_rx.await?).unwrap() diff --git a/crates/zeta2_tools/Cargo.toml b/crates/zeta2_tools/Cargo.toml index 607e24c895d96de1464ff1bfa2a4dfa01c5d9669..8e20224736c658d4d80d678b29d4231ec7e4b2f5 100644 --- a/crates/zeta2_tools/Cargo.toml +++ b/crates/zeta2_tools/Cargo.toml @@ -15,7 +15,6 @@ path = "src/zeta2_tools.rs" anyhow.workspace = true client.workspace = true cloud_llm_client.workspace = true -cloud_zeta2_prompt.workspace = true collections.workspace = true edit_prediction_context.workspace = true editor.workspace = true diff --git a/crates/zeta2_tools/src/zeta2_context_view.rs b/crates/zeta2_tools/src/zeta2_context_view.rs index 54f1ea2d813f7c00d30b12e341fb3e5ac3f155dc..882846929a62f90f349d40f8f6b6996f83613ec7 100644 --- a/crates/zeta2_tools/src/zeta2_context_view.rs +++ b/crates/zeta2_tools/src/zeta2_context_view.rs @@ -8,26 +8,25 @@ use std::{ use anyhow::Result; use client::{Client, UserStore}; -use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery; use editor::{Editor, PathKey}; use futures::StreamExt as _; use gpui::{ Animation, AnimationExt, App, AppContext as _, Context, Entity, EventEmitter, FocusHandle, - Focusable, ParentElement as _, SharedString, Styled as _, Task, TextAlign, Window, actions, - pulsating_between, + Focusable, InteractiveElement as _, IntoElement as _, ParentElement as _, SharedString, + Styled as _, Task, TextAlign, Window, actions, div, pulsating_between, }; use multi_buffer::MultiBuffer; use project::Project; use text::OffsetRangeExt; use ui::{ - ButtonCommon, Clickable, Color, Disableable, FluentBuilder as _, Icon, IconButton, IconName, - IconSize, InteractiveElement, IntoElement, ListHeader, ListItem, StyledTypography, div, h_flex, - v_flex, + ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName, + StyledTypography as _, h_flex, v_flex, }; + use workspace::Item; use zeta::{ - Zeta, ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo, - ZetaSearchQueryDebugInfo, + Zeta, ZetaContextRetrievalFinishedDebugInfo, ZetaContextRetrievalStartedDebugInfo, + ZetaDebugInfo, }; pub struct Zeta2ContextView { @@ -42,10 +41,8 @@ pub struct Zeta2ContextView { #[derive(Debug)] struct RetrievalRun { editor: Entity, - search_queries: Vec, started_at: Instant, - search_results_generated_at: Option, - search_results_executed_at: Option, + metadata: Vec<(&'static str, SharedString)>, finished_at: Option, } @@ -97,22 +94,12 @@ impl Zeta2ContextView { ) { match event { ZetaDebugInfo::ContextRetrievalStarted(info) => { - if info.project == self.project { + if info.project_entity_id == self.project.entity_id() { self.handle_context_retrieval_started(info, window, cx); } } - ZetaDebugInfo::SearchQueriesGenerated(info) => { - if info.project == self.project { - self.handle_search_queries_generated(info, window, cx); - } - } - ZetaDebugInfo::SearchQueriesExecuted(info) => { - if info.project == self.project { - self.handle_search_queries_executed(info, window, cx); - } - } ZetaDebugInfo::ContextRetrievalFinished(info) => { - if info.project == self.project { + if info.project_entity_id == self.project.entity_id() { self.handle_context_retrieval_finished(info, window, cx); } } @@ -129,7 +116,7 @@ impl Zeta2ContextView { if self .runs .back() - .is_some_and(|run| run.search_results_executed_at.is_none()) + .is_some_and(|run| run.finished_at.is_none()) { self.runs.pop_back(); } @@ -144,11 +131,9 @@ impl Zeta2ContextView { self.runs.push_back(RetrievalRun { editor, - search_queries: Vec::new(), started_at: info.timestamp, - search_results_generated_at: None, - search_results_executed_at: None, finished_at: None, + metadata: Vec::new(), }); cx.notify(); @@ -156,7 +141,7 @@ impl Zeta2ContextView { fn handle_context_retrieval_finished( &mut self, - info: ZetaContextRetrievalDebugInfo, + info: ZetaContextRetrievalFinishedDebugInfo, window: &mut Window, cx: &mut Context, ) { @@ -165,67 +150,72 @@ impl Zeta2ContextView { }; run.finished_at = Some(info.timestamp); + run.metadata = info.metadata; + + let project = self.project.clone(); + let related_files = self + .zeta + .read(cx) + .context_for_project(&self.project, cx) + .to_vec(); + let editor = run.editor.clone(); let multibuffer = run.editor.read(cx).buffer().clone(); - multibuffer.update(cx, |multibuffer, cx| { - multibuffer.clear(cx); - let context = self.zeta.read(cx).context_for_project(&self.project); - let mut paths = Vec::new(); - for (buffer, ranges) in context { - let path = PathKey::for_buffer(&buffer, cx); - let snapshot = buffer.read(cx).snapshot(); - let ranges = ranges - .iter() - .map(|range| range.to_point(&snapshot)) - .collect::>(); - paths.push((path, buffer, ranges)); - } + if self.current_ix + 2 == self.runs.len() { + self.current_ix += 1; + } - for (path, buffer, ranges) in paths { - multibuffer.set_excerpts_for_path(path, buffer, ranges, 0, cx); + cx.spawn_in(window, async move |this, cx| { + let mut paths = Vec::new(); + for related_file in related_files { + let (buffer, point_ranges): (_, Vec<_>) = + if let Some(buffer) = related_file.buffer.upgrade() { + let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; + + ( + buffer, + related_file + .excerpts + .iter() + .map(|excerpt| excerpt.anchor_range.to_point(&snapshot)) + .collect(), + ) + } else { + ( + project + .update(cx, |project, cx| { + project.open_buffer(related_file.path.clone(), cx) + })? + .await?, + related_file + .excerpts + .iter() + .map(|excerpt| excerpt.point_range.clone()) + .collect(), + ) + }; + cx.update(|_, cx| { + let path = PathKey::for_buffer(&buffer, cx); + paths.push((path, buffer, point_ranges)); + })?; } - }); - - run.editor.update(cx, |editor, cx| { - editor.move_to_beginning(&Default::default(), window, cx); - }); - - cx.notify(); - } - - fn handle_search_queries_generated( - &mut self, - info: ZetaSearchQueryDebugInfo, - _window: &mut Window, - cx: &mut Context, - ) { - let Some(run) = self.runs.back_mut() else { - return; - }; - run.search_results_generated_at = Some(info.timestamp); - run.search_queries = info.search_queries; - cx.notify(); - } + multibuffer.update(cx, |multibuffer, cx| { + multibuffer.clear(cx); - fn handle_search_queries_executed( - &mut self, - info: ZetaContextRetrievalDebugInfo, - _window: &mut Window, - cx: &mut Context, - ) { - if self.current_ix + 2 == self.runs.len() { - // Switch to latest when the queries are executed - self.current_ix += 1; - } + for (path, buffer, ranges) in paths { + multibuffer.set_excerpts_for_path(path, buffer, ranges, 0, cx); + } + })?; - let Some(run) = self.runs.back_mut() else { - return; - }; + editor.update_in(cx, |editor, window, cx| { + editor.move_to_beginning(&Default::default(), window, cx); + })?; - run.search_results_executed_at = Some(info.timestamp); - cx.notify(); + this.update(cx, |_, cx| cx.notify()) + }) + .detach(); } fn handle_go_back( @@ -254,8 +244,11 @@ impl Zeta2ContextView { } fn render_informational_footer(&self, cx: &mut Context<'_, Zeta2ContextView>) -> ui::Div { - let is_latest = self.runs.len() == self.current_ix + 1; let run = &self.runs[self.current_ix]; + let new_run_started = self + .runs + .back() + .map_or(false, |latest_run| latest_run.finished_at.is_none()); h_flex() .p_2() @@ -264,114 +257,65 @@ impl Zeta2ContextView { .text_xs() .border_t_1() .gap_2() + .child(v_flex().h_full().flex_1().child({ + let t0 = run.started_at; + let mut table = ui::Table::<2>::new().width(ui::px(300.)).no_ui_font(); + for (key, value) in &run.metadata { + table = table.row([key.into_any_element(), value.clone().into_any_element()]) + } + table = table.row([ + "Total Time".into_any_element(), + format!("{} ms", (run.finished_at.unwrap_or(t0) - t0).as_millis()) + .into_any_element(), + ]); + table + })) .child( - v_flex().h_full().flex_1().children( - run.search_queries - .iter() - .enumerate() - .flat_map(|(ix, query)| { - std::iter::once(ListHeader::new(query.glob.clone()).into_any_element()) - .chain(query.syntax_node.iter().enumerate().map( - move |(regex_ix, regex)| { - ListItem::new(ix * 100 + regex_ix) - .start_slot( - Icon::new(IconName::MagnifyingGlass) - .color(Color::Muted) - .size(IconSize::Small), - ) - .child(regex.clone()) - .into_any_element() - }, + v_flex().h_full().text_align(TextAlign::Right).child( + h_flex() + .justify_end() + .child( + IconButton::new("go-back", IconName::ChevronLeft) + .disabled(self.current_ix == 0 || self.runs.len() < 2) + .tooltip(ui::Tooltip::for_action_title( + "Go to previous run", + &Zeta2ContextGoBack, )) - .chain(query.content.as_ref().map(move |regex| { - ListItem::new(ix * 100 + query.syntax_node.len()) - .start_slot( - Icon::new(IconName::MagnifyingGlass) - .color(Color::Muted) - .size(IconSize::Small), + .on_click(cx.listener(|this, _, window, cx| { + this.handle_go_back(&Zeta2ContextGoBack, window, cx); + })), + ) + .child( + div() + .child(format!("{}/{}", self.current_ix + 1, self.runs.len())) + .map(|this| { + if new_run_started { + this.with_animation( + "pulsating-count", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.4, 0.8)), + |label, delta| label.opacity(delta), ) - .child(regex.clone()) .into_any_element() - })) - }), + } else { + this.into_any_element() + } + }), + ) + .child( + IconButton::new("go-forward", IconName::ChevronRight) + .disabled(self.current_ix + 1 == self.runs.len()) + .tooltip(ui::Tooltip::for_action_title( + "Go to next run", + &Zeta2ContextGoBack, + )) + .on_click(cx.listener(|this, _, window, cx| { + this.handle_go_forward(&Zeta2ContextGoForward, window, cx); + })), + ), ), ) - .child( - v_flex() - .h_full() - .text_align(TextAlign::Right) - .child( - h_flex() - .justify_end() - .child( - IconButton::new("go-back", IconName::ChevronLeft) - .disabled(self.current_ix == 0 || self.runs.len() < 2) - .tooltip(ui::Tooltip::for_action_title( - "Go to previous run", - &Zeta2ContextGoBack, - )) - .on_click(cx.listener(|this, _, window, cx| { - this.handle_go_back(&Zeta2ContextGoBack, window, cx); - })), - ) - .child( - div() - .child(format!("{}/{}", self.current_ix + 1, self.runs.len())) - .map(|this| { - if self.runs.back().is_some_and(|back| { - back.search_results_executed_at.is_none() - }) { - this.with_animation( - "pulsating-count", - Animation::new(Duration::from_secs(2)) - .repeat() - .with_easing(pulsating_between(0.4, 0.8)), - |label, delta| label.opacity(delta), - ) - .into_any_element() - } else { - this.into_any_element() - } - }), - ) - .child( - IconButton::new("go-forward", IconName::ChevronRight) - .disabled(self.current_ix + 1 == self.runs.len()) - .tooltip(ui::Tooltip::for_action_title( - "Go to next run", - &Zeta2ContextGoBack, - )) - .on_click(cx.listener(|this, _, window, cx| { - this.handle_go_forward(&Zeta2ContextGoForward, window, cx); - })), - ), - ) - .map(|mut div| { - let pending_message = |div: ui::Div, msg: &'static str| { - if is_latest { - return div.child(msg); - } else { - return div.child("Canceled"); - } - }; - - let t0 = run.started_at; - let Some(t1) = run.search_results_generated_at else { - return pending_message(div, "Planning search..."); - }; - div = div.child(format!("Planned search: {:>5} ms", (t1 - t0).as_millis())); - - let Some(t2) = run.search_results_executed_at else { - return pending_message(div, "Running search..."); - }; - div = div.child(format!("Ran search: {:>5} ms", (t2 - t1).as_millis())); - - div.child(format!( - "Total: {:>5} ms", - (run.finished_at.unwrap_or(t0) - t0).as_millis() - )) - }), - ) } } diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 4e650f2405d63feab010c5c9b73efc75bd576af6..26d68b075153557ab50ed0a231c5d45f0bb9646c 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -108,6 +108,7 @@ pub struct Zeta2Inspector { pub enum ContextModeState { Llm, + Lsp, Syntax { max_retrieved_declarations: Entity, }, @@ -222,6 +223,9 @@ impl Zeta2Inspector { ), }; } + ContextMode::Lsp(_) => { + self.context_mode = ContextModeState::Lsp; + } } cx.notify(); } @@ -302,6 +306,9 @@ impl Zeta2Inspector { ContextModeState::Syntax { max_retrieved_declarations, } => number_input_value(max_retrieved_declarations, cx), + ContextModeState::Lsp => { + zeta::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations + } }; ContextMode::Syntax(EditPredictionContextOptions { @@ -310,6 +317,7 @@ impl Zeta2Inspector { ..context_options }) } + ContextMode::Lsp(excerpt_options) => ContextMode::Lsp(excerpt_options), }; this.set_zeta_options( @@ -656,6 +664,7 @@ impl Zeta2Inspector { ContextModeState::Syntax { max_retrieved_declarations, } => Some(max_retrieved_declarations.clone()), + ContextModeState::Lsp => None, }) .child(self.max_prompt_bytes_input.clone()) .child(self.render_prompt_format_dropdown(window, cx)), @@ -679,6 +688,7 @@ impl Zeta2Inspector { match &self.context_mode { ContextModeState::Llm => "LLM-based", ContextModeState::Syntax { .. } => "Syntax", + ContextModeState::Lsp => "LSP-based", }, ContextMenu::build(window, cx, move |menu, _window, _cx| { menu.item( @@ -695,6 +705,7 @@ impl Zeta2Inspector { this.zeta.read(cx).options().clone(); match current_options.context.clone() { ContextMode::Agentic(_) => {} + ContextMode::Lsp(_) => {} ContextMode::Syntax(context_options) => { let options = ZetaOptions { context: ContextMode::Agentic( @@ -739,6 +750,7 @@ impl Zeta2Inspector { this.set_zeta_options(options, cx); } ContextMode::Syntax(_) => {} + ContextMode::Lsp(_) => {} } }) .ok(); diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index a9c8b5998cdd32a94a71f1894dfbdc40c22abaed..42c0ea185f4401a11c2798f9402a59829f8463df 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -21,15 +21,12 @@ use ::util::paths::PathStyle; use anyhow::{Result, anyhow}; use clap::{Args, Parser, Subcommand, ValueEnum}; use cloud_llm_client::predict_edits_v3; -use edit_prediction_context::{ - EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions, -}; +use edit_prediction_context::EditPredictionExcerptOptions; use gpui::{Application, AsyncApp, Entity, prelude::*}; use language::{Bias, Buffer, BufferSnapshot, Point}; use metrics::delta_chr_f; -use project::{Project, Worktree}; +use project::{Project, Worktree, lsp_store::OpenLspBufferHandle}; use reqwest_client::ReqwestClient; -use serde_json::json; use std::io::{self}; use std::time::Duration; use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc}; @@ -97,7 +94,7 @@ struct ContextArgs { enum ContextProvider { Zeta1, #[default] - Syntax, + Zeta2, } #[derive(Clone, Debug, Args)] @@ -204,19 +201,12 @@ enum PredictionProvider { Sweep, } -fn zeta2_args_to_options(args: &Zeta2Args, omit_excerpt_overlaps: bool) -> zeta::ZetaOptions { +fn zeta2_args_to_options(args: &Zeta2Args) -> zeta::ZetaOptions { zeta::ZetaOptions { - context: ContextMode::Syntax(EditPredictionContextOptions { - max_retrieved_declarations: args.max_retrieved_definitions, - use_imports: !args.disable_imports_gathering, - excerpt: EditPredictionExcerptOptions { - max_bytes: args.max_excerpt_bytes, - min_bytes: args.min_excerpt_bytes, - target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes, - }, - score: EditPredictionScoreOptions { - omit_excerpt_overlaps, - }, + context: ContextMode::Lsp(EditPredictionExcerptOptions { + max_bytes: args.max_excerpt_bytes, + min_bytes: args.min_excerpt_bytes, + target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes, }), max_diagnostic_bytes: args.max_diagnostic_bytes, max_prompt_bytes: args.max_prompt_bytes, @@ -295,6 +285,7 @@ struct LoadedContext { worktree: Entity, project: Entity, buffer: Entity, + lsp_open_handle: Option, } async fn load_context( @@ -330,7 +321,7 @@ async fn load_context( .await?; let mut ready_languages = HashSet::default(); - let (_lsp_open_handle, buffer) = if *use_language_server { + let (lsp_open_handle, buffer) = if *use_language_server { let (lsp_open_handle, _, buffer) = open_buffer_with_language_server( project.clone(), worktree.clone(), @@ -377,10 +368,11 @@ async fn load_context( worktree, project, buffer, + lsp_open_handle, }) } -async fn zeta2_syntax_context( +async fn zeta2_context( args: ContextArgs, app_state: &Arc, cx: &mut AsyncApp, @@ -390,6 +382,7 @@ async fn zeta2_syntax_context( project, buffer, clipped_cursor, + lsp_open_handle: _handle, .. } = load_context(&args, app_state, cx).await?; @@ -406,30 +399,26 @@ async fn zeta2_syntax_context( zeta::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx) }); let indexing_done_task = zeta.update(cx, |zeta, cx| { - zeta.set_options(zeta2_args_to_options(&args.zeta2_args, true)); + zeta.set_options(zeta2_args_to_options(&args.zeta2_args)); zeta.register_buffer(&buffer, &project, cx); zeta.wait_for_initial_indexing(&project, cx) }); cx.spawn(async move |cx| { indexing_done_task.await?; - let request = zeta - .update(cx, |zeta, cx| { - let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor); - zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx) - })? - .await?; - - let (prompt_string, section_labels) = cloud_zeta2_prompt::build_prompt(&request)?; - - match args.zeta2_args.output_format { - OutputFormat::Prompt => anyhow::Ok(prompt_string), - OutputFormat::Request => anyhow::Ok(serde_json::to_string_pretty(&request)?), - OutputFormat::Full => anyhow::Ok(serde_json::to_string_pretty(&json!({ - "request": request, - "prompt": prompt_string, - "section_labels": section_labels, - }))?), - } + let updates_rx = zeta.update(cx, |zeta, cx| { + let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor); + zeta.set_use_context(true); + zeta.refresh_context_if_needed(&project, &buffer, cursor, cx); + zeta.project_context_updates(&project).unwrap() + })?; + + updates_rx.recv().await.ok(); + + let context = zeta.update(cx, |zeta, cx| { + zeta.context_for_project(&project, cx).to_vec() + })?; + + anyhow::Ok(serde_json::to_string_pretty(&context).unwrap()) }) })? .await?; @@ -482,7 +471,6 @@ fn main() { None => { if args.printenv { ::util::shell_env::print_env(); - return; } else { panic!("Expected a command"); } @@ -494,7 +482,7 @@ fn main() { arguments.extension, arguments.limit, arguments.skip, - zeta2_args_to_options(&arguments.zeta2_args, false), + zeta2_args_to_options(&arguments.zeta2_args), cx, ) .await; @@ -507,10 +495,8 @@ fn main() { zeta1_context(context_args, &app_state, cx).await.unwrap(); serde_json::to_string_pretty(&context.body).unwrap() } - ContextProvider::Syntax => { - zeta2_syntax_context(context_args, &app_state, cx) - .await - .unwrap() + ContextProvider::Zeta2 => { + zeta2_context(context_args, &app_state, cx).await.unwrap() } }; println!("{}", result); diff --git a/crates/zeta_cli/src/predict.rs b/crates/zeta_cli/src/predict.rs index 99fe65cfa3221a1deb18e767e8faa8e1a1fca0ac..9fefc5ce97672796f79482e23acca3599aa1ff44 100644 --- a/crates/zeta_cli/src/predict.rs +++ b/crates/zeta_cli/src/predict.rs @@ -136,8 +136,7 @@ pub async fn perform_predict( let result = result.clone(); async move { let mut start_time = None; - let mut search_queries_generated_at = None; - let mut search_queries_executed_at = None; + let mut retrieval_finished_at = None; while let Some(event) = debug_rx.next().await { match event { zeta::ZetaDebugInfo::ContextRetrievalStarted(info) => { @@ -147,17 +146,17 @@ pub async fn perform_predict( &info.search_prompt, )?; } - zeta::ZetaDebugInfo::SearchQueriesGenerated(info) => { - search_queries_generated_at = Some(info.timestamp); - fs::write( - example_run_dir.join("search_queries.json"), - serde_json::to_string_pretty(&info.search_queries).unwrap(), - )?; - } - zeta::ZetaDebugInfo::SearchQueriesExecuted(info) => { - search_queries_executed_at = Some(info.timestamp); + zeta::ZetaDebugInfo::ContextRetrievalFinished(info) => { + retrieval_finished_at = Some(info.timestamp); + for (key, value) in &info.metadata { + if *key == "search_queries" { + fs::write( + example_run_dir.join("search_queries.json"), + value.as_bytes(), + )?; + } + } } - zeta::ZetaDebugInfo::ContextRetrievalFinished(_info) => {} zeta::ZetaDebugInfo::EditPredictionRequested(request) => { let prediction_started_at = Instant::now(); start_time.get_or_insert(prediction_started_at); @@ -200,13 +199,8 @@ pub async fn perform_predict( let mut result = result.lock().unwrap(); result.generated_len = response.chars().count(); - - result.planning_search_time = - Some(search_queries_generated_at.unwrap() - start_time.unwrap()); - result.running_search_time = Some( - search_queries_executed_at.unwrap() - - search_queries_generated_at.unwrap(), - ); + result.retrieval_time = + retrieval_finished_at.unwrap() - start_time.unwrap(); result.prediction_time = prediction_finished_at - prediction_started_at; result.total_time = prediction_finished_at - start_time.unwrap(); @@ -219,7 +213,12 @@ pub async fn perform_predict( }); zeta.update(cx, |zeta, cx| { - zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx) + zeta.refresh_context_with_agentic_retrieval( + project.clone(), + cursor_buffer.clone(), + cursor_anchor, + cx, + ) })? .await?; } @@ -321,8 +320,7 @@ pub struct PredictionDetails { pub diff: String, pub excerpts: Vec, pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly - pub planning_search_time: Option, - pub running_search_time: Option, + pub retrieval_time: Duration, pub prediction_time: Duration, pub total_time: Duration, pub run_example_dir: PathBuf, @@ -336,8 +334,7 @@ impl PredictionDetails { diff: Default::default(), excerpts: Default::default(), excerpts_text: Default::default(), - planning_search_time: Default::default(), - running_search_time: Default::default(), + retrieval_time: Default::default(), prediction_time: Default::default(), total_time: Default::default(), run_example_dir, @@ -357,28 +354,20 @@ impl PredictionDetails { } pub fn to_markdown(&self) -> String { - let inference_time = self.planning_search_time.unwrap_or_default() + self.prediction_time; - format!( "## Excerpts\n\n\ {}\n\n\ ## Prediction\n\n\ {}\n\n\ ## Time\n\n\ - Planning searches: {}ms\n\ - Running searches: {}ms\n\ - Making Prediction: {}ms\n\n\ - -------------------\n\n\ - Total: {}ms\n\ - Inference: {}ms ({:.2}%)\n", + Retrieval: {}ms\n\ + Prediction: {}ms\n\n\ + Total: {}ms\n", self.excerpts_text, self.diff, - self.planning_search_time.unwrap_or_default().as_millis(), - self.running_search_time.unwrap_or_default().as_millis(), + self.retrieval_time.as_millis(), self.prediction_time.as_millis(), self.total_time.as_millis(), - inference_time.as_millis(), - (inference_time.as_millis() as f64 / self.total_time.as_millis() as f64) * 100. ) } } diff --git a/crates/zeta_cli/src/util.rs b/crates/zeta_cli/src/util.rs index 699c1c743f67e09ef5ca7211c385114802d4ab32..f4a51d94585f82da008ac832dc62392c365738fd 100644 --- a/crates/zeta_cli/src/util.rs +++ b/crates/zeta_cli/src/util.rs @@ -2,7 +2,8 @@ use anyhow::{Result, anyhow}; use futures::channel::mpsc; use futures::{FutureExt as _, StreamExt as _}; use gpui::{AsyncApp, Entity, Task}; -use language::{Buffer, LanguageId, LanguageServerId, ParseStatus}; +use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus}; +use project::lsp_store::OpenLspBufferHandle; use project::{Project, ProjectPath, Worktree}; use std::collections::HashSet; use std::sync::Arc; @@ -40,7 +41,7 @@ pub async fn open_buffer_with_language_server( path: Arc, ready_languages: &mut HashSet, cx: &mut AsyncApp, -) -> Result<(Entity>, LanguageServerId, Entity)> { +) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity)> { let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?; let (lsp_open_handle, path_style) = project.update(cx, |project, cx| { @@ -50,6 +51,17 @@ pub async fn open_buffer_with_language_server( ) })?; + let language_registry = project.read_with(cx, |project, _| project.languages().clone())?; + let result = language_registry + .load_language_for_file_path(path.as_std_path()) + .await; + + if let Err(error) = result + && !error.is::() + { + anyhow::bail!(error); + } + let Some(language_id) = buffer.read_with(cx, |buffer, _cx| { buffer.language().map(|language| language.id()) })? @@ -57,9 +69,9 @@ pub async fn open_buffer_with_language_server( return Err(anyhow!("No language for {}", path.display(path_style))); }; - let log_prefix = path.display(path_style); + let log_prefix = format!("{} | ", path.display(path_style)); if !ready_languages.contains(&language_id) { - wait_for_lang_server(&project, &buffer, log_prefix.into_owned(), cx).await?; + wait_for_lang_server(&project, &buffer, log_prefix, cx).await?; ready_languages.insert(language_id); } @@ -95,7 +107,7 @@ pub fn wait_for_lang_server( log_prefix: String, cx: &mut AsyncApp, ) -> Task> { - println!("{}⏵ Waiting for language server", log_prefix); + eprintln!("{}⏵ Waiting for language server", log_prefix); let (mut tx, mut rx) = mpsc::channel(1); @@ -137,7 +149,7 @@ pub fn wait_for_lang_server( .. } = event { - println!("{}⟲ {message}", log_prefix) + eprintln!("{}⟲ {message}", log_prefix) } } }), @@ -162,7 +174,7 @@ pub fn wait_for_lang_server( cx.spawn(async move |cx| { if !has_lang_server { // some buffers never have a language server, so this aborts quickly in that case. - let timeout = cx.background_executor().timer(Duration::from_secs(5)); + let timeout = cx.background_executor().timer(Duration::from_secs(500)); futures::select! { _ = added_rx.next() => {}, _ = timeout.fuse() => { @@ -173,7 +185,7 @@ pub fn wait_for_lang_server( let timeout = cx.background_executor().timer(Duration::from_secs(60 * 5)); let result = futures::select! { _ = rx.next() => { - println!("{}⚑ Language server idle", log_prefix); + eprintln!("{}⚑ Language server idle", log_prefix); anyhow::Ok(()) }, _ = timeout.fuse() => {