Detailed changes
@@ -1,7 +1,7 @@
use chrono::Duration;
use serde::{Deserialize, Serialize};
use std::{
- ops::Range,
+ ops::{Add, Range, Sub},
path::{Path, PathBuf},
sync::Arc,
};
@@ -18,8 +18,8 @@ pub struct PredictEditsRequest {
pub excerpt_path: Arc<Path>,
/// Within file
pub excerpt_range: Range<usize>,
- /// Within `excerpt`
- pub cursor_offset: usize,
+ pub excerpt_line_range: Range<Line>,
+ pub cursor_point: Point,
/// Within `signatures`
pub excerpt_parent: Option<usize>,
pub signatures: Vec<Signature>,
@@ -47,12 +47,13 @@ pub struct PredictEditsRequest {
pub enum PromptFormat {
MarkedExcerpt,
LabeledSections,
+ NumberedLines,
/// Prompt format intended for use via zeta_cli
OnlySnippets,
}
impl PromptFormat {
- pub const DEFAULT: PromptFormat = PromptFormat::LabeledSections;
+ pub const DEFAULT: PromptFormat = PromptFormat::NumberedLines;
}
impl Default for PromptFormat {
@@ -73,6 +74,7 @@ impl std::fmt::Display for PromptFormat {
PromptFormat::MarkedExcerpt => write!(f, "Marked Excerpt"),
PromptFormat::LabeledSections => write!(f, "Labeled Sections"),
PromptFormat::OnlySnippets => write!(f, "Only Snippets"),
+ PromptFormat::NumberedLines => write!(f, "Numbered Lines"),
}
}
}
@@ -97,7 +99,7 @@ pub struct Signature {
pub parent_index: Option<usize>,
/// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
/// file is implicitly the file that contains the descendant declaration or excerpt.
- pub range: Range<usize>,
+ pub range: Range<Line>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -106,7 +108,7 @@ pub struct ReferencedDeclaration {
pub text: String,
pub text_is_truncated: bool,
/// Range of `text` within file, possibly truncated according to `text_is_truncated`
- pub range: Range<usize>,
+ pub range: Range<Line>,
/// Range within `text`
pub signature_range: Range<usize>,
/// Index within `signatures`.
@@ -169,10 +171,36 @@ pub struct DebugInfo {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Edit {
pub path: Arc<Path>,
- pub range: Range<usize>,
+ pub range: Range<Line>,
pub content: String,
}
fn is_default<T: Default + PartialEq>(value: &T) -> bool {
*value == T::default()
}
+
+#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
+pub struct Point {
+ pub line: Line,
+ pub column: u32,
+}
+
+#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, PartialOrd, Eq, Ord)]
+#[serde(transparent)]
+pub struct Line(pub u32);
+
+impl Add for Line {
+ type Output = Self;
+
+ fn add(self, rhs: Self) -> Self::Output {
+ Self(self.0 + rhs.0)
+ }
+}
+
+impl Sub for Line {
+ type Output = Self;
+
+ fn sub(self, rhs: Self) -> Self::Output {
+ Self(self.0 - rhs.0)
+ }
+}
@@ -1,7 +1,9 @@
//! Zeta2 prompt planning and generation code shared with cloud.
use anyhow::{Context as _, Result, anyhow};
-use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, ReferencedDeclaration};
+use cloud_llm_client::predict_edits_v3::{
+ self, Event, Line, Point, PromptFormat, ReferencedDeclaration,
+};
use indoc::indoc;
use ordered_float::OrderedFloat;
use rustc_hash::{FxHashMap, FxHashSet};
@@ -43,6 +45,42 @@ const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#"
}
"#};
+const NUMBERED_LINES_SYSTEM_PROMPT: &str = indoc! {r#"
+ # Instructions
+
+ You are a code completion assistant helping a programmer finish their work. Your task is to:
+
+ 1. Analyze the edit history to understand what the programmer is trying to achieve
+ 2. Identify any incomplete refactoring or changes that need to be finished
+ 3. Make the remaining edits that a human programmer would logically make next
+ 4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
+
+ Focus on:
+ - Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
+ - Completing any partially-applied changes across the codebase
+ - Ensuring consistency with the programming style and patterns already established
+ - Making edits that maintain or improve code quality
+ - If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
+ - Don't write a lot of code if you're not sure what to do
+
+ Rules:
+ - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
+ - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
+ - Write the edits in the unified diff format as shown in the example.
+
+ # Example output:
+
+ ```
+ --- a/distill-claude/tmp-outs/edits_history.txt
+ +++ b/distill-claude/tmp-outs/edits_history.txt
+ @@ -1,3 +1,3 @@
+ -
+ -
+ -import sys
+ +import json
+ ```
+"#};
+
pub struct PlannedPrompt<'a> {
request: &'a predict_edits_v3::PredictEditsRequest,
/// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
@@ -55,6 +93,7 @@ pub fn system_prompt(format: PromptFormat) -> &'static str {
match format {
PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT,
PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT,
+ PromptFormat::NumberedLines => NUMBERED_LINES_SYSTEM_PROMPT,
// only intended for use via zeta_cli
PromptFormat::OnlySnippets => "",
}
@@ -63,7 +102,7 @@ pub fn system_prompt(format: PromptFormat) -> &'static str {
#[derive(Clone, Debug)]
pub struct PlannedSnippet<'a> {
path: Arc<Path>,
- range: Range<usize>,
+ range: Range<Line>,
text: &'a str,
// TODO: Indicate this in the output
#[allow(dead_code)]
@@ -79,7 +118,7 @@ pub enum DeclarationStyle {
#[derive(Clone, Debug, Serialize)]
pub struct SectionLabels {
pub excerpt_index: usize,
- pub section_ranges: Vec<(Arc<Path>, Range<usize>)>,
+ pub section_ranges: Vec<(Arc<Path>, Range<Line>)>,
}
impl<'a> PlannedPrompt<'a> {
@@ -196,10 +235,24 @@ impl<'a> PlannedPrompt<'a> {
declaration.text.len()
));
};
+ let signature_start_line = declaration.range.start
+ + Line(
+ declaration.text[..declaration.signature_range.start]
+ .lines()
+ .count() as u32,
+ );
+ let signature_end_line = signature_start_line
+ + Line(
+ declaration.text
+ [declaration.signature_range.start..declaration.signature_range.end]
+ .lines()
+ .count() as u32,
+ );
+ let range = signature_start_line..signature_end_line;
+
PlannedSnippet {
path: declaration.path.clone(),
- range: (declaration.signature_range.start + declaration.range.start)
- ..(declaration.signature_range.end + declaration.range.start),
+ range,
text,
text_is_truncated: declaration.text_is_truncated,
}
@@ -318,7 +371,7 @@ impl<'a> PlannedPrompt<'a> {
}
let excerpt_snippet = PlannedSnippet {
path: self.request.excerpt_path.clone(),
- range: self.request.excerpt_range.clone(),
+ range: self.request.excerpt_line_range.clone(),
text: &self.request.excerpt,
text_is_truncated: false,
};
@@ -328,32 +381,33 @@ impl<'a> PlannedPrompt<'a> {
let mut excerpt_file_insertions = match self.request.prompt_format {
PromptFormat::MarkedExcerpt => vec![
(
- self.request.excerpt_range.start,
+ Point {
+ line: self.request.excerpt_line_range.start,
+ column: 0,
+ },
EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
),
+ (self.request.cursor_point, CURSOR_MARKER),
(
- self.request.excerpt_range.start + self.request.cursor_offset,
- CURSOR_MARKER,
- ),
- (
- self.request
- .excerpt_range
- .end
- .saturating_sub(0)
- .max(self.request.excerpt_range.start),
+ Point {
+ line: self.request.excerpt_line_range.end,
+ column: 0,
+ },
EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
),
],
- PromptFormat::LabeledSections => vec![(
- self.request.excerpt_range.start + self.request.cursor_offset,
- CURSOR_MARKER,
- )],
+ PromptFormat::LabeledSections => vec![(self.request.cursor_point, CURSOR_MARKER)],
+ PromptFormat::NumberedLines => vec![(self.request.cursor_point, CURSOR_MARKER)],
PromptFormat::OnlySnippets => vec![],
};
let mut prompt = String::new();
prompt.push_str("## User Edits\n\n");
- Self::push_events(&mut prompt, &self.request.events);
+ if self.request.events.is_empty() {
+ prompt.push_str("No edits yet.\n");
+ } else {
+ Self::push_events(&mut prompt, &self.request.events);
+ }
prompt.push_str("\n## Code\n\n");
let section_labels =
@@ -391,13 +445,17 @@ impl<'a> PlannedPrompt<'a> {
if *predicted {
writeln!(
output,
- "User accepted prediction {:?}:\n```diff\n{}\n```\n",
+ "User accepted prediction {:?}:\n`````diff\n{}\n`````\n",
path, diff
)
.unwrap();
} else {
- writeln!(output, "User edited {:?}:\n```diff\n{}\n```\n", path, diff)
- .unwrap();
+ writeln!(
+ output,
+ "User edited {:?}:\n`````diff\n{}\n`````\n",
+ path, diff
+ )
+ .unwrap();
}
}
}
@@ -407,7 +465,7 @@ impl<'a> PlannedPrompt<'a> {
fn push_file_snippets(
&self,
output: &mut String,
- excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
+ excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
) -> Result<SectionLabels> {
let mut section_ranges = Vec::new();
@@ -417,15 +475,13 @@ impl<'a> PlannedPrompt<'a> {
snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
// TODO: What if the snippets get expanded too large to be editable?
- let mut current_snippet: Option<(&PlannedSnippet, Range<usize>)> = None;
- let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<usize>)> = Vec::new();
+ let mut current_snippet: Option<(&PlannedSnippet, Range<Line>)> = None;
+ let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<Line>)> = Vec::new();
for snippet in snippets {
if let Some((_, current_snippet_range)) = current_snippet.as_mut()
- && snippet.range.start < current_snippet_range.end
+ && snippet.range.start <= current_snippet_range.end
{
- if snippet.range.end > current_snippet_range.end {
- current_snippet_range.end = snippet.range.end;
- }
+ current_snippet_range.end = current_snippet_range.end.max(snippet.range.end);
continue;
}
if let Some(current_snippet) = current_snippet.take() {
@@ -437,21 +493,24 @@ impl<'a> PlannedPrompt<'a> {
disjoint_snippets.push(current_snippet);
}
- writeln!(output, "```{}", file_path.display()).ok();
+ // TODO: remove filename=?
+ writeln!(output, "`````filename={}", file_path.display()).ok();
let mut skipped_last_snippet = false;
for (snippet, range) in disjoint_snippets {
let section_index = section_ranges.len();
match self.request.prompt_format {
- PromptFormat::MarkedExcerpt | PromptFormat::OnlySnippets => {
- if range.start > 0 && !skipped_last_snippet {
+ PromptFormat::MarkedExcerpt
+ | PromptFormat::OnlySnippets
+ | PromptFormat::NumberedLines => {
+ if range.start.0 > 0 && !skipped_last_snippet {
output.push_str("…\n");
}
}
PromptFormat::LabeledSections => {
if is_excerpt_file
- && range.start <= self.request.excerpt_range.start
- && range.end >= self.request.excerpt_range.end
+ && range.start <= self.request.excerpt_line_range.start
+ && range.end >= self.request.excerpt_line_range.end
{
writeln!(output, "<|current_section|>").ok();
} else {
@@ -460,46 +519,83 @@ impl<'a> PlannedPrompt<'a> {
}
}
+ let push_full_snippet = |output: &mut String| {
+ if self.request.prompt_format == PromptFormat::NumberedLines {
+ for (i, line) in snippet.text.lines().enumerate() {
+ writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?;
+ }
+ } else {
+ output.push_str(&snippet.text);
+ }
+ anyhow::Ok(())
+ };
+
if is_excerpt_file {
if self.request.prompt_format == PromptFormat::OnlySnippets {
- if range.start >= self.request.excerpt_range.start
- && range.end <= self.request.excerpt_range.end
+ if range.start >= self.request.excerpt_line_range.start
+ && range.end <= self.request.excerpt_line_range.end
{
skipped_last_snippet = true;
} else {
skipped_last_snippet = false;
output.push_str(snippet.text);
}
- } else {
- let mut last_offset = range.start;
- let mut i = 0;
- while i < excerpt_file_insertions.len() {
- let (offset, insertion) = &excerpt_file_insertions[i];
- let found = *offset >= range.start && *offset <= range.end;
+ } else if !excerpt_file_insertions.is_empty() {
+ let lines = snippet.text.lines().collect::<Vec<_>>();
+ let push_line = |output: &mut String, line_ix: usize| {
+ if self.request.prompt_format == PromptFormat::NumberedLines {
+ write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?;
+ }
+ anyhow::Ok(writeln!(output, "{}", lines[line_ix])?)
+ };
+ let mut last_line_ix = 0;
+ let mut insertion_ix = 0;
+ while insertion_ix < excerpt_file_insertions.len() {
+ let (point, insertion) = &excerpt_file_insertions[insertion_ix];
+ let found = point.line >= range.start && point.line <= range.end;
if found {
excerpt_index = Some(section_index);
- output.push_str(
- &snippet.text[last_offset - range.start..offset - range.start],
- );
- output.push_str(insertion);
- last_offset = *offset;
- excerpt_file_insertions.remove(i);
+ let insertion_line_ix = (point.line.0 - range.start.0) as usize;
+ for line_ix in last_line_ix..insertion_line_ix {
+ push_line(output, line_ix)?;
+ }
+ if let Some(next_line) = lines.get(insertion_line_ix) {
+ if self.request.prompt_format == PromptFormat::NumberedLines {
+ write!(
+ output,
+ "{}|",
+ insertion_line_ix as u32 + range.start.0 + 1
+ )?
+ }
+ output.push_str(&next_line[..point.column as usize]);
+ output.push_str(insertion);
+ writeln!(output, "{}", &next_line[point.column as usize..])?;
+ } else {
+ writeln!(output, "{}", insertion)?;
+ }
+ last_line_ix = insertion_line_ix + 1;
+ excerpt_file_insertions.remove(insertion_ix);
continue;
}
- i += 1;
+ insertion_ix += 1;
}
skipped_last_snippet = false;
- output.push_str(&snippet.text[last_offset - range.start..]);
+ for line_ix in last_line_ix..lines.len() {
+ push_line(output, line_ix)?;
+ }
+ } else {
+ skipped_last_snippet = false;
+ push_full_snippet(output)?;
}
} else {
skipped_last_snippet = false;
- output.push_str(snippet.text);
+ push_full_snippet(output)?;
}
section_ranges.push((snippet.path.clone(), range));
}
- output.push_str("```\n\n");
+ output.push_str("`````\n\n");
}
Ok(SectionLabels {
@@ -1,3 +1,4 @@
+use cloud_llm_client::predict_edits_v3::{self, Line};
use language::{Language, LanguageId};
use project::ProjectEntryId;
use std::ops::Range;
@@ -91,6 +92,18 @@ impl Declaration {
}
}
+ pub fn item_line_range(&self) -> Range<Line> {
+ match self {
+ Declaration::File { declaration, .. } => declaration.item_line_range.clone(),
+ Declaration::Buffer {
+ declaration, rope, ..
+ } => {
+ Line(rope.offset_to_point(declaration.item_range.start).row)
+ ..Line(rope.offset_to_point(declaration.item_range.end).row)
+ }
+ }
+ }
+
pub fn item_text(&self) -> (Cow<'_, str>, bool) {
match self {
Declaration::File { declaration, .. } => (
@@ -130,6 +143,18 @@ impl Declaration {
}
}
+ pub fn signature_line_range(&self) -> Range<Line> {
+ match self {
+ Declaration::File { declaration, .. } => declaration.signature_line_range.clone(),
+ Declaration::Buffer {
+ declaration, rope, ..
+ } => {
+ Line(rope.offset_to_point(declaration.signature_range.start).row)
+ ..Line(rope.offset_to_point(declaration.signature_range.end).row)
+ }
+ }
+ }
+
pub fn signature_range_in_item_text(&self) -> Range<usize> {
let signature_range = self.signature_range();
let item_range = self.item_range();
@@ -142,7 +167,7 @@ fn expand_range_to_line_boundaries_and_truncate(
range: &Range<usize>,
limit: usize,
rope: &Rope,
-) -> (Range<usize>, bool) {
+) -> (Range<usize>, Range<predict_edits_v3::Line>, bool) {
let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end);
point_range.start.column = 0;
point_range.end.row += 1;
@@ -155,7 +180,10 @@ fn expand_range_to_line_boundaries_and_truncate(
item_range.end = item_range.start + limit;
}
item_range.end = rope.clip_offset(item_range.end, Bias::Left);
- (item_range, is_truncated)
+
+ let line_range =
+ predict_edits_v3::Line(point_range.start.row)..predict_edits_v3::Line(point_range.end.row);
+ (item_range, line_range, is_truncated)
}
#[derive(Debug, Clone)]
@@ -164,25 +192,30 @@ pub struct FileDeclaration {
pub identifier: Identifier,
/// offset range of the declaration in the file, expanded to line boundaries and truncated
pub item_range: Range<usize>,
+ /// line range of the declaration in the file, potentially truncated
+ pub item_line_range: Range<predict_edits_v3::Line>,
/// text of `item_range`
pub text: Arc<str>,
/// whether `text` was truncated
pub text_is_truncated: bool,
/// offset range of the signature in the file, expanded to line boundaries and truncated
pub signature_range: Range<usize>,
+ /// line range of the signature in the file, truncated
+ pub signature_line_range: Range<Line>,
/// whether `signature` was truncated
pub signature_is_truncated: bool,
}
impl FileDeclaration {
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration {
- let (item_range_in_file, text_is_truncated) = expand_range_to_line_boundaries_and_truncate(
- &declaration.item_range,
- ITEM_TEXT_TRUNCATION_LENGTH,
- rope,
- );
+ let (item_range_in_file, item_line_range_in_file, text_is_truncated) =
+ expand_range_to_line_boundaries_and_truncate(
+ &declaration.item_range,
+ ITEM_TEXT_TRUNCATION_LENGTH,
+ rope,
+ );
- let (mut signature_range_in_file, mut signature_is_truncated) =
+ let (mut signature_range_in_file, signature_line_range, mut signature_is_truncated) =
expand_range_to_line_boundaries_and_truncate(
&declaration.signature_range,
ITEM_TEXT_TRUNCATION_LENGTH,
@@ -202,6 +235,7 @@ impl FileDeclaration {
parent: None,
identifier: declaration.identifier,
signature_range: signature_range_in_file,
+ signature_line_range,
signature_is_truncated,
text: rope
.chunks_in_range(item_range_in_file.clone())
@@ -209,6 +243,7 @@ impl FileDeclaration {
.into(),
text_is_truncated,
item_range: item_range_in_file,
+ item_line_range: item_line_range_in_file,
}
}
}
@@ -225,12 +260,13 @@ pub struct BufferDeclaration {
impl BufferDeclaration {
pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self {
- let (item_range, item_range_is_truncated) = expand_range_to_line_boundaries_and_truncate(
- &declaration.item_range,
- ITEM_TEXT_TRUNCATION_LENGTH,
- rope,
- );
- let (signature_range, signature_range_is_truncated) =
+ let (item_range, _item_line_range, item_range_is_truncated) =
+ expand_range_to_line_boundaries_and_truncate(
+ &declaration.item_range,
+ ITEM_TEXT_TRUNCATION_LENGTH,
+ rope,
+ );
+ let (signature_range, _signature_line_range, signature_range_is_truncated) =
expand_range_to_line_boundaries_and_truncate(
&declaration.signature_range,
ITEM_TEXT_TRUNCATION_LENGTH,
@@ -9,6 +9,7 @@ pub mod text_similarity;
use std::{path::Path, sync::Arc};
+use cloud_llm_client::predict_edits_v3;
use collections::HashMap;
use gpui::{App, AppContext as _, Entity, Task};
use language::BufferSnapshot;
@@ -21,6 +22,8 @@ pub use imports::*;
pub use reference::*;
pub use syntax_index::*;
+pub use predict_edits_v3::Line;
+
#[derive(Clone, Debug, PartialEq)]
pub struct EditPredictionContextOptions {
pub use_imports: bool,
@@ -32,7 +35,7 @@ pub struct EditPredictionContextOptions {
pub struct EditPredictionContext {
pub excerpt: EditPredictionExcerpt,
pub excerpt_text: EditPredictionExcerptText,
- pub cursor_offset_in_excerpt: usize,
+ pub cursor_point: Point,
pub declarations: Vec<ScoredDeclaration>,
}
@@ -124,8 +127,6 @@ impl EditPredictionContext {
);
let cursor_offset_in_file = cursor_point.to_offset(buffer);
- // TODO fix this to not need saturating_sub
- let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
let declarations = if let Some(index_state) = index_state {
let references = get_references(&excerpt, &excerpt_text, buffer);
@@ -148,7 +149,7 @@ impl EditPredictionContext {
Some(Self {
excerpt,
excerpt_text,
- cursor_offset_in_excerpt,
+ cursor_point,
declarations,
})
}
@@ -4,7 +4,7 @@ use text::{Point, ToOffset as _, ToPoint as _};
use tree_sitter::{Node, TreeCursor};
use util::RangeExt;
-use crate::{BufferDeclaration, declaration::DeclarationId, syntax_index::SyntaxIndexState};
+use crate::{BufferDeclaration, Line, declaration::DeclarationId, syntax_index::SyntaxIndexState};
// TODO:
//
@@ -35,6 +35,7 @@ pub struct EditPredictionExcerptOptions {
#[derive(Debug, Clone)]
pub struct EditPredictionExcerpt {
pub range: Range<usize>,
+ pub line_range: Range<Line>,
pub parent_declarations: Vec<(DeclarationId, Range<usize>)>,
pub size: usize,
}
@@ -86,12 +87,19 @@ impl EditPredictionExcerpt {
buffer.len(),
options.max_bytes
);
- return Some(EditPredictionExcerpt::new(0..buffer.len(), Vec::new()));
+ let offset_range = 0..buffer.len();
+ let line_range = Line(0)..Line(buffer.max_point().row);
+ return Some(EditPredictionExcerpt::new(
+ offset_range,
+ line_range,
+ Vec::new(),
+ ));
}
let query_offset = query_point.to_offset(buffer);
- let query_range = Point::new(query_point.row, 0).to_offset(buffer)
- ..Point::new(query_point.row + 1, 0).to_offset(buffer);
+ let query_line_range = query_point.row..query_point.row + 1;
+ let query_range = Point::new(query_line_range.start, 0).to_offset(buffer)
+ ..Point::new(query_line_range.end, 0).to_offset(buffer);
if query_range.len() >= options.max_bytes {
return None;
}
@@ -107,6 +115,7 @@ impl EditPredictionExcerpt {
let excerpt_selector = ExcerptSelector {
query_offset,
query_range,
+ query_line_range: Line(query_line_range.start)..Line(query_line_range.end),
parent_declarations: &parent_declarations,
buffer,
options,
@@ -130,7 +139,11 @@ impl EditPredictionExcerpt {
excerpt_selector.select_lines()
}
- fn new(range: Range<usize>, parent_declarations: Vec<(DeclarationId, Range<usize>)>) -> Self {
+ fn new(
+ range: Range<usize>,
+ line_range: Range<Line>,
+ parent_declarations: Vec<(DeclarationId, Range<usize>)>,
+ ) -> Self {
let size = range.len()
+ parent_declarations
.iter()
@@ -140,10 +153,11 @@ impl EditPredictionExcerpt {
range,
parent_declarations,
size,
+ line_range,
}
}
- fn with_expanded_range(&self, new_range: Range<usize>) -> Self {
+ fn with_expanded_range(&self, new_range: Range<usize>, new_line_range: Range<Line>) -> Self {
if !new_range.contains_inclusive(&self.range) {
// this is an issue because parent_signature_ranges may be incorrect
log::error!("bug: with_expanded_range called with disjoint range");
@@ -155,7 +169,7 @@ impl EditPredictionExcerpt {
}
parent_declarations.push((*declaration_id, range.clone()));
}
- Self::new(new_range, parent_declarations)
+ Self::new(new_range, new_line_range, parent_declarations)
}
fn parent_signatures_size(&self) -> usize {
@@ -166,6 +180,7 @@ impl EditPredictionExcerpt {
struct ExcerptSelector<'a> {
query_offset: usize,
query_range: Range<usize>,
+ query_line_range: Range<Line>,
parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)],
buffer: &'a BufferSnapshot,
options: &'a EditPredictionExcerptOptions,
@@ -178,10 +193,13 @@ impl<'a> ExcerptSelector<'a> {
let mut cursor = selected_layer_root.walk();
loop {
- let excerpt_range = node_line_start(cursor.node()).to_offset(&self.buffer)
- ..node_line_end(cursor.node()).to_offset(&self.buffer);
+ let line_start = node_line_start(cursor.node());
+ let line_end = node_line_end(cursor.node());
+ let line_range = Line(line_start.row)..Line(line_end.row);
+ let excerpt_range =
+ line_start.to_offset(&self.buffer)..line_end.to_offset(&self.buffer);
if excerpt_range.contains_inclusive(&self.query_range) {
- let excerpt = self.make_excerpt(excerpt_range);
+ let excerpt = self.make_excerpt(excerpt_range, line_range);
if excerpt.size <= self.options.max_bytes {
return Some(self.expand_to_siblings(&mut cursor, excerpt));
}
@@ -272,9 +290,13 @@ impl<'a> ExcerptSelector<'a> {
let mut forward = None;
while !forward_done {
- let new_end = node_line_end(forward_cursor.node()).to_offset(&self.buffer);
+ let new_end_point = node_line_end(forward_cursor.node());
+ let new_end = new_end_point.to_offset(&self.buffer);
if new_end > excerpt.range.end {
- let new_excerpt = excerpt.with_expanded_range(excerpt.range.start..new_end);
+ let new_excerpt = excerpt.with_expanded_range(
+ excerpt.range.start..new_end,
+ excerpt.line_range.start..Line(new_end_point.row),
+ );
if new_excerpt.size <= self.options.max_bytes {
forward = Some(new_excerpt);
break;
@@ -289,9 +311,13 @@ impl<'a> ExcerptSelector<'a> {
let mut backward = None;
while !backward_done {
- let new_start = node_line_start(backward_cursor.node()).to_offset(&self.buffer);
+ let new_start_point = node_line_start(backward_cursor.node());
+ let new_start = new_start_point.to_offset(&self.buffer);
if new_start < excerpt.range.start {
- let new_excerpt = excerpt.with_expanded_range(new_start..excerpt.range.end);
+ let new_excerpt = excerpt.with_expanded_range(
+ new_start..excerpt.range.end,
+ Line(new_start_point.row)..excerpt.line_range.end,
+ );
if new_excerpt.size <= self.options.max_bytes {
backward = Some(new_excerpt);
break;
@@ -339,7 +365,7 @@ impl<'a> ExcerptSelector<'a> {
fn select_lines(&self) -> Option<EditPredictionExcerpt> {
// early return if line containing query_offset is already too large
- let excerpt = self.make_excerpt(self.query_range.clone());
+ let excerpt = self.make_excerpt(self.query_range.clone(), self.query_line_range.clone());
if excerpt.size > self.options.max_bytes {
log::debug!(
"excerpt for cursor line is {} bytes, which exceeds the window",
@@ -353,24 +379,24 @@ impl<'a> ExcerptSelector<'a> {
let before_bytes =
(self.options.target_before_cursor_over_total_bytes * bytes_remaining as f32) as usize;
- let start_point = {
+ let start_line = {
let offset = self.query_offset.saturating_sub(before_bytes);
let point = offset.to_point(self.buffer);
- Point::new(point.row + 1, 0)
+ Line(point.row + 1)
};
- let start_offset = start_point.to_offset(&self.buffer);
- let end_point = {
+ let start_offset = Point::new(start_line.0, 0).to_offset(&self.buffer);
+ let end_line = {
let offset = start_offset + bytes_remaining;
let point = offset.to_point(self.buffer);
- Point::new(point.row, 0)
+ Line(point.row)
};
- let end_offset = end_point.to_offset(&self.buffer);
+ let end_offset = Point::new(end_line.0, 0).to_offset(&self.buffer);
// this could be expanded further since recalculated `signature_size` may be smaller, but
// skipping that for now for simplicity
//
// TODO: could also consider checking if lines immediately before / after fit.
- let excerpt = self.make_excerpt(start_offset..end_offset);
+ let excerpt = self.make_excerpt(start_offset..end_offset, start_line..end_line);
if excerpt.size > self.options.max_bytes {
log::error!(
"bug: line-based excerpt selection has size {}, \
@@ -382,14 +408,14 @@ impl<'a> ExcerptSelector<'a> {
return Some(excerpt);
}
- fn make_excerpt(&self, range: Range<usize>) -> EditPredictionExcerpt {
+ fn make_excerpt(&self, range: Range<usize>, line_range: Range<Line>) -> EditPredictionExcerpt {
let parent_declarations = self
.parent_declarations
.iter()
.filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range))
.map(|(id, declaration)| (*id, declaration.signature_range.clone()))
.collect();
- EditPredictionExcerpt::new(range, parent_declarations)
+ EditPredictionExcerpt::new(range, line_range, parent_declarations)
}
/// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt.
@@ -33,7 +33,7 @@ pub struct EditPrediction {
pub snapshot: BufferSnapshot,
pub edit_preview: EditPreview,
// We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
- _buffer: Entity<Buffer>,
+ pub buffer: Entity<Buffer>,
}
impl EditPrediction {
@@ -108,7 +108,7 @@ impl EditPrediction {
edits,
snapshot,
edit_preview,
- _buffer: buffer,
+ buffer,
})
}
@@ -184,6 +184,10 @@ pub fn interpolate_edits(
if edits.is_empty() { None } else { Some(edits) }
}
+pub fn line_range_to_point_range(range: Range<predict_edits_v3::Line>) -> Range<language::Point> {
+ language::Point::new(range.start.0, 0)..language::Point::new(range.end.0, 0)
+}
+
fn edits_from_response(
edits: &[predict_edits_v3::Edit],
snapshot: &TextBufferSnapshot,
@@ -191,12 +195,14 @@ fn edits_from_response(
edits
.iter()
.flat_map(|edit| {
- let old_text = snapshot.text_for_range(edit.range.clone());
+ let point_range = line_range_to_point_range(edit.range.clone());
+ let offset = point_range.to_offset(snapshot).start;
+ let old_text = snapshot.text_for_range(point_range);
excerpt_edits_from_response(
old_text.collect::<Cow<str>>(),
&edit.content,
- edit.range.start,
+ offset,
&snapshot,
)
})
@@ -252,6 +258,7 @@ mod tests {
use super::*;
use cloud_llm_client::predict_edits_v3;
+ use edit_prediction_context::Line;
use gpui::{App, Entity, TestAppContext, prelude::*};
use indoc::indoc;
use language::{Buffer, ToOffset as _};
@@ -278,7 +285,7 @@ mod tests {
// TODO cover more cases when multi-file is supported
let big_edits = vec![predict_edits_v3::Edit {
path: PathBuf::from("test.txt").into(),
- range: 0..old.len(),
+ range: Line(0)..Line(old.lines().count() as u32),
content: new.into(),
}];
@@ -317,7 +324,7 @@ mod tests {
edits,
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
path: Path::new("test.txt").into(),
- _buffer: buffer.clone(),
+ buffer: buffer.clone(),
edit_preview,
};
@@ -17,8 +17,8 @@ use gpui::{
App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
http_client, prelude::*,
};
+use language::BufferSnapshot;
use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
-use language::{BufferSnapshot, TextBufferSnapshot};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use project::Project;
use release_channel::AppVersion;
@@ -106,30 +106,40 @@ struct ZetaProject {
current_prediction: Option<CurrentEditPrediction>,
}
-#[derive(Clone)]
+#[derive(Debug, Clone)]
struct CurrentEditPrediction {
pub requested_by_buffer_id: EntityId,
pub prediction: EditPrediction,
}
impl CurrentEditPrediction {
- fn should_replace_prediction(
- &self,
- old_prediction: &Self,
- snapshot: &TextBufferSnapshot,
- ) -> bool {
- if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
+ fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
+ let Some(new_edits) = self
+ .prediction
+ .interpolate(&self.prediction.buffer.read(cx))
+ else {
+ return false;
+ };
+
+ if self.prediction.buffer != old_prediction.prediction.buffer {
return true;
}
- let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
+ let Some(old_edits) = old_prediction
+ .prediction
+ .interpolate(&old_prediction.prediction.buffer.read(cx))
+ else {
return true;
};
- let Some(new_edits) = self.prediction.interpolate(snapshot) else {
- return false;
- };
- if old_edits.len() == 1 && new_edits.len() == 1 {
+ // This reduces the occurrence of UI thrash from replacing edits
+ //
+ // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
+ if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
+ && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
+ && old_edits.len() == 1
+ && new_edits.len() == 1
+ {
let (old_range, old_text) = &old_edits[0];
let (new_range, new_text) = &new_edits[0];
new_range == old_range && new_text.starts_with(old_text)
@@ -421,8 +431,7 @@ impl Zeta {
.current_prediction
.as_ref()
.is_none_or(|old_prediction| {
- new_prediction
- .should_replace_prediction(&old_prediction, buffer.read(cx))
+ new_prediction.should_replace_prediction(&old_prediction, cx)
})
{
project_state.current_prediction = Some(new_prediction);
@@ -926,7 +935,7 @@ fn make_cloud_request(
referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
path: path.as_std_path().into(),
text: text.into(),
- range: snippet.declaration.item_range(),
+ range: snippet.declaration.item_line_range(),
text_is_truncated,
signature_range: snippet.declaration.signature_range_in_item_text(),
parent_index,
@@ -954,8 +963,12 @@ fn make_cloud_request(
predict_edits_v3::PredictEditsRequest {
excerpt_path,
excerpt: context.excerpt_text.body,
+ excerpt_line_range: context.excerpt.line_range,
excerpt_range: context.excerpt.range,
- cursor_offset: context.cursor_offset_in_excerpt,
+ cursor_point: predict_edits_v3::Point {
+ line: predict_edits_v3::Line(context.cursor_point.row),
+ column: context.cursor_point.column,
+ },
referenced_declarations,
signatures,
excerpt_parent,
@@ -992,7 +1005,7 @@ fn add_signature(
text: text.into(),
text_is_truncated,
parent_index,
- range: parent_declaration.signature_range(),
+ range: parent_declaration.signature_line_range(),
});
declaration_to_signature_index.insert(declaration_id, signature_index);
Some(signature_index)
@@ -1007,7 +1020,8 @@ mod tests {
use client::UserStore;
use clock::FakeSystemClock;
- use cloud_llm_client::predict_edits_v3;
+ use cloud_llm_client::predict_edits_v3::{self, Point};
+ use edit_prediction_context::Line;
use futures::{
AsyncReadExt, StreamExt,
channel::{mpsc, oneshot},
@@ -1067,7 +1081,7 @@ mod tests {
request_id: Uuid::new_v4(),
edits: vec![predict_edits_v3::Edit {
path: Path::new(path!("root/1.txt")).into(),
- range: 0..snapshot1.len(),
+ range: Line(0)..Line(snapshot1.max_point().row + 1),
content: "Hello!\nHow are you?\nBye".into(),
}],
debug_info: None,
@@ -1083,7 +1097,6 @@ mod tests {
});
// Prediction for another file
-
let prediction_task = zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction(&project, &buffer1, position, cx)
});
@@ -1093,14 +1106,13 @@ mod tests {
request_id: Uuid::new_v4(),
edits: vec![predict_edits_v3::Edit {
path: Path::new(path!("root/2.txt")).into(),
- range: 0..snapshot1.len(),
+ range: Line(0)..Line(snapshot1.max_point().row + 1),
content: "Hola!\nComo estas?\nAdios".into(),
}],
debug_info: None,
})
.unwrap();
prediction_task.await.unwrap();
-
zeta.read_with(cx, |zeta, cx| {
let prediction = zeta
.current_prediction_for_buffer(&buffer1, &project, cx)
@@ -1159,14 +1171,20 @@ mod tests {
request.excerpt_path.as_ref(),
Path::new(path!("root/foo.md"))
);
- assert_eq!(request.cursor_offset, 10);
+ assert_eq!(
+ request.cursor_point,
+ Point {
+ line: Line(1),
+ column: 3
+ }
+ );
respond_tx
.send(predict_edits_v3::PredictEditsResponse {
request_id: Uuid::new_v4(),
edits: vec![predict_edits_v3::Edit {
path: Path::new(path!("root/foo.md")).into(),
- range: 0..snapshot.len(),
+ range: Line(0)..Line(snapshot.max_point().row + 1),
content: "Hello!\nHow are you?\nBye".into(),
}],
debug_info: None,
@@ -1244,7 +1262,7 @@ mod tests {
request_id: Uuid::new_v4(),
edits: vec![predict_edits_v3::Edit {
path: Path::new(path!("root/foo.md")).into(),
- range: 0..snapshot.len(),
+ range: Line(0)..Line(snapshot.max_point().row + 1),
content: "Hello!\nHow are you?\nBye".into(),
}],
debug_info: None,
@@ -98,10 +98,11 @@ struct Zeta2Args {
#[derive(clap::ValueEnum, Default, Debug, Clone)]
enum PromptFormat {
- #[default]
MarkedExcerpt,
LabeledSections,
OnlySnippets,
+ #[default]
+ NumberedLines,
}
impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
@@ -110,6 +111,7 @@ impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
Self::MarkedExcerpt => predict_edits_v3::PromptFormat::MarkedExcerpt,
Self::LabeledSections => predict_edits_v3::PromptFormat::LabeledSections,
Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
+ Self::NumberedLines => predict_edits_v3::PromptFormat::NumberedLines,
}
}
}