Detailed changes
@@ -21640,6 +21640,7 @@ dependencies = [
"clock",
"cloud_llm_client",
"cloud_zeta2_prompt",
+ "collections",
"edit_prediction",
"edit_prediction_context",
"feature_flags",
@@ -21653,6 +21654,7 @@ dependencies = [
"pretty_assertions",
"project",
"release_channel",
+ "schemars 1.0.4",
"serde",
"serde_json",
"settings",
@@ -23,7 +23,11 @@ pub struct PredictEditsRequest {
pub cursor_point: Point,
/// Within `signatures`
pub excerpt_parent: Option<usize>,
+ #[serde(skip_serializing_if = "Vec::is_empty", default)]
+ pub included_files: Vec<IncludedFile>,
+ #[serde(skip_serializing_if = "Vec::is_empty", default)]
pub signatures: Vec<Signature>,
+ #[serde(skip_serializing_if = "Vec::is_empty", default)]
pub referenced_declarations: Vec<ReferencedDeclaration>,
pub events: Vec<Event>,
#[serde(default)]
@@ -44,6 +48,19 @@ pub struct PredictEditsRequest {
pub prompt_format: PromptFormat,
}
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct IncludedFile {
+ pub path: Arc<Path>,
+ pub max_row: Line,
+ pub excerpts: Vec<Excerpt>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Excerpt {
+ pub start_line: Line,
+ pub text: Arc<str>,
+}
+
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum PromptFormat {
MarkedExcerpt,
@@ -1,11 +1,14 @@
//! Zeta2 prompt planning and generation code shared with cloud.
use anyhow::{Context as _, Result, anyhow};
-use cloud_llm_client::predict_edits_v3::{self, Line, Point, PromptFormat, ReferencedDeclaration};
+use cloud_llm_client::predict_edits_v3::{
+ self, Excerpt, Line, Point, PromptFormat, ReferencedDeclaration,
+};
use indoc::indoc;
use ordered_float::OrderedFloat;
use rustc_hash::{FxHashMap, FxHashSet};
use serde::Serialize;
+use std::cmp;
use std::fmt::Write;
use std::sync::Arc;
use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
@@ -96,7 +99,177 @@ const UNIFIED_DIFF_REMINDER: &str = indoc! {"
If you're editing multiple files, be sure to reflect filename in the hunk's header.
"};
-pub struct PlannedPrompt<'a> {
+pub fn build_prompt(
+ request: &predict_edits_v3::PredictEditsRequest,
+) -> Result<(String, SectionLabels)> {
+ let mut insertions = match request.prompt_format {
+ PromptFormat::MarkedExcerpt => vec![
+ (
+ Point {
+ line: request.excerpt_line_range.start,
+ column: 0,
+ },
+ EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
+ ),
+ (request.cursor_point, CURSOR_MARKER),
+ (
+ Point {
+ line: request.excerpt_line_range.end,
+ column: 0,
+ },
+ EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
+ ),
+ ],
+ PromptFormat::LabeledSections => vec![(request.cursor_point, CURSOR_MARKER)],
+ PromptFormat::NumLinesUniDiff => {
+ vec![(request.cursor_point, CURSOR_MARKER)]
+ }
+ PromptFormat::OnlySnippets => vec![],
+ };
+
+ let mut prompt = match request.prompt_format {
+ PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(),
+ PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(),
+ PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(),
+ // only intended for use via zeta_cli
+ PromptFormat::OnlySnippets => String::new(),
+ };
+
+ if request.events.is_empty() {
+ prompt.push_str("(No edit history)\n\n");
+ } else {
+ prompt.push_str(
+ "The following are the latest edits made by the user, from earlier to later.\n\n",
+ );
+ push_events(&mut prompt, &request.events);
+ }
+
+ if request.prompt_format == PromptFormat::NumLinesUniDiff {
+ if request.referenced_declarations.is_empty() {
+ prompt.push_str(indoc! {"
+ # File under the cursor:
+
+ The cursor marker <|user_cursor|> indicates the current user cursor position.
+ The file is in current state, edits from edit history have been applied.
+ We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
+
+ "});
+ } else {
+ // Note: This hasn't been trained on yet
+ prompt.push_str(indoc! {"
+ # Code Excerpts:
+
+ The cursor marker <|user_cursor|> indicates the current user cursor position.
+ Other excerpts of code from the project have been included as context based on their similarity to the code under the cursor.
+ Context excerpts are not guaranteed to be relevant, so use your own judgement.
+ Files are in their current state, edits from edit history have been applied.
+ We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
+
+ "});
+ }
+ } else {
+ prompt.push_str("\n## Code\n\n");
+ }
+
+ let mut section_labels = Default::default();
+
+ if !request.referenced_declarations.is_empty() || !request.signatures.is_empty() {
+ let syntax_based_prompt = SyntaxBasedPrompt::populate(request)?;
+ section_labels = syntax_based_prompt.write(&mut insertions, &mut prompt)?;
+ } else {
+ if request.prompt_format == PromptFormat::LabeledSections {
+ anyhow::bail!("PromptFormat::LabeledSections cannot be used with ContextMode::Llm");
+ }
+
+ for related_file in &request.included_files {
+ writeln!(&mut prompt, "`````filename={}", related_file.path.display()).unwrap();
+ write_excerpts(
+ &related_file.excerpts,
+ if related_file.path == request.excerpt_path {
+ &insertions
+ } else {
+ &[]
+ },
+ related_file.max_row,
+ request.prompt_format == PromptFormat::NumLinesUniDiff,
+ &mut prompt,
+ );
+ write!(&mut prompt, "`````\n\n").unwrap();
+ }
+ }
+
+ if request.prompt_format == PromptFormat::NumLinesUniDiff {
+ prompt.push_str(UNIFIED_DIFF_REMINDER);
+ }
+
+ Ok((prompt, section_labels))
+}
+
+pub fn write_excerpts<'a>(
+ excerpts: impl IntoIterator<Item = &'a Excerpt>,
+ sorted_insertions: &[(Point, &str)],
+ file_line_count: Line,
+ include_line_numbers: bool,
+ output: &mut String,
+) {
+ let mut current_row = Line(0);
+ let mut sorted_insertions = sorted_insertions.iter().peekable();
+
+ for excerpt in excerpts {
+ if excerpt.start_line > current_row {
+ writeln!(output, "โฆ").unwrap();
+ }
+ if excerpt.text.is_empty() {
+ return;
+ }
+
+ current_row = excerpt.start_line;
+
+ for mut line in excerpt.text.lines() {
+ if include_line_numbers {
+ write!(output, "{}|", current_row.0 + 1).unwrap();
+ }
+
+ while let Some((insertion_location, insertion_marker)) = sorted_insertions.peek() {
+ match current_row.cmp(&insertion_location.line) {
+ cmp::Ordering::Equal => {
+ let (prefix, suffix) = line.split_at(insertion_location.column as usize);
+ output.push_str(prefix);
+ output.push_str(insertion_marker);
+ line = suffix;
+ sorted_insertions.next();
+ }
+ cmp::Ordering::Less => break,
+ cmp::Ordering::Greater => {
+ sorted_insertions.next();
+ break;
+ }
+ }
+ }
+ output.push_str(line);
+ output.push('\n');
+ current_row.0 += 1;
+ }
+ }
+
+ if current_row < file_line_count {
+ writeln!(output, "โฆ").unwrap();
+ }
+}
+
+fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
+ if events.is_empty() {
+ return;
+ };
+
+ writeln!(output, "`````diff").unwrap();
+ for event in events {
+ writeln!(output, "{}", event).unwrap();
+ }
+ writeln!(output, "`````\n").unwrap();
+}
+
+pub struct SyntaxBasedPrompt<'a> {
request: &'a predict_edits_v3::PredictEditsRequest,
/// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
/// `to_prompt_string`.
@@ -120,13 +293,13 @@ pub enum DeclarationStyle {
Declaration,
}
-#[derive(Clone, Debug, Serialize)]
+#[derive(Default, Clone, Debug, Serialize)]
pub struct SectionLabels {
pub excerpt_index: usize,
pub section_ranges: Vec<(Arc<Path>, Range<Line>)>,
}
-impl<'a> PlannedPrompt<'a> {
+impl<'a> SyntaxBasedPrompt<'a> {
/// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
///
/// Initializes a priority queue by populating it with each snippet, finding the
@@ -149,7 +322,7 @@ impl<'a> PlannedPrompt<'a> {
///
/// * Does not include file paths / other text when considering max_bytes.
pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
- let mut this = PlannedPrompt {
+ let mut this = Self {
request,
snippets: Vec::new(),
budget_used: request.excerpt.len(),
@@ -354,7 +527,11 @@ impl<'a> PlannedPrompt<'a> {
/// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
/// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
/// chunks.
- pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
+ pub fn write(
+ &'a self,
+ excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
+ prompt: &mut String,
+ ) -> Result<SectionLabels> {
let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
FxHashMap::default();
for snippet in &self.snippets {
@@ -383,95 +560,10 @@ impl<'a> PlannedPrompt<'a> {
excerpt_file_snippets.push(&excerpt_snippet);
file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
- let mut excerpt_file_insertions = match self.request.prompt_format {
- PromptFormat::MarkedExcerpt => vec![
- (
- Point {
- line: self.request.excerpt_line_range.start,
- column: 0,
- },
- EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
- ),
- (self.request.cursor_point, CURSOR_MARKER),
- (
- Point {
- line: self.request.excerpt_line_range.end,
- column: 0,
- },
- EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
- ),
- ],
- PromptFormat::LabeledSections => vec![(self.request.cursor_point, CURSOR_MARKER)],
- PromptFormat::NumLinesUniDiff => {
- vec![(self.request.cursor_point, CURSOR_MARKER)]
- }
- PromptFormat::OnlySnippets => vec![],
- };
-
- let mut prompt = match self.request.prompt_format {
- PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(),
- PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(),
- PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(),
- // only intended for use via zeta_cli
- PromptFormat::OnlySnippets => String::new(),
- };
-
- if self.request.events.is_empty() {
- prompt.push_str("(No edit history)\n\n");
- } else {
- prompt.push_str(
- "The following are the latest edits made by the user, from earlier to later.\n\n",
- );
- Self::push_events(&mut prompt, &self.request.events);
- }
-
- if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
- if self.request.referenced_declarations.is_empty() {
- prompt.push_str(indoc! {"
- # File under the cursor:
-
- The cursor marker <|user_cursor|> indicates the current user cursor position.
- The file is in current state, edits from edit history have been applied.
- We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
-
- "});
- } else {
- // Note: This hasn't been trained on yet
- prompt.push_str(indoc! {"
- # Code Excerpts:
-
- The cursor marker <|user_cursor|> indicates the current user cursor position.
- Other excerpts of code from the project have been included as context based on their similarity to the code under the cursor.
- Context excerpts are not guaranteed to be relevant, so use your own judgement.
- Files are in their current state, edits from edit history have been applied.
- We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
-
- "});
- }
- } else {
- prompt.push_str("\n## Code\n\n");
- }
-
let section_labels =
- self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
-
- if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
- prompt.push_str(UNIFIED_DIFF_REMINDER);
- }
-
- Ok((prompt, section_labels))
- }
+ self.push_file_snippets(prompt, excerpt_file_insertions, file_snippets)?;
- fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
- if events.is_empty() {
- return;
- };
-
- writeln!(output, "`````diff").unwrap();
- for event in events {
- writeln!(output, "{}", event).unwrap();
- }
- writeln!(output, "`````\n").unwrap();
+ Ok(section_labels)
}
fn push_file_snippets(
@@ -20,7 +20,6 @@ test-support = [
"text/test-support",
"tree-sitter-rust",
"tree-sitter-python",
- "tree-sitter-rust",
"tree-sitter-typescript",
"settings/test-support",
"util/test-support",
@@ -3833,6 +3833,32 @@ impl BufferSnapshot {
include_extra_context: bool,
theme: Option<&SyntaxTheme>,
) -> Vec<OutlineItem<Anchor>> {
+ self.outline_items_containing_internal(
+ range,
+ include_extra_context,
+ theme,
+ |this, range| this.anchor_after(range.start)..this.anchor_before(range.end),
+ )
+ }
+
+ pub fn outline_items_as_points_containing<T: ToOffset>(
+ &self,
+ range: Range<T>,
+ include_extra_context: bool,
+ theme: Option<&SyntaxTheme>,
+ ) -> Vec<OutlineItem<Point>> {
+ self.outline_items_containing_internal(range, include_extra_context, theme, |_, range| {
+ range
+ })
+ }
+
+ fn outline_items_containing_internal<T: ToOffset, U>(
+ &self,
+ range: Range<T>,
+ include_extra_context: bool,
+ theme: Option<&SyntaxTheme>,
+ range_callback: fn(&Self, Range<Point>) -> Range<U>,
+ ) -> Vec<OutlineItem<U>> {
let range = range.to_offset(self);
let mut matches = self.syntax.matches(range.clone(), &self.text, |grammar| {
grammar.outline_config.as_ref().map(|c| &c.query)
@@ -3905,19 +3931,16 @@ impl BufferSnapshot {
anchor_items.push(OutlineItem {
depth: item_ends_stack.len(),
- range: self.anchor_after(item.range.start)..self.anchor_before(item.range.end),
+ range: range_callback(self, item.range.clone()),
+ source_range_for_text: range_callback(self, item.source_range_for_text.clone()),
text: item.text,
highlight_ranges: item.highlight_ranges,
name_ranges: item.name_ranges,
- body_range: item
- .body_range
- .map(|r| self.anchor_after(r.start)..self.anchor_before(r.end)),
+ body_range: item.body_range.map(|r| range_callback(self, r)),
annotation_range: annotation_row_range.map(|annotation_range| {
- self.anchor_after(Point::new(annotation_range.start, 0))
- ..self.anchor_before(Point::new(
- annotation_range.end,
- self.line_len(annotation_range.end),
- ))
+ let point_range = Point::new(annotation_range.start, 0)
+ ..Point::new(annotation_range.end, self.line_len(annotation_range.end));
+ range_callback(self, point_range)
}),
});
item_ends_stack.push(item.range.end);
@@ -3984,14 +4007,13 @@ impl BufferSnapshot {
if buffer_ranges.is_empty() {
return None;
}
+ let source_range_for_text =
+ buffer_ranges.first().unwrap().0.start..buffer_ranges.last().unwrap().0.end;
let mut text = String::new();
let mut highlight_ranges = Vec::new();
let mut name_ranges = Vec::new();
- let mut chunks = self.chunks(
- buffer_ranges.first().unwrap().0.start..buffer_ranges.last().unwrap().0.end,
- true,
- );
+ let mut chunks = self.chunks(source_range_for_text.clone(), true);
let mut last_buffer_range_end = 0;
for (buffer_range, is_name) in buffer_ranges {
let space_added = !text.is_empty() && buffer_range.start > last_buffer_range_end;
@@ -4037,6 +4059,7 @@ impl BufferSnapshot {
Some(OutlineItem {
depth: 0, // We'll calculate the depth later
range: item_point_range,
+ source_range_for_text: source_range_for_text.to_point(self),
text,
highlight_ranges,
name_ranges,
@@ -16,6 +16,7 @@ pub struct Outline<T> {
pub struct OutlineItem<T> {
pub depth: usize,
pub range: Range<T>,
+ pub source_range_for_text: Range<T>,
pub text: String,
pub highlight_ranges: Vec<(Range<usize>, HighlightStyle)>,
pub name_ranges: Vec<Range<usize>>,
@@ -32,6 +33,8 @@ impl<T: ToPoint> OutlineItem<T> {
OutlineItem {
depth: self.depth,
range: self.range.start.to_point(buffer)..self.range.end.to_point(buffer),
+ source_range_for_text: self.source_range_for_text.start.to_point(buffer)
+ ..self.source_range_for_text.end.to_point(buffer),
text: self.text.clone(),
highlight_ranges: self.highlight_ranges.clone(),
name_ranges: self.name_ranges.clone(),
@@ -205,6 +208,7 @@ mod tests {
OutlineItem {
depth: 0,
range: Point::new(0, 0)..Point::new(5, 0),
+ source_range_for_text: Point::new(0, 0)..Point::new(0, 9),
text: "class Foo".to_string(),
highlight_ranges: vec![],
name_ranges: vec![6..9],
@@ -214,6 +218,7 @@ mod tests {
OutlineItem {
depth: 0,
range: Point::new(2, 0)..Point::new(2, 7),
+ source_range_for_text: Point::new(0, 0)..Point::new(0, 7),
text: "private".to_string(),
highlight_ranges: vec![],
name_ranges: vec![],
@@ -238,6 +243,7 @@ mod tests {
OutlineItem {
depth: 0,
range: Point::new(0, 0)..Point::new(5, 0),
+ source_range_for_text: Point::new(0, 0)..Point::new(0, 10),
text: "fn process".to_string(),
highlight_ranges: vec![],
name_ranges: vec![3..10],
@@ -247,6 +253,7 @@ mod tests {
OutlineItem {
depth: 0,
range: Point::new(7, 0)..Point::new(12, 0),
+ source_range_for_text: Point::new(0, 0)..Point::new(0, 20),
text: "struct DataProcessor".to_string(),
highlight_ranges: vec![],
name_ranges: vec![7..20],
@@ -20,7 +20,7 @@
trait: (_)? @name
"for"? @context
type: (_) @name
- body: (_ "{" @open (_)* "}" @close)) @item
+ body: (_ . "{" @open "}" @close .)) @item
(trait_item
(visibility_modifier)? @context
@@ -31,7 +31,8 @@
(visibility_modifier)? @context
(function_modifiers)? @context
"fn" @context
- name: (_) @name) @item
+ name: (_) @name
+ body: (_ . "{" @open "}" @close .)) @item
(function_signature_item
(visibility_modifier)? @context
@@ -5451,6 +5451,8 @@ impl MultiBufferSnapshot {
Some(OutlineItem {
depth: item.depth,
range: self.anchor_range_in_excerpt(*excerpt_id, item.range)?,
+ source_range_for_text: self
+ .anchor_range_in_excerpt(*excerpt_id, item.source_range_for_text)?,
text: item.text,
highlight_ranges: item.highlight_ranges,
name_ranges: item.name_ranges,
@@ -5484,6 +5486,11 @@ impl MultiBufferSnapshot {
.flat_map(|item| {
Some(OutlineItem {
depth: item.depth,
+ source_range_for_text: Anchor::range_in_buffer(
+ excerpt_id,
+ buffer_id,
+ item.source_range_for_text,
+ ),
range: Anchor::range_in_buffer(excerpt_id, buffer_id, item.range),
text: item.text,
highlight_ranges: item.highlight_ranges,
@@ -2484,6 +2484,7 @@ impl OutlinePanel {
annotation_range: None,
range: search_data.context_range.clone(),
text: search_data.context_text.clone(),
+ source_range_for_text: search_data.context_range.clone(),
highlight_ranges: search_data
.highlights_data
.get()
@@ -18,6 +18,7 @@ chrono.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
cloud_zeta2_prompt.workspace = true
+collections.workspace = true
edit_prediction.workspace = true
edit_prediction_context.workspace = true
feature_flags.workspace = true
@@ -29,6 +30,7 @@ language_model.workspace = true
log.workspace = true
project.workspace = true
release_channel.workspace = true
+schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
thiserror.workspace = true
@@ -43,6 +45,7 @@ cloud_llm_client = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
lsp.workspace = true
indoc.workspace = true
+language = { workspace = true, features = ["test-support"] }
language_model = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
@@ -0,0 +1,192 @@
+use cloud_llm_client::predict_edits_v3::{self, Excerpt};
+use edit_prediction_context::Line;
+use language::{BufferSnapshot, Point};
+use std::ops::Range;
+
+pub fn merge_excerpts(
+ buffer: &BufferSnapshot,
+ sorted_line_ranges: impl IntoIterator<Item = Range<Line>>,
+) -> Vec<Excerpt> {
+ let mut output = Vec::new();
+ let mut merged_ranges = Vec::<Range<Line>>::new();
+
+ for line_range in sorted_line_ranges {
+ if let Some(last_line_range) = merged_ranges.last_mut()
+ && line_range.start <= last_line_range.end
+ {
+ last_line_range.end = last_line_range.end.max(line_range.end);
+ continue;
+ }
+ merged_ranges.push(line_range);
+ }
+
+ let outline_items = buffer.outline_items_as_points_containing(0..buffer.len(), false, None);
+ let mut outline_items = outline_items.into_iter().peekable();
+
+ for range in merged_ranges {
+ let point_range = Point::new(range.start.0, 0)..Point::new(range.end.0, 0);
+
+ while let Some(outline_item) = outline_items.peek() {
+ if outline_item.range.start >= point_range.start {
+ break;
+ }
+ if outline_item.range.end > point_range.start {
+ let mut point_range = outline_item.source_range_for_text.clone();
+ point_range.start.column = 0;
+ point_range.end.column = buffer.line_len(point_range.end.row);
+
+ output.push(Excerpt {
+ start_line: Line(point_range.start.row),
+ text: buffer
+ .text_for_range(point_range.clone())
+ .collect::<String>()
+ .into(),
+ })
+ }
+ outline_items.next();
+ }
+
+ output.push(Excerpt {
+ start_line: Line(point_range.start.row),
+ text: buffer
+ .text_for_range(point_range.clone())
+ .collect::<String>()
+ .into(),
+ })
+ }
+
+ output
+}
+
+pub fn write_merged_excerpts(
+ buffer: &BufferSnapshot,
+ sorted_line_ranges: impl IntoIterator<Item = Range<Line>>,
+ sorted_insertions: &[(predict_edits_v3::Point, &str)],
+ output: &mut String,
+) {
+ cloud_zeta2_prompt::write_excerpts(
+ merge_excerpts(buffer, sorted_line_ranges).iter(),
+ sorted_insertions,
+ Line(buffer.max_point().row),
+ true,
+ output,
+ );
+}
+
+#[cfg(test)]
+mod tests {
+ use std::sync::Arc;
+
+ use super::*;
+ use gpui::{TestAppContext, prelude::*};
+ use indoc::indoc;
+ use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt};
+ use pretty_assertions::assert_eq;
+ use util::test::marked_text_ranges;
+
+ #[gpui::test]
+ fn test_rust(cx: &mut TestAppContext) {
+ let table = [
+ (
+ indoc! {r#"
+ struct User {
+ first_name: String,
+ ยซ last_name: String,
+ ageห: u32,
+ ยป email: String,
+ create_at: Instant,
+ }
+
+ impl User {
+ pub fn first_name(&self) -> String {
+ self.first_name.clone()
+ }
+
+ pub fn full_name(&self) -> String {
+ ยซ format!("{} {}", self.first_name, self.last_name)
+ ยป }
+ }
+ "#},
+ indoc! {r#"
+ 1|struct User {
+ โฆ
+ 3| last_name: String,
+ 4| age<|cursor|>: u32,
+ โฆ
+ 9|impl User {
+ โฆ
+ 14| pub fn full_name(&self) -> String {
+ 15| format!("{} {}", self.first_name, self.last_name)
+ โฆ
+ "#},
+ ),
+ (
+ indoc! {r#"
+ struct User {
+ first_name: String,
+ ยซ last_name: String,
+ age: u32,
+ }
+ ยป"#
+ },
+ indoc! {r#"
+ 1|struct User {
+ โฆ
+ 3| last_name: String,
+ 4| age: u32,
+ 5|}
+ "#},
+ ),
+ ];
+
+ for (input, expected_output) in table {
+ let input_without_ranges = input.replace(['ยซ', 'ยป'], "");
+ let input_without_caret = input.replace('ห', "");
+ let cursor_offset = input_without_ranges.find('ห');
+ let (input, ranges) = marked_text_ranges(&input_without_caret, false);
+ let buffer =
+ cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx));
+ buffer.read_with(cx, |buffer, _cx| {
+ let insertions = cursor_offset
+ .map(|offset| {
+ let point = buffer.offset_to_point(offset);
+ vec![(
+ predict_edits_v3::Point {
+ line: Line(point.row),
+ column: point.column,
+ },
+ "<|cursor|>",
+ )]
+ })
+ .unwrap_or_default();
+ let ranges: Vec<Range<Line>> = ranges
+ .into_iter()
+ .map(|range| {
+ let point_range = range.to_point(&buffer);
+ Line(point_range.start.row)..Line(point_range.end.row)
+ })
+ .collect();
+
+ let mut output = String::new();
+ write_merged_excerpts(&buffer.snapshot(), ranges, &insertions, &mut output);
+ assert_eq!(output, expected_output);
+ });
+ }
+ }
+
+ fn rust_lang() -> Language {
+ Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ Some(language::tree_sitter_rust::LANGUAGE.into()),
+ )
+ .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
+ .unwrap()
+ }
+}
@@ -116,6 +116,10 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
return;
}
+ self.zeta.update(cx, |zeta, cx| {
+ zeta.refresh_context_if_needed(&self.project, &buffer, cursor_position, cx);
+ });
+
let pending_prediction_id = self.next_pending_prediction_id;
self.next_pending_prediction_id += 1;
let last_request_timestamp = self.last_request_timestamp;
@@ -0,0 +1,586 @@
+use std::{cmp::Reverse, fmt::Write, ops::Range, path::PathBuf, sync::Arc};
+
+use crate::merge_excerpts::write_merged_excerpts;
+use anyhow::{Result, anyhow};
+use collections::HashMap;
+use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
+use futures::{StreamExt, stream::BoxStream};
+use gpui::{App, AsyncApp, Entity, Task};
+use indoc::indoc;
+use language::{Anchor, Bias, Buffer, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _};
+use language_model::{
+ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
+ LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
+ LanguageModelRequestTool, LanguageModelToolResult, MessageContent, Role,
+};
+use project::{
+ Project, WorktreeSettings,
+ search::{SearchQuery, SearchResult},
+};
+use schemars::JsonSchema;
+use serde::Deserialize;
+use util::paths::{PathMatcher, PathStyle};
+use workspace::item::Settings as _;
+
+const SEARCH_PROMPT: &str = indoc! {r#"
+ ## Task
+
+ You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
+ that will serve as context for predicting the next required edit.
+
+ **Your task:**
+ - Analyze the user's recent edits and current cursor context
+ - Use the `search` tool to find code that may be relevant for predicting the next edit
+ - Focus on finding:
+ - Code patterns that might need similar changes based on the recent edits
+ - Functions, variables, types, and constants referenced in the current cursor context
+ - Related implementations, usages, or dependencies that may require consistent updates
+
+ **Important constraints:**
+ - This conversation has exactly 2 turns
+ - You must make ALL search queries in your first response via the `search` tool
+ - All queries will be executed in parallel and results returned together
+ - In the second turn, you will select the most relevant results via the `select` tool.
+
+ ## User Edits
+
+ {edits}
+
+ ## Current cursor context
+
+ `````filename={current_file_path}
+ {cursor_excerpt}
+ `````
+
+ --
+ Use the `search` tool now
+"#};
+
+const SEARCH_TOOL_NAME: &str = "search";
+
+/// Search for relevant code
+///
+/// For the best results, run multiple queries at once with a single invocation of this tool.
+#[derive(Deserialize, JsonSchema)]
+struct SearchToolInput {
+ /// An array of queries to run for gathering context relevant to the next prediction
+ #[schemars(length(max = 5))]
+ queries: Box<[SearchToolQuery]>,
+}
+
+#[derive(Deserialize, JsonSchema)]
+struct SearchToolQuery {
+ /// A glob pattern to match file paths in the codebase
+ glob: String,
+ /// A regular expression to match content within the files matched by the glob pattern
+ regex: String,
+ /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
+ #[serde(default)]
+ case_sensitive: bool,
+}
+
+const RESULTS_MESSAGE: &str = indoc! {"
+ Here are the results of your queries combined and grouped by file:
+
+"};
+
+const SELECT_TOOL_NAME: &str = "select";
+
+const SELECT_PROMPT: &str = indoc! {"
+ Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
+ Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
+ Include up to 200 lines in total.
+"};
+
+/// Select line ranges from search results
+#[derive(Deserialize, JsonSchema)]
+struct SelectToolInput {
+ /// The line ranges to select from search results.
+ ranges: Vec<SelectLineRange>,
+}
+
+/// A specific line range to select from a file
+#[derive(Debug, Deserialize, JsonSchema)]
+struct SelectLineRange {
+ /// The file path containing the lines to select
+ /// Exactly as it appears in the search result codeblocks.
+ path: PathBuf,
+ /// The starting line number (1-based)
+ #[schemars(range(min = 1))]
+ start_line: u32,
+ /// The ending line number (1-based, inclusive)
+ #[schemars(range(min = 1))]
+ end_line: u32,
+}
+
+#[derive(Debug, Clone, PartialEq)]
+pub struct LlmContextOptions {
+ pub excerpt: EditPredictionExcerptOptions,
+}
+
+pub fn find_related_excerpts<'a>(
+ buffer: Entity<language::Buffer>,
+ cursor_position: Anchor,
+ project: &Entity<Project>,
+ events: impl Iterator<Item = &'a crate::Event>,
+ options: &LlmContextOptions,
+ cx: &App,
+) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
+ let language_model_registry = LanguageModelRegistry::global(cx);
+ let Some(model) = language_model_registry
+ .read(cx)
+ .available_models(cx)
+ .find(|model| {
+ model.provider_id() == language_model::ANTHROPIC_PROVIDER_ID
+ && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
+ })
+ else {
+ return Task::ready(Err(anyhow!("could not find claude model")));
+ };
+
+ let mut edits_string = String::new();
+
+ for event in events {
+ if let Some(event) = event.to_request_event(cx) {
+ writeln!(&mut edits_string, "{event}").ok();
+ }
+ }
+
+ if edits_string.is_empty() {
+ edits_string.push_str("(No user edits yet)");
+ }
+
+ // TODO [zeta2] include breadcrumbs?
+ let snapshot = buffer.read(cx).snapshot();
+ let cursor_point = cursor_position.to_point(&snapshot);
+ let Some(cursor_excerpt) =
+ EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
+ else {
+ return Task::ready(Ok(HashMap::default()));
+ };
+
+ let current_file_path = snapshot
+ .file()
+ .map(|f| f.full_path(cx).display().to_string())
+ .unwrap_or_else(|| "untitled".to_string());
+
+ let prompt = SEARCH_PROMPT
+ .replace("{edits}", &edits_string)
+ .replace("{current_file_path}", ¤t_file_path)
+ .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
+
+ let path_style = project.read(cx).path_style(cx);
+
+ let exclude_matcher = {
+ let global_settings = WorktreeSettings::get_global(cx);
+ let exclude_patterns = global_settings
+ .file_scan_exclusions
+ .sources()
+ .iter()
+ .chain(global_settings.private_files.sources().iter());
+
+ match PathMatcher::new(exclude_patterns, path_style) {
+ Ok(matcher) => matcher,
+ Err(err) => {
+ return Task::ready(Err(anyhow!(err)));
+ }
+ }
+ };
+
+ let project = project.clone();
+ cx.spawn(async move |cx| {
+ let initial_prompt_message = LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![prompt.into()],
+ cache: false,
+ };
+
+ let mut search_stream = request_tool_call::<SearchToolInput>(
+ vec![initial_prompt_message.clone()],
+ SEARCH_TOOL_NAME,
+ &model,
+ cx,
+ )
+ .await?;
+
+ let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
+ select_request_messages.push(initial_prompt_message);
+ let mut search_calls = Vec::new();
+
+ while let Some(event) = search_stream.next().await {
+ match event? {
+ LanguageModelCompletionEvent::ToolUse(tool_use) => {
+ if !tool_use.is_input_complete {
+ continue;
+ }
+
+ if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
+ search_calls.push((select_request_messages.len(), tool_use));
+ } else {
+ log::warn!(
+ "context gathering model tried to use unknown tool: {}",
+ tool_use.name
+ );
+ }
+ }
+ LanguageModelCompletionEvent::Text(txt) => {
+ if let Some(LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content,
+ ..
+ }) = select_request_messages.last_mut()
+ {
+ if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
+ existing_text.push_str(&txt);
+ } else {
+ content.push(MessageContent::Text(txt));
+ }
+ } else {
+ select_request_messages.push(LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![MessageContent::Text(txt)],
+ cache: false,
+ });
+ }
+ }
+ LanguageModelCompletionEvent::Thinking { text, signature } => {
+ if let Some(LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content,
+ ..
+ }) = select_request_messages.last_mut()
+ {
+ if let Some(MessageContent::Thinking {
+ text: existing_text,
+ signature: existing_signature,
+ }) = content.last_mut()
+ {
+ existing_text.push_str(&text);
+ *existing_signature = signature;
+ } else {
+ content.push(MessageContent::Thinking { text, signature });
+ }
+ } else {
+ select_request_messages.push(LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![MessageContent::Thinking { text, signature }],
+ cache: false,
+ });
+ }
+ }
+ LanguageModelCompletionEvent::RedactedThinking { data } => {
+ if let Some(LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content,
+ ..
+ }) = select_request_messages.last_mut()
+ {
+ if let Some(MessageContent::RedactedThinking(existing_data)) =
+ content.last_mut()
+ {
+ existing_data.push_str(&data);
+ } else {
+ content.push(MessageContent::RedactedThinking(data));
+ }
+ } else {
+ select_request_messages.push(LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![MessageContent::RedactedThinking(data)],
+ cache: false,
+ });
+ }
+ }
+ ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
+ log::error!("{ev:?}");
+ }
+ ev => {
+ log::trace!("context search event: {ev:?}")
+ }
+ }
+ }
+
+ struct ResultBuffer {
+ buffer: Entity<Buffer>,
+ snapshot: TextBufferSnapshot,
+ }
+
+ let mut result_buffers_by_path = HashMap::default();
+
+ for (index, tool_use) in search_calls.into_iter().rev() {
+ let call = serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
+
+ let mut excerpts_by_buffer = HashMap::default();
+
+ for query in call.queries {
+ // TODO [zeta2] parallelize?
+
+ run_query(
+ query,
+ &mut excerpts_by_buffer,
+ path_style,
+ exclude_matcher.clone(),
+ &project,
+ cx,
+ )
+ .await?;
+ }
+
+ if excerpts_by_buffer.is_empty() {
+ continue;
+ }
+
+ let mut merged_result = RESULTS_MESSAGE.to_string();
+
+ for (buffer_entity, mut excerpts_for_buffer) in excerpts_by_buffer {
+ excerpts_for_buffer.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
+
+ buffer_entity
+ .clone()
+ .read_with(cx, |buffer, cx| {
+ let Some(file) = buffer.file() else {
+ return;
+ };
+
+ let path = file.full_path(cx);
+
+ writeln!(&mut merged_result, "`````filename={}", path.display()).unwrap();
+
+ let snapshot = buffer.snapshot();
+
+ write_merged_excerpts(
+ &snapshot,
+ excerpts_for_buffer,
+ &[],
+ &mut merged_result,
+ );
+
+ merged_result.push_str("`````\n\n");
+
+ result_buffers_by_path.insert(
+ path,
+ ResultBuffer {
+ buffer: buffer_entity,
+ snapshot: snapshot.text,
+ },
+ );
+ })
+ .ok();
+ }
+
+ let tool_result = LanguageModelToolResult {
+ tool_use_id: tool_use.id.clone(),
+ tool_name: SEARCH_TOOL_NAME.into(),
+ is_error: false,
+ content: merged_result.into(),
+ output: None,
+ };
+
+ // Almost always appends at the end, but in theory, the model could return some text after the tool call
+ // or perform parallel tool calls, so we splice at the message index for correctness.
+ select_request_messages.splice(
+ index..index,
+ [
+ LanguageModelRequestMessage {
+ role: Role::Assistant,
+ content: vec![MessageContent::ToolUse(tool_use)],
+ cache: false,
+ },
+ LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![MessageContent::ToolResult(tool_result)],
+ cache: false,
+ },
+ ],
+ );
+ }
+
+ if result_buffers_by_path.is_empty() {
+ log::trace!("context gathering queries produced no results");
+ return anyhow::Ok(HashMap::default());
+ }
+
+ select_request_messages.push(LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![SELECT_PROMPT.into()],
+ cache: false,
+ });
+
+ let mut select_stream = request_tool_call::<SelectToolInput>(
+ select_request_messages,
+ SELECT_TOOL_NAME,
+ &model,
+ cx,
+ )
+ .await?;
+ let mut selected_ranges = Vec::new();
+
+ while let Some(event) = select_stream.next().await {
+ match event? {
+ LanguageModelCompletionEvent::ToolUse(tool_use) => {
+ if !tool_use.is_input_complete {
+ continue;
+ }
+
+ if tool_use.name.as_ref() == SELECT_TOOL_NAME {
+ let call =
+ serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
+ selected_ranges.extend(call.ranges);
+ } else {
+ log::warn!(
+ "context gathering model tried to use unknown tool: {}",
+ tool_use.name
+ );
+ }
+ }
+ ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
+ log::error!("{ev:?}");
+ }
+ ev => {
+ log::trace!("context select event: {ev:?}")
+ }
+ }
+ }
+
+ if selected_ranges.is_empty() {
+ log::trace!("context gathering selected no ranges")
+ }
+
+ let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
+
+ for selected_range in selected_ranges {
+ if let Some(ResultBuffer { buffer, snapshot }) =
+ result_buffers_by_path.get(&selected_range.path)
+ {
+ let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
+ let end_point =
+ snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
+ let range = snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
+
+ related_excerpts_by_buffer
+ .entry(buffer.clone())
+ .or_default()
+ .push(range);
+ } else {
+ log::warn!(
+ "selected path that wasn't included in search results: {}",
+ selected_range.path.display()
+ );
+ }
+ }
+
+ for (buffer, ranges) in &mut related_excerpts_by_buffer {
+ buffer.read_with(cx, |buffer, _cx| {
+ ranges.sort_unstable_by(|a, b| {
+ a.start
+ .cmp(&b.start, buffer)
+ .then(b.end.cmp(&a.end, buffer))
+ });
+ })?;
+ }
+
+ anyhow::Ok(related_excerpts_by_buffer)
+ })
+}
+
+async fn request_tool_call<T: JsonSchema>(
+ messages: Vec<LanguageModelRequestMessage>,
+ tool_name: &'static str,
+ model: &Arc<dyn LanguageModel>,
+ cx: &mut AsyncApp,
+) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
+{
+ let schema = schemars::schema_for!(T);
+
+ let request = LanguageModelRequest {
+ messages,
+ tools: vec![LanguageModelRequestTool {
+ name: tool_name.into(),
+ description: schema
+ .get("description")
+ .and_then(|description| description.as_str())
+ .unwrap()
+ .to_string(),
+ input_schema: serde_json::to_value(schema).unwrap(),
+ }],
+ ..Default::default()
+ };
+
+ Ok(model.stream_completion(request, cx).await?)
+}
+
+const MIN_EXCERPT_LEN: usize = 16;
+const MAX_EXCERPT_LEN: usize = 768;
+const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
+
+async fn run_query(
+ args: SearchToolQuery,
+ excerpts_by_buffer: &mut HashMap<Entity<Buffer>, Vec<Range<Line>>>,
+ path_style: PathStyle,
+ exclude_matcher: PathMatcher,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+) -> Result<()> {
+ let include_matcher = PathMatcher::new(vec![args.glob], path_style)?;
+
+ let query = SearchQuery::regex(
+ &args.regex,
+ false,
+ args.case_sensitive,
+ false,
+ true,
+ include_matcher,
+ exclude_matcher,
+ true,
+ None,
+ )?;
+
+ let results = project.update(cx, |project, cx| project.search(query, cx))?;
+ futures::pin_mut!(results);
+
+ let mut total_bytes = 0;
+
+ while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
+ if ranges.is_empty() {
+ continue;
+ }
+
+ let excerpts_for_buffer = excerpts_by_buffer
+ .entry(buffer.clone())
+ .or_insert_with(|| Vec::with_capacity(ranges.len()));
+
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+
+ for range in ranges {
+ let offset_range = range.to_offset(&snapshot);
+ let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
+
+ if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
+ break;
+ }
+
+ let excerpt = EditPredictionExcerpt::select_from_buffer(
+ query_point,
+ &snapshot,
+ &EditPredictionExcerptOptions {
+ max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
+ min_bytes: MIN_EXCERPT_LEN,
+ target_before_cursor_over_total_bytes: 0.5,
+ },
+ None,
+ );
+
+ if let Some(excerpt) = excerpt {
+ total_bytes += excerpt.range.len();
+ if !excerpt.line_range.is_empty() {
+ excerpts_for_buffer.push(excerpt.line_range);
+ }
+ }
+ }
+
+ if excerpts_for_buffer.is_empty() {
+ excerpts_by_buffer.remove(&buffer);
+ }
+ }
+
+ anyhow::Ok(())
+}
@@ -6,10 +6,12 @@ use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
ZED_VERSION_HEADER_NAME,
};
-use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
+use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
+use collections::HashMap;
use edit_prediction_context::{
DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
- EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
+ EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionScoreOptions, Line,
+ SyntaxIndex, SyntaxIndexState,
};
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
use futures::AsyncReadExt as _;
@@ -19,25 +21,32 @@ use gpui::{
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
http_client, prelude::*,
};
-use language::BufferSnapshot;
-use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
+use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
+use language::{BufferSnapshot, OffsetRangeExt};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use project::Project;
use release_channel::AppVersion;
use serde::de::DeserializeOwned;
-use std::collections::{HashMap, VecDeque, hash_map};
+use std::collections::{VecDeque, hash_map};
+use std::ops::Range;
use std::path::Path;
use std::str::FromStr as _;
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
+use util::ResultExt as _;
use util::rel_path::RelPathBuf;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
+mod merge_excerpts;
mod prediction;
mod provider;
+mod related_excerpts;
+use crate::merge_excerpts::merge_excerpts;
use crate::prediction::EditPrediction;
+pub use crate::related_excerpts::LlmContextOptions;
+use crate::related_excerpts::find_related_excerpts;
pub use provider::ZetaEditPredictionProvider;
const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
@@ -45,19 +54,28 @@ const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
/// Maximum number of events to track.
const MAX_EVENT_COUNT: usize = 16;
-pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
- use_imports: true,
- max_retrieved_declarations: 0,
- excerpt: EditPredictionExcerptOptions {
- max_bytes: 512,
- min_bytes: 128,
- target_before_cursor_over_total_bytes: 0.5,
- },
- score: EditPredictionScoreOptions {
- omit_excerpt_overlaps: true,
- },
+pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
+ max_bytes: 512,
+ min_bytes: 128,
+ target_before_cursor_over_total_bytes: 0.5,
+};
+
+pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
+
+pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
+ excerpt: DEFAULT_EXCERPT_OPTIONS,
};
+pub const DEFAULT_SYNTAX_CONTEXT_OPTIONS: EditPredictionContextOptions =
+ EditPredictionContextOptions {
+ use_imports: true,
+ max_retrieved_declarations: 0,
+ excerpt: DEFAULT_EXCERPT_OPTIONS,
+ score: EditPredictionScoreOptions {
+ omit_excerpt_overlaps: true,
+ },
+ };
+
pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
context: DEFAULT_CONTEXT_OPTIONS,
max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
@@ -94,13 +112,28 @@ pub struct Zeta {
#[derive(Debug, Clone, PartialEq)]
pub struct ZetaOptions {
- pub context: EditPredictionContextOptions,
+ pub context: ContextMode,
pub max_prompt_bytes: usize,
pub max_diagnostic_bytes: usize,
pub prompt_format: predict_edits_v3::PromptFormat,
pub file_indexing_parallelism: usize,
}
+#[derive(Debug, Clone, PartialEq)]
+pub enum ContextMode {
+ Llm(LlmContextOptions),
+ Syntax(EditPredictionContextOptions),
+}
+
+impl ContextMode {
+ pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
+ match self {
+ ContextMode::Llm(options) => &options.excerpt,
+ ContextMode::Syntax(options) => &options.excerpt,
+ }
+ }
+}
+
pub struct PredictionDebugInfo {
pub request: predict_edits_v3::PredictEditsRequest,
pub retrieval_time: TimeDelta,
@@ -117,6 +150,10 @@ struct ZetaProject {
events: VecDeque<Event>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
+ context: Option<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>,
+ refresh_context_task: Option<Task<Option<()>>>,
+ refresh_context_debounce_task: Option<Task<Option<()>>>,
+ refresh_context_timestamp: Option<Instant>,
}
#[derive(Debug, Clone)]
@@ -183,6 +220,44 @@ pub enum Event {
},
}
+impl Event {
+ pub fn to_request_event(&self, cx: &App) -> Option<predict_edits_v3::Event> {
+ match self {
+ Event::BufferChange {
+ old_snapshot,
+ new_snapshot,
+ ..
+ } => {
+ let path = new_snapshot.file().map(|f| f.full_path(cx));
+
+ let old_path = old_snapshot.file().and_then(|f| {
+ let old_path = f.full_path(cx);
+ if Some(&old_path) != path.as_ref() {
+ Some(old_path)
+ } else {
+ None
+ }
+ });
+
+ // TODO [zeta2] move to bg?
+ let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
+
+ if path == old_path && diff.is_empty() {
+ None
+ } else {
+ Some(predict_edits_v3::Event::BufferChange {
+ old_path,
+ path,
+ diff,
+ //todo: Actually detect if this edit was predicted or not
+ predicted: false,
+ })
+ }
+ }
+ }
+ }
+}
+
impl Zeta {
pub fn try_global(cx: &App) -> Option<Entity<Self>> {
cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
@@ -206,7 +281,7 @@ impl Zeta {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
Self {
- projects: HashMap::new(),
+ projects: HashMap::default(),
client,
user_store,
options: DEFAULT_OPTIONS,
@@ -248,6 +323,14 @@ impl Zeta {
}
}
+ pub fn history_for_project(&self, project: &Entity<Project>) -> impl Iterator<Item = &Event> {
+ static EMPTY_EVENTS: VecDeque<Event> = VecDeque::new();
+ self.projects
+ .get(&project.entity_id())
+ .map_or(&EMPTY_EVENTS, |project| &project.events)
+ .iter()
+ }
+
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
self.user_store.read(cx).edit_prediction_usage()
}
@@ -278,8 +361,12 @@ impl Zeta {
SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
}),
events: VecDeque::new(),
- registered_buffers: HashMap::new(),
+ registered_buffers: HashMap::default(),
current_prediction: None,
+ context: None,
+ refresh_context_task: None,
+ refresh_context_debounce_task: None,
+ refresh_context_timestamp: None,
})
}
@@ -507,7 +594,10 @@ impl Zeta {
});
let options = self.options.clone();
let snapshot = buffer.read(cx).snapshot();
- let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
+ let Some(excerpt_path) = snapshot
+ .file()
+ .map(|path| -> Arc<Path> { path.full_path(cx).into() })
+ else {
return Task::ready(Err(anyhow!("No file path for excerpt")));
};
let client = self.client.clone();
@@ -525,40 +615,7 @@ impl Zeta {
state
.events
.iter()
- .filter_map(|event| match event {
- Event::BufferChange {
- old_snapshot,
- new_snapshot,
- ..
- } => {
- let path = new_snapshot.file().map(|f| f.full_path(cx));
-
- let old_path = old_snapshot.file().and_then(|f| {
- let old_path = f.full_path(cx);
- if Some(&old_path) != path.as_ref() {
- Some(old_path)
- } else {
- None
- }
- });
-
- // TODO [zeta2] move to bg?
- let diff =
- language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
-
- if path == old_path && diff.is_empty() {
- None
- } else {
- Some(predict_edits_v3::Event::BufferChange {
- old_path,
- path,
- diff,
- //todo: Actually detect if this edit was predicted or not
- predicted: false,
- })
- }
- }
- })
+ .filter_map(|event| event.to_request_event(cx))
.collect::<Vec<_>>()
})
.unwrap_or_default();
@@ -573,6 +630,20 @@ impl Zeta {
// TODO data collection
let can_collect_data = cx.is_staff();
+ let mut included_files = project_state
+ .and_then(|project_state| project_state.context.as_ref())
+ .unwrap_or(&HashMap::default())
+ .iter()
+ .filter_map(|(buffer, ranges)| {
+ let buffer = buffer.read(cx);
+ Some((
+ buffer.snapshot(),
+ buffer.file()?.full_path(cx).into(),
+ ranges.clone(),
+ ))
+ })
+ .collect::<Vec<_>>();
+
let request_task = cx.background_spawn({
let snapshot = snapshot.clone();
let buffer = buffer.clone();
@@ -588,18 +659,6 @@ impl Zeta {
let before_retrieval = chrono::Utc::now();
- let Some(context) = EditPredictionContext::gather_context(
- cursor_point,
- &snapshot,
- parent_abs_path.as_deref(),
- &options.context,
- index_state.as_deref(),
- ) else {
- return Ok((None, None));
- };
-
- let retrieval_time = chrono::Utc::now() - before_retrieval;
-
let (diagnostic_groups, diagnostic_groups_truncated) =
Self::gather_nearby_diagnostics(
cursor_offset,
@@ -608,26 +667,127 @@ impl Zeta {
options.max_diagnostic_bytes,
);
- let request = make_cloud_request(
- excerpt_path,
- context,
- events,
- can_collect_data,
- diagnostic_groups,
- diagnostic_groups_truncated,
- None,
- debug_tx.is_some(),
- &worktree_snapshots,
- index_state.as_deref(),
- Some(options.max_prompt_bytes),
- options.prompt_format,
- );
+ let request = match options.context {
+ ContextMode::Llm(context_options) => {
+ let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
+ cursor_point,
+ &snapshot,
+ &context_options.excerpt,
+ index_state.as_deref(),
+ ) else {
+ return Ok((None, None));
+ };
+
+ let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
+ ..snapshot.anchor_before(excerpt.range.end);
+
+ if let Some(buffer_ix) = included_files
+ .iter()
+ .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
+ {
+ let (buffer, _, ranges) = &mut included_files[buffer_ix];
+ let range_ix = ranges
+ .binary_search_by(|probe| {
+ probe
+ .start
+ .cmp(&excerpt_anchor_range.start, buffer)
+ .then(excerpt_anchor_range.end.cmp(&probe.end, buffer))
+ })
+ .unwrap_or_else(|ix| ix);
+
+ ranges.insert(range_ix, excerpt_anchor_range);
+ let last_ix = included_files.len() - 1;
+ included_files.swap(buffer_ix, last_ix);
+ } else {
+ included_files.push((
+ snapshot,
+ excerpt_path.clone(),
+ vec![excerpt_anchor_range],
+ ));
+ }
+
+ let included_files = included_files
+ .into_iter()
+ .map(|(buffer, path, ranges)| {
+ let excerpts = merge_excerpts(
+ &buffer,
+ ranges.iter().map(|range| {
+ let point_range = range.to_point(&buffer);
+ Line(point_range.start.row)..Line(point_range.end.row)
+ }),
+ );
+ predict_edits_v3::IncludedFile {
+ path,
+ max_row: Line(buffer.max_point().row),
+ excerpts,
+ }
+ })
+ .collect::<Vec<_>>();
+
+ predict_edits_v3::PredictEditsRequest {
+ excerpt_path,
+ excerpt: String::new(),
+ excerpt_line_range: Line(0)..Line(0),
+ excerpt_range: 0..0,
+ cursor_point: predict_edits_v3::Point {
+ line: predict_edits_v3::Line(cursor_point.row),
+ column: cursor_point.column,
+ },
+ included_files,
+ referenced_declarations: vec![],
+ events,
+ can_collect_data,
+ diagnostic_groups,
+ diagnostic_groups_truncated,
+ debug_info: debug_tx.is_some(),
+ prompt_max_bytes: Some(options.max_prompt_bytes),
+ prompt_format: options.prompt_format,
+ // TODO [zeta2]
+ signatures: vec![],
+ excerpt_parent: None,
+ git_info: None,
+ }
+ }
+ ContextMode::Syntax(context_options) => {
+ let Some(context) = EditPredictionContext::gather_context(
+ cursor_point,
+ &snapshot,
+ parent_abs_path.as_deref(),
+ &context_options,
+ index_state.as_deref(),
+ ) else {
+ return Ok((None, None));
+ };
+
+ make_syntax_context_cloud_request(
+ excerpt_path,
+ context,
+ events,
+ can_collect_data,
+ diagnostic_groups,
+ diagnostic_groups_truncated,
+ None,
+ debug_tx.is_some(),
+ &worktree_snapshots,
+ index_state.as_deref(),
+ Some(options.max_prompt_bytes),
+ options.prompt_format,
+ )
+ }
+ };
+
+ let retrieval_time = chrono::Utc::now() - before_retrieval;
let debug_response_tx = if let Some(debug_tx) = &debug_tx {
let (response_tx, response_rx) = oneshot::channel();
- let local_prompt = PlannedPrompt::populate(&request)
- .and_then(|p| p.to_prompt_string().map(|p| p.0))
+ if !request.referenced_declarations.is_empty() || !request.signatures.is_empty()
+ {
+ } else {
+ };
+
+ let local_prompt = build_prompt(&request)
+ .map(|(prompt, _)| prompt)
.map_err(|err| err.to_string());
debug_tx
@@ -827,6 +987,103 @@ impl Zeta {
}
}
+ pub const CONTEXT_RETRIEVAL_IDLE_DURATION: Duration = Duration::from_secs(10);
+ pub const CONTEXT_RETRIEVAL_DEBOUNCE_DURATION: Duration = Duration::from_secs(3);
+
+ // Refresh the related excerpts when the user just beguns editing after
+ // an idle period, and after they pause editing.
+ fn refresh_context_if_needed(
+ &mut self,
+ project: &Entity<Project>,
+ buffer: &Entity<language::Buffer>,
+ cursor_position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) {
+ if !matches!(&self.options().context, ContextMode::Llm { .. }) {
+ return;
+ }
+
+ let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
+
+ let now = Instant::now();
+ let was_idle = zeta_project
+ .refresh_context_timestamp
+ .map_or(true, |timestamp| {
+ now - timestamp > Self::CONTEXT_RETRIEVAL_IDLE_DURATION
+ });
+ zeta_project.refresh_context_timestamp = Some(now);
+ zeta_project.refresh_context_debounce_task = Some(cx.spawn({
+ let buffer = buffer.clone();
+ let project = project.clone();
+ async move |this, cx| {
+ if was_idle {
+ log::debug!("refetching edit prediction context after idle");
+ } else {
+ cx.background_executor()
+ .timer(Self::CONTEXT_RETRIEVAL_DEBOUNCE_DURATION)
+ .await;
+ log::debug!("refetching edit prediction context after pause");
+ }
+ this.update(cx, |this, cx| {
+ this.refresh_context(project, buffer, cursor_position, cx);
+ })
+ .ok()
+ }
+ }));
+ }
+
+ // Refresh the related excerpts asynchronously. Ensure the task runs to completion,
+ // and avoid spawning more than one concurrent task.
+ fn refresh_context(
+ &mut self,
+ project: Entity<Project>,
+ buffer: Entity<language::Buffer>,
+ cursor_position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(zeta_project) = self.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
+
+ zeta_project
+ .refresh_context_task
+ .get_or_insert(cx.spawn(async move |this, cx| {
+ let related_excerpts = this
+ .update(cx, |this, cx| {
+ let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
+ return Task::ready(anyhow::Ok(HashMap::default()));
+ };
+
+ let ContextMode::Llm(options) = &this.options().context else {
+ return Task::ready(anyhow::Ok(HashMap::default()));
+ };
+
+ find_related_excerpts(
+ buffer.clone(),
+ cursor_position,
+ &project,
+ zeta_project.events.iter(),
+ options,
+ cx,
+ )
+ })
+ .ok()?
+ .await
+ .log_err()
+ .unwrap_or_default();
+ this.update(cx, |this, _cx| {
+ let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
+ return;
+ };
+ zeta_project.context = Some(related_excerpts);
+ zeta_project.refresh_context_task.take();
+ })
+ .ok()
+ }));
+ }
+
fn gather_nearby_diagnostics(
cursor_offset: usize,
diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
@@ -918,12 +1175,20 @@ impl Zeta {
cursor_point,
&snapshot,
parent_abs_path.as_deref(),
- &options.context,
+ match &options.context {
+ ContextMode::Llm(_) => {
+ // TODO
+ panic!("Llm mode not supported in zeta cli yet");
+ }
+ ContextMode::Syntax(edit_prediction_context_options) => {
+ edit_prediction_context_options
+ }
+ },
index_state.as_deref(),
)
.context("Failed to select excerpt")
.map(|context| {
- make_cloud_request(
+ make_syntax_context_cloud_request(
excerpt_path.into(),
context,
// TODO pass everything
@@ -963,7 +1228,7 @@ pub struct ZedUpdateRequiredError {
minimum_version: SemanticVersion,
}
-fn make_cloud_request(
+fn make_syntax_context_cloud_request(
excerpt_path: Arc<Path>,
context: EditPredictionContext,
events: Vec<predict_edits_v3::Event>,
@@ -1044,6 +1309,7 @@ fn make_cloud_request(
column: context.cursor_point.column,
},
referenced_declarations,
+ included_files: vec![],
signatures,
excerpt_parent,
events,
@@ -20,7 +20,10 @@ use ui::{ButtonLike, ContextMenu, ContextMenuEntry, DropdownMenu, KeyBinding, pr
use ui_input::InputField;
use util::{ResultExt, paths::PathStyle, rel_path::RelPath};
use workspace::{Item, SplitDirection, Workspace};
-use zeta2::{PredictionDebugInfo, Zeta, Zeta2FeatureFlag, ZetaOptions};
+use zeta2::{
+ ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, LlmContextOptions, PredictionDebugInfo, Zeta,
+ Zeta2FeatureFlag, ZetaOptions,
+};
use edit_prediction_context::{EditPredictionContextOptions, EditPredictionExcerptOptions};
@@ -69,7 +72,7 @@ pub struct Zeta2Inspector {
min_excerpt_bytes_input: Entity<InputField>,
cursor_context_ratio_input: Entity<InputField>,
max_prompt_bytes_input: Entity<InputField>,
- max_retrieved_declarations: Entity<InputField>,
+ context_mode: ContextModeState,
active_view: ActiveView,
zeta: Entity<Zeta>,
_active_editor_subscription: Option<Subscription>,
@@ -77,6 +80,13 @@ pub struct Zeta2Inspector {
_receive_task: Task<()>,
}
+pub enum ContextModeState {
+ Llm,
+ Syntax {
+ max_retrieved_declarations: Entity<InputField>,
+ },
+}
+
#[derive(PartialEq)]
enum ActiveView {
Context,
@@ -143,36 +153,34 @@ impl Zeta2Inspector {
min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx),
cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx),
max_prompt_bytes_input: Self::number_input("Max Prompt Bytes", window, cx),
- max_retrieved_declarations: Self::number_input("Max Retrieved Definitions", window, cx),
+ context_mode: ContextModeState::Llm,
zeta: zeta.clone(),
_active_editor_subscription: None,
_update_state_task: Task::ready(()),
_receive_task: receive_task,
};
- this.set_input_options(&zeta.read(cx).options().clone(), window, cx);
+ this.set_options_state(&zeta.read(cx).options().clone(), window, cx);
this
}
- fn set_input_options(
+ fn set_options_state(
&mut self,
options: &ZetaOptions,
window: &mut Window,
cx: &mut Context<Self>,
) {
+ let excerpt_options = options.context.excerpt();
self.max_excerpt_bytes_input.update(cx, |input, cx| {
- input.set_text(options.context.excerpt.max_bytes.to_string(), window, cx);
+ input.set_text(excerpt_options.max_bytes.to_string(), window, cx);
});
self.min_excerpt_bytes_input.update(cx, |input, cx| {
- input.set_text(options.context.excerpt.min_bytes.to_string(), window, cx);
+ input.set_text(excerpt_options.min_bytes.to_string(), window, cx);
});
self.cursor_context_ratio_input.update(cx, |input, cx| {
input.set_text(
format!(
"{:.2}",
- options
- .context
- .excerpt
- .target_before_cursor_over_total_bytes
+ excerpt_options.target_before_cursor_over_total_bytes
),
window,
cx,
@@ -181,20 +189,28 @@ impl Zeta2Inspector {
self.max_prompt_bytes_input.update(cx, |input, cx| {
input.set_text(options.max_prompt_bytes.to_string(), window, cx);
});
- self.max_retrieved_declarations.update(cx, |input, cx| {
- input.set_text(
- options.context.max_retrieved_declarations.to_string(),
- window,
- cx,
- );
- });
+
+ match &options.context {
+ ContextMode::Llm(_) => {
+ self.context_mode = ContextModeState::Llm;
+ }
+ ContextMode::Syntax(_) => {
+ self.context_mode = ContextModeState::Syntax {
+ max_retrieved_declarations: Self::number_input(
+ "Max Retrieved Definitions",
+ window,
+ cx,
+ ),
+ };
+ }
+ }
cx.notify();
}
- fn set_options(&mut self, options: ZetaOptions, cx: &mut Context<Self>) {
+ fn set_zeta_options(&mut self, options: ZetaOptions, cx: &mut Context<Self>) {
self.zeta.update(cx, |this, _cx| this.set_options(options));
- const THROTTLE_TIME: Duration = Duration::from_millis(100);
+ const DEBOUNCE_TIME: Duration = Duration::from_millis(100);
if let Some(prediction) = self.last_prediction.as_mut() {
if let Some(buffer) = prediction.buffer.upgrade() {
@@ -202,7 +218,7 @@ impl Zeta2Inspector {
let zeta = self.zeta.clone();
let project = self.project.clone();
prediction._task = Some(cx.spawn(async move |_this, cx| {
- cx.background_executor().timer(THROTTLE_TIME).await;
+ cx.background_executor().timer(DEBOUNCE_TIME).await;
if let Some(task) = zeta
.update(cx, |zeta, cx| {
zeta.refresh_prediction(&project, &buffer, position, cx)
@@ -255,25 +271,40 @@ impl Zeta2Inspector {
let zeta_options = this.zeta.read(cx).options().clone();
- let context_options = EditPredictionContextOptions {
- excerpt: EditPredictionExcerptOptions {
- max_bytes: number_input_value(&this.max_excerpt_bytes_input, cx),
- min_bytes: number_input_value(&this.min_excerpt_bytes_input, cx),
- target_before_cursor_over_total_bytes: number_input_value(
- &this.cursor_context_ratio_input,
- cx,
- ),
- },
- max_retrieved_declarations: number_input_value(
- &this.max_retrieved_declarations,
+ let excerpt_options = EditPredictionExcerptOptions {
+ max_bytes: number_input_value(&this.max_excerpt_bytes_input, cx),
+ min_bytes: number_input_value(&this.min_excerpt_bytes_input, cx),
+ target_before_cursor_over_total_bytes: number_input_value(
+ &this.cursor_context_ratio_input,
cx,
),
- ..zeta_options.context
};
- this.set_options(
+ let context = match zeta_options.context {
+ ContextMode::Llm(_context_options) => ContextMode::Llm(LlmContextOptions {
+ excerpt: excerpt_options,
+ }),
+ ContextMode::Syntax(context_options) => {
+ let max_retrieved_declarations = match &this.context_mode {
+ ContextModeState::Llm => {
+ zeta2::DEFAULT_SYNTAX_CONTEXT_OPTIONS.max_retrieved_declarations
+ }
+ ContextModeState::Syntax {
+ max_retrieved_declarations,
+ } => number_input_value(max_retrieved_declarations, cx),
+ };
+
+ ContextMode::Syntax(EditPredictionContextOptions {
+ excerpt: excerpt_options,
+ max_retrieved_declarations,
+ ..context_options
+ })
+ }
+ };
+
+ this.set_zeta_options(
ZetaOptions {
- context: context_options,
+ context,
max_prompt_bytes: number_input_value(&this.max_prompt_bytes_input, cx),
max_diagnostic_bytes: zeta_options.max_diagnostic_bytes,
prompt_format: zeta_options.prompt_format,
@@ -709,7 +740,7 @@ impl Zeta2Inspector {
.style(ButtonStyle::Outlined)
.size(ButtonSize::Large)
.on_click(cx.listener(|this, _, window, cx| {
- this.set_input_options(&zeta2::DEFAULT_OPTIONS, window, cx);
+ this.set_options_state(&zeta2::DEFAULT_OPTIONS, window, cx);
})),
),
)
@@ -722,19 +753,113 @@ impl Zeta2Inspector {
.items_end()
.child(self.max_excerpt_bytes_input.clone())
.child(self.min_excerpt_bytes_input.clone())
- .child(self.cursor_context_ratio_input.clone()),
+ .child(self.cursor_context_ratio_input.clone())
+ .child(self.render_context_mode_dropdown(window, cx)),
)
.child(
h_flex()
.gap_2()
.items_end()
- .child(self.max_retrieved_declarations.clone())
+ .children(match &self.context_mode {
+ ContextModeState::Llm => None,
+ ContextModeState::Syntax {
+ max_retrieved_declarations,
+ } => Some(max_retrieved_declarations.clone()),
+ })
.child(self.max_prompt_bytes_input.clone())
.child(self.render_prompt_format_dropdown(window, cx)),
),
)
}
+ fn render_context_mode_dropdown(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
+ let this = cx.weak_entity();
+
+ v_flex()
+ .gap_1p5()
+ .child(
+ Label::new("Context Mode")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .child(
+ DropdownMenu::new(
+ "ep-ctx-mode",
+ match &self.context_mode {
+ ContextModeState::Llm => "LLM-based",
+ ContextModeState::Syntax { .. } => "Syntax",
+ },
+ ContextMenu::build(window, cx, move |menu, _window, _cx| {
+ menu.item(
+ ContextMenuEntry::new("LLM-based")
+ .toggleable(
+ IconPosition::End,
+ matches!(self.context_mode, ContextModeState::Llm),
+ )
+ .handler({
+ let this = this.clone();
+ move |window, cx| {
+ this.update(cx, |this, cx| {
+ let current_options =
+ this.zeta.read(cx).options().clone();
+ match current_options.context.clone() {
+ ContextMode::Llm(_) => {}
+ ContextMode::Syntax(context_options) => {
+ let options = ZetaOptions {
+ context: ContextMode::Llm(
+ LlmContextOptions {
+ excerpt: context_options.excerpt,
+ },
+ ),
+ ..current_options
+ };
+ this.set_options_state(&options, window, cx);
+ this.set_zeta_options(options, cx);
+ }
+ }
+ })
+ .ok();
+ }
+ }),
+ )
+ .item(
+ ContextMenuEntry::new("Syntax")
+ .toggleable(
+ IconPosition::End,
+ matches!(self.context_mode, ContextModeState::Syntax { .. }),
+ )
+ .handler({
+ move |window, cx| {
+ this.update(cx, |this, cx| {
+ let current_options =
+ this.zeta.read(cx).options().clone();
+ match current_options.context.clone() {
+ ContextMode::Llm(context_options) => {
+ let options = ZetaOptions {
+ context: ContextMode::Syntax(
+ EditPredictionContextOptions {
+ excerpt: context_options.excerpt,
+ ..DEFAULT_SYNTAX_CONTEXT_OPTIONS
+ },
+ ),
+ ..current_options
+ };
+ this.set_options_state(&options, window, cx);
+ this.set_zeta_options(options, cx);
+ }
+ ContextMode::Syntax(_) => {}
+ }
+ })
+ .ok();
+ }
+ }),
+ )
+ }),
+ )
+ .style(ui::DropdownStyle::Outlined),
+ )
+ }
+
fn render_prompt_format_dropdown(&self, window: &mut Window, cx: &mut Context<Self>) -> Div {
let active_format = self.zeta.read(cx).options().prompt_format;
let this = cx.weak_entity();
@@ -765,7 +890,7 @@ impl Zeta2Inspector {
prompt_format,
..current_options
};
- this.set_options(options, cx);
+ this.set_zeta_options(options, cx);
})
.ok();
}
@@ -20,6 +20,7 @@ use reqwest_client::ReqwestClient;
use serde_json::json;
use std::{collections::HashSet, path::PathBuf, process::exit, str::FromStr, sync::Arc};
use zeta::{PerformPredictEditsParams, Zeta};
+use zeta2::ContextMode;
use crate::headless::ZetaCliAppState;
use crate::source_location::SourceLocation;
@@ -263,8 +264,8 @@ async fn get_context(
})?
.await?;
- let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?;
- let (prompt_string, section_labels) = planned_prompt.to_prompt_string()?;
+ let (prompt_string, section_labels) =
+ cloud_zeta2_prompt::build_prompt(&request)?;
match zeta2_args.output_format {
OutputFormat::Prompt => anyhow::Ok(prompt_string),
@@ -301,7 +302,7 @@ async fn get_context(
impl Zeta2Args {
fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
zeta2::ZetaOptions {
- context: EditPredictionContextOptions {
+ context: ContextMode::Syntax(EditPredictionContextOptions {
max_retrieved_declarations: self.max_retrieved_definitions,
use_imports: !self.disable_imports_gathering,
excerpt: EditPredictionExcerptOptions {
@@ -313,7 +314,7 @@ impl Zeta2Args {
score: EditPredictionScoreOptions {
omit_excerpt_overlaps,
},
- },
+ }),
max_diagnostic_bytes: self.max_diagnostic_bytes,
max_prompt_bytes: self.max_prompt_bytes,
prompt_format: self.prompt_format.clone().into(),
@@ -3,8 +3,8 @@ use ::util::{RangeExt, ResultExt as _};
use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
use edit_prediction_context::{
- Declaration, DeclarationStyle, EditPredictionContext, Identifier, Imports, Reference,
- ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range,
+ Declaration, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, Identifier,
+ Imports, Reference, ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range,
};
use futures::StreamExt as _;
use futures::channel::mpsc;
@@ -32,6 +32,7 @@ use std::{
time::Duration,
};
use util::paths::PathStyle;
+use zeta2::ContextMode;
use crate::headless::ZetaCliAppState;
use crate::source_location::SourceLocation;
@@ -46,6 +47,10 @@ pub async fn retrieval_stats(
options: zeta2::ZetaOptions,
cx: &mut AsyncApp,
) -> Result<String> {
+ let ContextMode::Syntax(context_options) = options.context.clone() else {
+ anyhow::bail!("retrieval stats only works in ContextMode::Syntax");
+ };
+
let options = Arc::new(options);
let worktree_path = worktree.canonicalize()?;
@@ -264,10 +269,10 @@ pub async fn retrieval_stats(
.map(|project_file| {
let index_state = index_state.clone();
let lsp_definitions = lsp_definitions.clone();
- let options = options.clone();
let output_tx = output_tx.clone();
let done_count = done_count.clone();
let file_snapshots = file_snapshots.clone();
+ let context_options = context_options.clone();
cx.background_spawn(async move {
let snapshot = project_file.snapshot;
@@ -279,7 +284,7 @@ pub async fn retrieval_stats(
&snapshot,
);
- let imports = if options.context.use_imports {
+ let imports = if context_options.use_imports {
Imports::gather(&snapshot, Some(&project_file.parent_abs_path))
} else {
Imports::default()
@@ -311,7 +316,7 @@ pub async fn retrieval_stats(
&snapshot,
&index_state,
&file_snapshots,
- &options,
+ &context_options,
)
.await?;
@@ -958,7 +963,7 @@ async fn retrieve_definitions(
snapshot: &BufferSnapshot,
index: &Arc<SyntaxIndexState>,
file_snapshots: &Arc<HashMap<ProjectEntryId, BufferSnapshot>>,
- options: &Arc<zeta2::ZetaOptions>,
+ context_options: &EditPredictionContextOptions,
) -> Result<RetrieveResult> {
let mut single_reference_map = HashMap::default();
single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]);
@@ -966,7 +971,7 @@ async fn retrieve_definitions(
query_point,
snapshot,
imports,
- &options.context,
+ &context_options,
Some(&index),
|_, _, _| single_reference_map,
);