Detailed changes
@@ -39,6 +39,7 @@ dependencies = [
"util",
"uuid",
"watch",
+ "zlog",
]
[[package]]
@@ -3198,7 +3199,9 @@ dependencies = [
"indoc",
"ordered-float 2.10.1",
"rustc-hash 2.1.1",
+ "schemars 1.0.4",
"serde",
+ "serde_json",
"strum 0.27.2",
]
@@ -21675,10 +21678,10 @@ dependencies = [
"language_model",
"log",
"lsp",
+ "open_ai",
"pretty_assertions",
"project",
"release_channel",
- "schemars 1.0.4",
"serde",
"serde_json",
"settings",
@@ -21687,6 +21690,7 @@ dependencies = [
"uuid",
"workspace",
"worktree",
+ "zlog",
]
[[package]]
@@ -56,3 +56,4 @@ rand.workspace = true
tempfile.workspace = true
util.workspace = true
settings.workspace = true
+zlog.workspace = true
@@ -1,5 +1,4 @@
pub mod predict_edits_v3;
-pub mod udiff;
use std::str::FromStr;
use std::sync::Arc;
@@ -1,7 +1,7 @@
use chrono::Duration;
use serde::{Deserialize, Serialize};
use std::{
- fmt::Display,
+ fmt::{Display, Write as _},
ops::{Add, Range, Sub},
path::{Path, PathBuf},
sync::Arc,
@@ -11,7 +11,14 @@ use uuid::Uuid;
use crate::PredictEditsGitInfo;
-// TODO: snippet ordering within file / relative to excerpt
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct PlanContextRetrievalRequest {
+ pub excerpt: String,
+ pub excerpt_path: Arc<Path>,
+ pub excerpt_line_range: Range<Line>,
+ pub cursor_file_max_row: Line,
+ pub events: Vec<Event>,
+}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictEditsRequest {
@@ -125,15 +132,15 @@ impl Display for Event {
write!(
f,
"// User accepted prediction:\n--- a/{}\n+++ b/{}\n{diff}",
- old_path.display(),
- new_path.display()
+ DiffPathFmt(old_path),
+ DiffPathFmt(new_path)
)
} else {
write!(
f,
"--- a/{}\n+++ b/{}\n{diff}",
- old_path.display(),
- new_path.display()
+ DiffPathFmt(old_path),
+ DiffPathFmt(new_path)
)
}
}
@@ -141,6 +148,24 @@ impl Display for Event {
}
}
+/// always format the Path as a unix path with `/` as the path sep in Diffs
+pub struct DiffPathFmt<'a>(pub &'a Path);
+
+impl<'a> std::fmt::Display for DiffPathFmt<'a> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let mut is_first = true;
+ for component in self.0.components() {
+ if !is_first {
+ f.write_char('/')?;
+ } else {
+ is_first = false;
+ }
+ write!(f, "{}", component.as_os_str().display())?;
+ }
+ Ok(())
+ }
+}
+
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Signature {
pub text: String,
@@ -1,294 +0,0 @@
-use std::{borrow::Cow, fmt::Display};
-
-#[derive(Debug, PartialEq)]
-pub enum DiffLine<'a> {
- OldPath { path: Cow<'a, str> },
- NewPath { path: Cow<'a, str> },
- HunkHeader(Option<HunkLocation>),
- Context(&'a str),
- Deletion(&'a str),
- Addition(&'a str),
- Garbage(&'a str),
-}
-
-#[derive(Debug, PartialEq)]
-pub struct HunkLocation {
- start_line_old: u32,
- count_old: u32,
- start_line_new: u32,
- count_new: u32,
-}
-
-impl<'a> DiffLine<'a> {
- pub fn parse(line: &'a str) -> Self {
- Self::try_parse(line).unwrap_or(Self::Garbage(line))
- }
-
- fn try_parse(line: &'a str) -> Option<Self> {
- if let Some(header) = line.strip_prefix("---").and_then(eat_required_whitespace) {
- let path = parse_header_path("a/", header);
- Some(Self::OldPath { path })
- } else if let Some(header) = line.strip_prefix("+++").and_then(eat_required_whitespace) {
- Some(Self::NewPath {
- path: parse_header_path("b/", header),
- })
- } else if let Some(header) = line.strip_prefix("@@").and_then(eat_required_whitespace) {
- if header.starts_with("...") {
- return Some(Self::HunkHeader(None));
- }
-
- let (start_line_old, header) = header.strip_prefix('-')?.split_once(',')?;
- let mut parts = header.split_ascii_whitespace();
- let count_old = parts.next()?;
- let (start_line_new, count_new) = parts.next()?.strip_prefix('+')?.split_once(',')?;
-
- Some(Self::HunkHeader(Some(HunkLocation {
- start_line_old: start_line_old.parse::<u32>().ok()?.saturating_sub(1),
- count_old: count_old.parse().ok()?,
- start_line_new: start_line_new.parse::<u32>().ok()?.saturating_sub(1),
- count_new: count_new.parse().ok()?,
- })))
- } else if let Some(deleted_header) = line.strip_prefix("-") {
- Some(Self::Deletion(deleted_header))
- } else if line.is_empty() {
- Some(Self::Context(""))
- } else if let Some(context) = line.strip_prefix(" ") {
- Some(Self::Context(context))
- } else {
- Some(Self::Addition(line.strip_prefix("+")?))
- }
- }
-}
-
-impl<'a> Display for DiffLine<'a> {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- DiffLine::OldPath { path } => write!(f, "--- {path}"),
- DiffLine::NewPath { path } => write!(f, "+++ {path}"),
- DiffLine::HunkHeader(Some(hunk_location)) => {
- write!(
- f,
- "@@ -{},{} +{},{} @@",
- hunk_location.start_line_old + 1,
- hunk_location.count_old,
- hunk_location.start_line_new + 1,
- hunk_location.count_new
- )
- }
- DiffLine::HunkHeader(None) => write!(f, "@@ ... @@"),
- DiffLine::Context(content) => write!(f, " {content}"),
- DiffLine::Deletion(content) => write!(f, "-{content}"),
- DiffLine::Addition(content) => write!(f, "+{content}"),
- DiffLine::Garbage(line) => write!(f, "{line}"),
- }
- }
-}
-
-fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> {
- if !header.contains(['"', '\\']) {
- let path = header.split_ascii_whitespace().next().unwrap_or(header);
- return Cow::Borrowed(path.strip_prefix(strip_prefix).unwrap_or(path));
- }
-
- let mut path = String::with_capacity(header.len());
- let mut in_quote = false;
- let mut chars = header.chars().peekable();
- let mut strip_prefix = Some(strip_prefix);
-
- while let Some(char) = chars.next() {
- if char == '"' {
- in_quote = !in_quote;
- } else if char == '\\' {
- let Some(&next_char) = chars.peek() else {
- break;
- };
- chars.next();
- path.push(next_char);
- } else if char.is_ascii_whitespace() && !in_quote {
- break;
- } else {
- path.push(char);
- }
-
- if let Some(prefix) = strip_prefix
- && path == prefix
- {
- strip_prefix.take();
- path.clear();
- }
- }
-
- Cow::Owned(path)
-}
-
-fn eat_required_whitespace(header: &str) -> Option<&str> {
- let trimmed = header.trim_ascii_start();
-
- if trimmed.len() == header.len() {
- None
- } else {
- Some(trimmed)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use indoc::indoc;
-
- #[test]
- fn parse_lines_simple() {
- let input = indoc! {"
- diff --git a/text.txt b/text.txt
- index 86c770d..a1fd855 100644
- --- a/file.txt
- +++ b/file.txt
- @@ -1,2 +1,3 @@
- context
- -deleted
- +inserted
- garbage
-
- --- b/file.txt
- +++ a/file.txt
- "};
-
- let lines = input.lines().map(DiffLine::parse).collect::<Vec<_>>();
-
- pretty_assertions::assert_eq!(
- lines,
- &[
- DiffLine::Garbage("diff --git a/text.txt b/text.txt"),
- DiffLine::Garbage("index 86c770d..a1fd855 100644"),
- DiffLine::OldPath {
- path: "file.txt".into()
- },
- DiffLine::NewPath {
- path: "file.txt".into()
- },
- DiffLine::HunkHeader(Some(HunkLocation {
- start_line_old: 0,
- count_old: 2,
- start_line_new: 0,
- count_new: 3
- })),
- DiffLine::Context("context"),
- DiffLine::Deletion("deleted"),
- DiffLine::Addition("inserted"),
- DiffLine::Garbage("garbage"),
- DiffLine::Context(""),
- DiffLine::OldPath {
- path: "b/file.txt".into()
- },
- DiffLine::NewPath {
- path: "a/file.txt".into()
- },
- ]
- );
- }
-
- #[test]
- fn file_header_extra_space() {
- let options = ["--- file", "--- file", "---\tfile"];
-
- for option in options {
- pretty_assertions::assert_eq!(
- DiffLine::parse(option),
- DiffLine::OldPath {
- path: "file".into()
- },
- "{option}",
- );
- }
- }
-
- #[test]
- fn hunk_header_extra_space() {
- let options = [
- "@@ -1,2 +1,3 @@",
- "@@ -1,2 +1,3 @@",
- "@@\t-1,2\t+1,3\t@@",
- "@@ -1,2 +1,3 @@",
- "@@ -1,2 +1,3 @@",
- "@@ -1,2 +1,3 @@",
- "@@ -1,2 +1,3 @@ garbage",
- ];
-
- for option in options {
- pretty_assertions::assert_eq!(
- DiffLine::parse(option),
- DiffLine::HunkHeader(Some(HunkLocation {
- start_line_old: 0,
- count_old: 2,
- start_line_new: 0,
- count_new: 3
- })),
- "{option}",
- );
- }
- }
-
- #[test]
- fn hunk_header_without_location() {
- pretty_assertions::assert_eq!(DiffLine::parse("@@ ... @@"), DiffLine::HunkHeader(None));
- }
-
- #[test]
- fn test_parse_path() {
- assert_eq!(parse_header_path("a/", "foo.txt"), "foo.txt");
- assert_eq!(
- parse_header_path("a/", "foo/bar/baz.txt"),
- "foo/bar/baz.txt"
- );
- assert_eq!(parse_header_path("a/", "a/foo.txt"), "foo.txt");
- assert_eq!(
- parse_header_path("a/", "a/foo/bar/baz.txt"),
- "foo/bar/baz.txt"
- );
-
- // Extra
- assert_eq!(
- parse_header_path("a/", "a/foo/bar/baz.txt 2025"),
- "foo/bar/baz.txt"
- );
- assert_eq!(
- parse_header_path("a/", "a/foo/bar/baz.txt\t2025"),
- "foo/bar/baz.txt"
- );
- assert_eq!(
- parse_header_path("a/", "a/foo/bar/baz.txt \""),
- "foo/bar/baz.txt"
- );
-
- // Quoted
- assert_eq!(
- parse_header_path("a/", "a/foo/bar/\"baz quox.txt\""),
- "foo/bar/baz quox.txt"
- );
- assert_eq!(
- parse_header_path("a/", "\"a/foo/bar/baz quox.txt\""),
- "foo/bar/baz quox.txt"
- );
- assert_eq!(
- parse_header_path("a/", "\"foo/bar/baz quox.txt\""),
- "foo/bar/baz quox.txt"
- );
- assert_eq!(parse_header_path("a/", "\"whatever ๐คท\""), "whatever ๐คท");
- assert_eq!(
- parse_header_path("a/", "\"foo/bar/baz quox.txt\" 2025"),
- "foo/bar/baz quox.txt"
- );
- // unescaped quotes are dropped
- assert_eq!(parse_header_path("a/", "foo/\"bar\""), "foo/bar");
-
- // Escaped
- assert_eq!(
- parse_header_path("a/", "\"foo/\\\"bar\\\"/baz.txt\""),
- "foo/\"bar\"/baz.txt"
- );
- assert_eq!(
- parse_header_path("a/", "\"C:\\\\Projects\\\\My App\\\\old file.txt\""),
- "C:\\Projects\\My App\\old file.txt"
- );
- }
-}
@@ -17,5 +17,7 @@ cloud_llm_client.workspace = true
indoc.workspace = true
ordered-float.workspace = true
rustc-hash.workspace = true
+schemars.workspace = true
serde.workspace = true
+serde_json.workspace = true
strum.workspace = true
@@ -1,8 +1,9 @@
//! Zeta2 prompt planning and generation code shared with cloud.
+pub mod retrieval_prompt;
use anyhow::{Context as _, Result, anyhow};
use cloud_llm_client::predict_edits_v3::{
- self, Excerpt, Line, Point, PromptFormat, ReferencedDeclaration,
+ self, DiffPathFmt, Excerpt, Line, Point, PromptFormat, ReferencedDeclaration,
};
use indoc::indoc;
use ordered_float::OrderedFloat;
@@ -212,7 +213,7 @@ pub fn write_codeblock<'a>(
include_line_numbers: bool,
output: &'a mut String,
) {
- writeln!(output, "`````{}", path.display()).unwrap();
+ writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
write_excerpts(
excerpts,
sorted_insertions,
@@ -275,7 +276,7 @@ pub fn write_excerpts<'a>(
}
}
-fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
+pub fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
if events.is_empty() {
return;
};
@@ -0,0 +1,92 @@
+use anyhow::Result;
+use cloud_llm_client::predict_edits_v3::{self, Excerpt};
+use indoc::indoc;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use std::{fmt::Write, sync::LazyLock};
+
+use crate::{push_events, write_codeblock};
+
+pub fn build_prompt(request: predict_edits_v3::PlanContextRetrievalRequest) -> Result<String> {
+ let mut prompt = SEARCH_INSTRUCTIONS.to_string();
+
+ if !request.events.is_empty() {
+ writeln!(&mut prompt, "## User Edits\n")?;
+ push_events(&mut prompt, &request.events);
+ }
+
+ writeln!(&mut prompt, "## Excerpt around the cursor\n")?;
+ write_codeblock(
+ &request.excerpt_path,
+ &[Excerpt {
+ start_line: request.excerpt_line_range.start,
+ text: request.excerpt.into(),
+ }],
+ &[],
+ request.cursor_file_max_row,
+ true,
+ &mut prompt,
+ );
+
+ writeln!(&mut prompt, "{TOOL_USE_REMINDER}")?;
+
+ Ok(prompt)
+}
+
+/// Search for relevant code
+///
+/// For the best results, run multiple queries at once with a single invocation of this tool.
+#[derive(Clone, Deserialize, Serialize, JsonSchema)]
+pub struct SearchToolInput {
+ /// An array of queries to run for gathering context relevant to the next prediction
+ #[schemars(length(max = 5))]
+ pub queries: Box<[SearchToolQuery]>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
+pub struct SearchToolQuery {
+ /// A glob pattern to match file paths in the codebase
+ pub glob: String,
+ /// A regular expression to match content within the files matched by the glob pattern
+ pub regex: String,
+}
+
+pub static TOOL_SCHEMA: LazyLock<(serde_json::Value, String)> = LazyLock::new(|| {
+ let schema = schemars::schema_for!(SearchToolInput);
+
+ let description = schema
+ .get("description")
+ .and_then(|description| description.as_str())
+ .unwrap()
+ .to_string();
+
+ (schema.into(), description)
+});
+
+pub const TOOL_NAME: &str = "search";
+
+const SEARCH_INSTRUCTIONS: &str = indoc! {r#"
+ ## Task
+
+ You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
+ that will serve as context for predicting the next required edit.
+
+ **Your task:**
+ - Analyze the user's recent edits and current cursor context
+ - Use the `search` tool to find code that may be relevant for predicting the next edit
+ - Focus on finding:
+ - Code patterns that might need similar changes based on the recent edits
+ - Functions, variables, types, and constants referenced in the current cursor context
+ - Related implementations, usages, or dependencies that may require consistent updates
+
+ **Important constraints:**
+ - This conversation has exactly 2 turns
+ - You must make ALL search queries in your first response via the `search` tool
+ - All queries will be executed in parallel and results returned together
+ - In the second turn, you will select the most relevant results via the `select` tool.
+"#};
+
+const TOOL_USE_REMINDER: &str = indoc! {"
+ --
+ Use the `search` tool now
+"};
@@ -34,7 +34,7 @@ struct CurrentCompletion {
snapshot: BufferSnapshot,
/// The edits that should be applied to transform the original text into the predicted text.
/// Each edit is a range in the buffer and the text to replace it with.
- edits: Arc<[(Range<Anchor>, String)]>,
+ edits: Arc<[(Range<Anchor>, Arc<str>)]>,
/// Preview of how the buffer will look after applying the edits.
edit_preview: EditPreview,
}
@@ -42,7 +42,7 @@ struct CurrentCompletion {
impl CurrentCompletion {
/// Attempts to adjust the edits based on changes made to the buffer since the completion was generated.
/// Returns None if the user's edits conflict with the predicted edits.
- fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
+ fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
}
}
@@ -281,8 +281,8 @@ impl EditPredictionProvider for CodestralCompletionProvider {
return Ok(());
}
- let edits: Arc<[(Range<Anchor>, String)]> =
- vec![(cursor_position..cursor_position, completion_text)].into();
+ let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
+ vec![(cursor_position..cursor_position, completion_text.into())].into();
let edit_preview = buffer
.read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))?
.await;
@@ -1,4 +1,4 @@
-use std::ops::Range;
+use std::{ops::Range, sync::Arc};
use client::EditPredictionUsage;
use gpui::{App, Context, Entity, SharedString};
@@ -19,7 +19,7 @@ pub enum EditPrediction {
/// Edits within the buffer that requested the prediction
Local {
id: Option<SharedString>,
- edits: Vec<(Range<language::Anchor>, String)>,
+ edits: Vec<(Range<language::Anchor>, Arc<str>)>,
edit_preview: Option<language::EditPreview>,
},
/// Jump to a different file from the one that requested the prediction
@@ -248,8 +248,8 @@ where
pub fn interpolate_edits(
old_snapshot: &BufferSnapshot,
new_snapshot: &BufferSnapshot,
- current_edits: &[(Range<Anchor>, String)],
-) -> Option<Vec<(Range<Anchor>, String)>> {
+ current_edits: &[(Range<Anchor>, Arc<str>)],
+) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
let mut edits = Vec::new();
let mut model_edits = current_edits.iter().peekable();
@@ -274,7 +274,7 @@ pub fn interpolate_edits(
if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
if !model_suffix.is_empty() {
let anchor = old_snapshot.anchor_after(user_edit.old.end);
- edits.push((anchor..anchor, model_suffix.to_string()));
+ edits.push((anchor..anchor, model_suffix.into()));
}
model_edits.next();
@@ -2,7 +2,7 @@ use edit_prediction::EditPredictionProvider;
use gpui::{Entity, KeyBinding, Modifiers, prelude::*};
use indoc::indoc;
use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint};
-use std::ops::Range;
+use std::{ops::Range, sync::Arc};
use text::{Point, ToOffset};
use crate::{
@@ -24,7 +24,7 @@ async fn test_edit_prediction_insert(cx: &mut gpui::TestAppContext) {
assert_editor_active_edit_completion(&mut cx, |_, edits| {
assert_eq!(edits.len(), 1);
- assert_eq!(edits[0].1.as_str(), "-273.15");
+ assert_eq!(edits[0].1.as_ref(), "-273.15");
});
accept_completion(&mut cx);
@@ -46,7 +46,7 @@ async fn test_edit_prediction_modification(cx: &mut gpui::TestAppContext) {
assert_editor_active_edit_completion(&mut cx, |_, edits| {
assert_eq!(edits.len(), 1);
- assert_eq!(edits[0].1.as_str(), "3.14159");
+ assert_eq!(edits[0].1.as_ref(), "3.14159");
});
accept_completion(&mut cx);
@@ -330,7 +330,7 @@ async fn test_edit_prediction_preview_cleanup_on_toggle_off(cx: &mut gpui::TestA
fn assert_editor_active_edit_completion(
cx: &mut EditorTestContext,
- assert: impl FnOnce(MultiBufferSnapshot, &Vec<(Range<Anchor>, String)>),
+ assert: impl FnOnce(MultiBufferSnapshot, &Vec<(Range<Anchor>, Arc<str>)>),
) {
cx.editor(|editor, _, cx| {
let completion_state = editor
@@ -616,7 +616,7 @@ pub(crate) enum EditDisplayMode {
enum EditPrediction {
Edit {
- edits: Vec<(Range<Anchor>, String)>,
+ edits: Vec<(Range<Anchor>, Arc<str>)>,
edit_preview: Option<EditPreview>,
display_mode: EditDisplayMode,
snapshot: BufferSnapshot,
@@ -7960,7 +7960,7 @@ impl Editor {
let inlay = Inlay::edit_prediction(
post_inc(&mut self.next_inlay_id),
range.start,
- new_text.as_str(),
+ new_text.as_ref(),
);
inlay_ids.push(inlay.id);
inlays.push(inlay);
@@ -8982,7 +8982,7 @@ impl Editor {
newest_selection_head: Option<DisplayPoint>,
editor_width: Pixels,
style: &EditorStyle,
- edits: &Vec<(Range<Anchor>, String)>,
+ edits: &Vec<(Range<Anchor>, Arc<str>)>,
edit_preview: &Option<language::EditPreview>,
snapshot: &language::BufferSnapshot,
window: &mut Window,
@@ -24382,25 +24382,20 @@ impl InvalidationRegion for SnippetState {
fn edit_prediction_edit_text(
current_snapshot: &BufferSnapshot,
- edits: &[(Range<Anchor>, String)],
+ edits: &[(Range<Anchor>, impl AsRef<str>)],
edit_preview: &EditPreview,
include_deletions: bool,
cx: &App,
) -> HighlightedText {
let edits = edits
.iter()
- .map(|(anchor, text)| {
- (
- anchor.start.text_anchor..anchor.end.text_anchor,
- text.clone(),
- )
- })
+ .map(|(anchor, text)| (anchor.start.text_anchor..anchor.end.text_anchor, text))
.collect::<Vec<_>>();
edit_preview.highlight_edits(current_snapshot, &edits, include_deletions, cx)
}
-fn edit_prediction_fallback_text(edits: &[(Range<Anchor>, String)], cx: &App) -> HighlightedText {
+fn edit_prediction_fallback_text(edits: &[(Range<Anchor>, Arc<str>)], cx: &App) -> HighlightedText {
// Fallback for providers that don't provide edit_preview (like Copilot/Supermaven)
// Just show the raw edit text with basic styling
let mut text = String::new();
@@ -24793,7 +24788,7 @@ impl Focusable for BreakpointPromptEditor {
}
fn all_edits_insertions_or_deletions(
- edits: &Vec<(Range<Anchor>, String)>,
+ edits: &Vec<(Range<Anchor>, Arc<str>)>,
snapshot: &MultiBufferSnapshot,
) -> bool {
let mut all_insertions = true;
@@ -22915,7 +22915,7 @@ async fn assert_highlighted_edits(
let text_anchor_edits = edits
.clone()
.into_iter()
- .map(|(range, edit)| (range.start.text_anchor..range.end.text_anchor, edit))
+ .map(|(range, edit)| (range.start.text_anchor..range.end.text_anchor, edit.into()))
.collect::<Vec<_>>();
let edit_preview = window
@@ -720,7 +720,7 @@ impl EditPreview {
pub fn highlight_edits(
&self,
current_snapshot: &BufferSnapshot,
- edits: &[(Range<Anchor>, String)],
+ edits: &[(Range<Anchor>, impl AsRef<str>)],
include_deletions: bool,
cx: &App,
) -> HighlightedText {
@@ -747,7 +747,8 @@ impl EditPreview {
.end
.bias_right(&self.old_snapshot)
.to_offset(&self.applied_edits_snapshot);
- let edit_start_in_preview_snapshot = edit_new_end_in_preview_snapshot - edit_text.len();
+ let edit_start_in_preview_snapshot =
+ edit_new_end_in_preview_snapshot - edit_text.as_ref().len();
let unchanged_range_in_preview_snapshot =
offset_in_preview_snapshot..edit_start_in_preview_snapshot;
@@ -772,7 +773,7 @@ impl EditPreview {
);
}
- if !edit_text.is_empty() {
+ if !edit_text.as_ref().is_empty() {
highlighted_text.add_text_from_buffer_range(
edit_start_in_preview_snapshot..edit_new_end_in_preview_snapshot,
&self.applied_edits_snapshot,
@@ -796,7 +797,7 @@ impl EditPreview {
highlighted_text.build()
}
- fn compute_visible_range(&self, edits: &[(Range<Anchor>, String)]) -> Option<Range<usize>> {
+ fn compute_visible_range<T>(&self, edits: &[(Range<Anchor>, T)]) -> Option<Range<usize>> {
let (first, _) = edits.first()?;
let (last, _) = edits.last()?;
@@ -1130,7 +1131,7 @@ impl Buffer {
pub fn preview_edits(
&self,
- edits: Arc<[(Range<Anchor>, String)]>,
+ edits: Arc<[(Range<Anchor>, Arc<str>)]>,
cx: &App,
) -> Task<EditPreview> {
let registry = self.language_registry();
@@ -3120,15 +3120,13 @@ async fn test_preview_edits(cx: &mut TestAppContext) {
.map(|(range, text)| {
(
buffer.anchor_before(range.start)..buffer.anchor_after(range.end),
- text.to_string(),
+ text.into(),
)
})
- .collect::<Vec<_>>()
+ .collect::<Arc<[_]>>()
});
let edit_preview = buffer
- .read_with(cx, |buffer, cx| {
- buffer.preview_edits(edits.clone().into(), cx)
- })
+ .read_with(cx, |buffer, cx| buffer.preview_edits(edits.clone(), cx))
.await;
let highlighted_edits = cx.read(|cx| {
edit_preview.highlight_edits(&buffer.read(cx).snapshot(), &edits, include_deletions, cx)
@@ -293,7 +293,7 @@ pub struct FunctionDefinition {
pub parameters: Option<Value>,
}
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum RequestMessage {
Assistant {
@@ -366,25 +366,42 @@ pub struct ImageUrl {
pub detail: Option<String>,
}
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ToolCall {
pub id: String,
#[serde(flatten)]
pub content: ToolCallContent,
}
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ToolCallContent {
Function { function: FunctionContent },
}
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct FunctionContent {
pub name: String,
pub arguments: String,
}
+#[derive(Clone, Serialize, Deserialize, Debug)]
+pub struct Response {
+ pub id: String,
+ pub object: String,
+ pub created: u64,
+ pub model: String,
+ pub choices: Vec<Choice>,
+ pub usage: Usage,
+}
+
+#[derive(Clone, Serialize, Deserialize, Debug)]
+pub struct Choice {
+ pub index: u32,
+ pub message: RequestMessage,
+ pub finish_reason: Option<String>,
+}
+
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct ResponseMessageDelta {
pub role: Option<Role>,
@@ -410,7 +427,7 @@ pub struct FunctionChunk {
pub arguments: Option<String>,
}
-#[derive(Serialize, Deserialize, Debug)]
+#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct Usage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
@@ -7,6 +7,7 @@ use language::{Anchor, Buffer, BufferSnapshot};
use std::{
ops::{AddAssign, Range},
path::Path,
+ sync::Arc,
time::Duration,
};
use text::{ToOffset, ToPoint};
@@ -51,7 +52,7 @@ fn completion_from_diff(
) -> EditPrediction {
let buffer_text = snapshot.text_for_range(delete_range).collect::<String>();
- let mut edits: Vec<(Range<language::Anchor>, String)> = Vec::new();
+ let mut edits: Vec<(Range<language::Anchor>, Arc<str>)> = Vec::new();
let completion_graphemes: Vec<&str> = completion_text.graphemes(true).collect();
let buffer_graphemes: Vec<&str> = buffer_text.graphemes(true).collect();
@@ -70,7 +71,10 @@ fn completion_from_diff(
if k != 0 {
let offset = snapshot.anchor_after(offset);
// the range from the current position to item is an inlay.
- let edit = (offset..offset, completion_graphemes[i..i + k].join(""));
+ let edit = (
+ offset..offset,
+ completion_graphemes[i..i + k].join("").into(),
+ );
edits.push(edit);
}
i += k + 1;
@@ -90,7 +94,7 @@ fn completion_from_diff(
// there is leftover completion text, so drop it as an inlay.
let edit_range = offset..offset;
let edit_text = completion_graphemes[i..].join("");
- edits.push((edit_range, edit_text));
+ edits.push((edit_range, edit_text.into()));
}
EditPrediction::Local {
@@ -133,7 +133,7 @@ pub struct EditPrediction {
path: Arc<Path>,
excerpt_range: Range<usize>,
cursor_offset: usize,
- edits: Arc<[(Range<Anchor>, String)]>,
+ edits: Arc<[(Range<Anchor>, Arc<str>)]>,
snapshot: BufferSnapshot,
edit_preview: EditPreview,
input_outline: Arc<str>,
@@ -150,7 +150,7 @@ impl EditPrediction {
.duration_since(self.buffer_snapshotted_at)
}
- fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
+ fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
edit_prediction::interpolate_edits(&self.snapshot, new_snapshot, &self.edits)
}
}
@@ -711,7 +711,7 @@ impl Zeta {
cx.spawn(async move |cx| {
let output_excerpt: Arc<str> = output_excerpt.into();
- let edits: Arc<[(Range<Anchor>, String)]> = cx
+ let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
.background_spawn({
let output_excerpt = output_excerpt.clone();
let editable_range = editable_range.clone();
@@ -725,7 +725,7 @@ impl Zeta {
let edits = edits.clone();
move |buffer, cx| {
let new_snapshot = buffer.snapshot();
- let edits: Arc<[(Range<Anchor>, String)]> =
+ let edits: Arc<[(Range<Anchor>, Arc<str>)]> =
edit_prediction::interpolate_edits(&snapshot, &new_snapshot, &edits)?
.into();
Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
@@ -759,7 +759,7 @@ impl Zeta {
output_excerpt: Arc<str>,
editable_range: Range<usize>,
snapshot: &BufferSnapshot,
- ) -> Result<Vec<(Range<Anchor>, String)>> {
+ ) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
let content = output_excerpt.replace(CURSOR_MARKER, "");
let start_markers = content
@@ -817,7 +817,7 @@ impl Zeta {
new_text: &str,
offset: usize,
snapshot: &BufferSnapshot,
- ) -> Vec<(Range<Anchor>, String)> {
+ ) -> Vec<(Range<Anchor>, Arc<str>)> {
text_diff(&old_text, new_text)
.into_iter()
.map(|(mut old_range, new_text)| {
@@ -836,7 +836,7 @@ impl Zeta {
);
old_range.end = old_range.end.saturating_sub(suffix_len);
- let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
+ let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
let range = if old_range.is_empty() {
let anchor = snapshot.anchor_after(old_range.start);
anchor..anchor
@@ -1183,7 +1183,7 @@ impl CurrentEditPrediction {
if old_edits.len() == 1 && new_edits.len() == 1 {
let (old_range, old_text) = &old_edits[0];
let (new_range, new_text) = &new_edits[0];
- new_range == old_range && new_text.starts_with(old_text)
+ new_range == old_range && new_text.starts_with(old_text.as_ref())
} else {
true
}
@@ -1599,13 +1599,8 @@ mod tests {
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
- let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
- to_completion_edits(
- [(2..5, "REM".to_string()), (9..11, "".to_string())],
- &buffer,
- cx,
- )
- .into()
+ let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
+ to_completion_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
});
let edit_preview = cx
@@ -1635,7 +1630,7 @@ mod tests {
&buffer,
cx
),
- vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
+ vec![(2..5, "REM".into()), (9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
@@ -1645,7 +1640,7 @@ mod tests {
&buffer,
cx
),
- vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
+ vec![(2..2, "REM".into()), (6..8, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.undo(cx));
@@ -1655,7 +1650,7 @@ mod tests {
&buffer,
cx
),
- vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
+ vec![(2..5, "REM".into()), (9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
@@ -1665,7 +1660,7 @@ mod tests {
&buffer,
cx
),
- vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
+ vec![(3..3, "EM".into()), (7..9, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
@@ -1675,7 +1670,7 @@ mod tests {
&buffer,
cx
),
- vec![(4..4, "M".to_string()), (8..10, "".to_string())]
+ vec![(4..4, "M".into()), (8..10, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
@@ -1685,7 +1680,7 @@ mod tests {
&buffer,
cx
),
- vec![(9..11, "".to_string())]
+ vec![(9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
@@ -1695,7 +1690,7 @@ mod tests {
&buffer,
cx
),
- vec![(4..4, "M".to_string()), (8..10, "".to_string())]
+ vec![(4..4, "M".into()), (8..10, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
@@ -1705,7 +1700,7 @@ mod tests {
&buffer,
cx
),
- vec![(4..4, "M".to_string())]
+ vec![(4..4, "M".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
@@ -2211,10 +2206,10 @@ mod tests {
}
fn to_completion_edits(
- iterator: impl IntoIterator<Item = (Range<usize>, String)>,
+ iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
buffer: &Entity<Buffer>,
cx: &App,
- ) -> Vec<(Range<Anchor>, String)> {
+ ) -> Vec<(Range<Anchor>, Arc<str>)> {
let buffer = buffer.read(cx);
iterator
.into_iter()
@@ -2228,10 +2223,10 @@ mod tests {
}
fn from_completion_edits(
- editor_edits: &[(Range<Anchor>, String)],
+ editor_edits: &[(Range<Anchor>, Arc<str>)],
buffer: &Entity<Buffer>,
cx: &App,
- ) -> Vec<(Range<usize>, String)> {
+ ) -> Vec<(Range<usize>, Arc<str>)> {
let buffer = buffer.read(cx);
editor_edits
.iter()
@@ -28,9 +28,9 @@ indoc.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
+open_ai.workspace = true
project.workspace = true
release_channel.workspace = true
-schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
thiserror.workspace = true
@@ -50,3 +50,4 @@ language_model = { workspace = true, features = ["test-support"] }
pretty_assertions.workspace = true
project = { workspace = true, features = ["test-support"] }
settings = { workspace = true, features = ["test-support"] }
+zlog.workspace = true
@@ -1,17 +1,11 @@
-use std::{borrow::Cow, ops::Range, path::Path, sync::Arc};
-
-use anyhow::Context as _;
-use cloud_llm_client::predict_edits_v3;
-use gpui::{App, AsyncApp, Entity};
-use language::{
- Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot, text_diff,
-};
-use project::Project;
-use util::ResultExt;
+use std::{ops::Range, sync::Arc};
+
+use gpui::{AsyncApp, Entity};
+use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot};
use uuid::Uuid;
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
-pub struct EditPredictionId(Uuid);
+pub struct EditPredictionId(pub Uuid);
impl Into<Uuid> for EditPredictionId {
fn into(self) -> Uuid {
@@ -34,8 +28,7 @@ impl std::fmt::Display for EditPredictionId {
#[derive(Clone)]
pub struct EditPrediction {
pub id: EditPredictionId,
- pub path: Arc<Path>,
- pub edits: Arc<[(Range<Anchor>, String)]>,
+ pub edits: Arc<[(Range<Anchor>, Arc<str>)]>,
pub snapshot: BufferSnapshot,
pub edit_preview: EditPreview,
// We keep a reference to the buffer so that we do not need to reload it from disk when applying the prediction.
@@ -43,90 +36,43 @@ pub struct EditPrediction {
}
impl EditPrediction {
- pub async fn from_response(
- response: predict_edits_v3::PredictEditsResponse,
- active_buffer_old_snapshot: &TextBufferSnapshot,
- active_buffer: &Entity<Buffer>,
- project: &Entity<Project>,
+ pub async fn new(
+ id: EditPredictionId,
+ edited_buffer: &Entity<Buffer>,
+ edited_buffer_snapshot: &BufferSnapshot,
+ edits: Vec<(Range<Anchor>, Arc<str>)>,
cx: &mut AsyncApp,
) -> Option<Self> {
- // TODO only allow cloud to return one path
- let Some(path) = response.edits.first().map(|e| e.path.clone()) else {
- return None;
- };
+ let (edits, snapshot, edit_preview_task) = edited_buffer
+ .read_with(cx, |buffer, cx| {
+ let new_snapshot = buffer.snapshot();
+ let edits: Arc<[_]> =
+ interpolate_edits(&edited_buffer_snapshot, &new_snapshot, edits.into())?.into();
- let is_same_path = active_buffer
- .read_with(cx, |buffer, cx| buffer_path_eq(buffer, &path, cx))
- .ok()?;
-
- let (buffer, edits, snapshot, edit_preview_task) = if is_same_path {
- active_buffer
- .read_with(cx, |buffer, cx| {
- let new_snapshot = buffer.snapshot();
- let edits = edits_from_response(&response.edits, &active_buffer_old_snapshot);
- let edits: Arc<[_]> =
- interpolate_edits(active_buffer_old_snapshot, &new_snapshot, edits)?.into();
-
- Some((
- active_buffer.clone(),
- edits.clone(),
- new_snapshot,
- buffer.preview_edits(edits, cx),
- ))
- })
- .ok()??
- } else {
- let buffer_handle = project
- .update(cx, |project, cx| {
- let project_path = project
- .find_project_path(&path, cx)
- .context("Failed to find project path for zeta edit")?;
- anyhow::Ok(project.open_buffer(project_path, cx))
- })
- .ok()?
- .log_err()?
- .await
- .context("Failed to open buffer for zeta edit")
- .log_err()?;
-
- buffer_handle
- .read_with(cx, |buffer, cx| {
- let snapshot = buffer.snapshot();
- let edits = edits_from_response(&response.edits, &snapshot);
- if edits.is_empty() {
- return None;
- }
- Some((
- buffer_handle.clone(),
- edits.clone(),
- snapshot,
- buffer.preview_edits(edits, cx),
- ))
- })
- .ok()??
- };
+ Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
+ })
+ .ok()??;
let edit_preview = edit_preview_task.await;
Some(EditPrediction {
- id: EditPredictionId(response.request_id),
- path,
+ id,
edits,
snapshot,
edit_preview,
- buffer,
+ buffer: edited_buffer.clone(),
})
}
pub fn interpolate(
&self,
new_snapshot: &TextBufferSnapshot,
- ) -> Option<Vec<(Range<Anchor>, String)>> {
+ ) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
interpolate_edits(&self.snapshot, new_snapshot, self.edits.clone())
}
- pub fn targets_buffer(&self, buffer: &Buffer, cx: &App) -> bool {
- buffer_path_eq(buffer, &self.path, cx)
+ pub fn targets_buffer(&self, buffer: &Buffer) -> bool {
+ self.snapshot.remote_id() == buffer.remote_id()
}
}
@@ -134,21 +80,16 @@ impl std::fmt::Debug for EditPrediction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EditPrediction")
.field("id", &self.id)
- .field("path", &self.path)
.field("edits", &self.edits)
.finish()
}
}
-pub fn buffer_path_eq(buffer: &Buffer, path: &Path, cx: &App) -> bool {
- buffer.file().map(|p| p.full_path(cx)).as_deref() == Some(path)
-}
-
pub fn interpolate_edits(
old_snapshot: &TextBufferSnapshot,
new_snapshot: &TextBufferSnapshot,
- current_edits: Arc<[(Range<Anchor>, String)]>,
-) -> Option<Vec<(Range<Anchor>, String)>> {
+ current_edits: Arc<[(Range<Anchor>, Arc<str>)]>,
+) -> Option<Vec<(Range<Anchor>, Arc<str>)>> {
let mut edits = Vec::new();
let mut model_edits = current_edits.iter().peekable();
@@ -173,7 +114,7 @@ pub fn interpolate_edits(
if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
if !model_suffix.is_empty() {
let anchor = old_snapshot.anchor_after(user_edit.old.end);
- edits.push((anchor..anchor, model_suffix.to_string()));
+ edits.push((anchor..anchor, model_suffix.into()));
}
model_edits.next();
@@ -190,135 +131,17 @@ pub fn interpolate_edits(
if edits.is_empty() { None } else { Some(edits) }
}
-pub fn line_range_to_point_range(range: Range<predict_edits_v3::Line>) -> Range<language::Point> {
- language::Point::new(range.start.0, 0)..language::Point::new(range.end.0, 0)
-}
-
-fn edits_from_response(
- edits: &[predict_edits_v3::Edit],
- snapshot: &TextBufferSnapshot,
-) -> Arc<[(Range<Anchor>, String)]> {
- edits
- .iter()
- .flat_map(|edit| {
- let point_range = line_range_to_point_range(edit.range.clone());
- let offset = point_range.to_offset(snapshot).start;
- let old_text = snapshot.text_for_range(point_range);
-
- excerpt_edits_from_response(
- old_text.collect::<Cow<str>>(),
- &edit.content,
- offset,
- &snapshot,
- )
- })
- .collect::<Vec<_>>()
- .into()
-}
-
-fn excerpt_edits_from_response(
- old_text: Cow<str>,
- new_text: &str,
- offset: usize,
- snapshot: &TextBufferSnapshot,
-) -> impl Iterator<Item = (Range<Anchor>, String)> {
- text_diff(&old_text, new_text)
- .into_iter()
- .map(move |(mut old_range, new_text)| {
- old_range.start += offset;
- old_range.end += offset;
-
- let prefix_len = common_prefix(
- snapshot.chars_for_range(old_range.clone()),
- new_text.chars(),
- );
- old_range.start += prefix_len;
-
- let suffix_len = common_prefix(
- snapshot.reversed_chars_for_range(old_range.clone()),
- new_text[prefix_len..].chars().rev(),
- );
- old_range.end = old_range.end.saturating_sub(suffix_len);
-
- let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
- let range = if old_range.is_empty() {
- let anchor = snapshot.anchor_after(old_range.start);
- anchor..anchor
- } else {
- snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
- };
- (range, new_text)
- })
-}
-
-fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
- a.zip(b)
- .take_while(|(a, b)| a == b)
- .map(|(a, _)| a.len_utf8())
- .sum()
-}
-
#[cfg(test)]
mod tests {
- use std::path::PathBuf;
-
use super::*;
- use cloud_llm_client::predict_edits_v3;
- use edit_prediction_context::Line;
use gpui::{App, Entity, TestAppContext, prelude::*};
- use indoc::indoc;
use language::{Buffer, ToOffset as _};
- #[gpui::test]
- async fn test_compute_edits(cx: &mut TestAppContext) {
- let old = indoc! {r#"
- fn main() {
- let args =
- println!("{}", args[1])
- }
- "#};
-
- let new = indoc! {r#"
- fn main() {
- let args = std::env::args();
- println!("{}", args[1]);
- }
- "#};
-
- let buffer = cx.new(|cx| Buffer::local(old, cx));
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
-
- // TODO cover more cases when multi-file is supported
- let big_edits = vec![predict_edits_v3::Edit {
- path: PathBuf::from("test.txt").into(),
- range: Line(0)..Line(old.lines().count() as u32),
- content: new.into(),
- }];
-
- let edits = edits_from_response(&big_edits, &snapshot);
- assert_eq!(edits.len(), 2);
- assert_eq!(
- edits[0].0.to_point(&snapshot).start,
- language::Point::new(1, 14)
- );
- assert_eq!(edits[0].1, " std::env::args();");
- assert_eq!(
- edits[1].0.to_point(&snapshot).start,
- language::Point::new(2, 27)
- );
- assert_eq!(edits[1].1, ";");
- }
-
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
- let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
- to_prediction_edits(
- [(2..5, "REM".to_string()), (9..11, "".to_string())],
- &buffer,
- cx,
- )
- .into()
+ let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx.update(|cx| {
+ to_prediction_edits([(2..5, "REM".into()), (9..11, "".into())], &buffer, cx).into()
});
let edit_preview = cx
@@ -329,7 +152,6 @@ mod tests {
id: EditPredictionId(Uuid::new_v4()),
edits,
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
- path: Path::new("test.txt").into(),
buffer: buffer.clone(),
edit_preview,
};
@@ -341,7 +163,7 @@ mod tests {
&buffer,
cx
),
- vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
+ vec![(2..5, "REM".into()), (9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
@@ -351,7 +173,7 @@ mod tests {
&buffer,
cx
),
- vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
+ vec![(2..2, "REM".into()), (6..8, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.undo(cx));
@@ -361,7 +183,7 @@ mod tests {
&buffer,
cx
),
- vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
+ vec![(2..5, "REM".into()), (9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
@@ -371,7 +193,7 @@ mod tests {
&buffer,
cx
),
- vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
+ vec![(3..3, "EM".into()), (7..9, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
@@ -381,7 +203,7 @@ mod tests {
&buffer,
cx
),
- vec![(4..4, "M".to_string()), (8..10, "".to_string())]
+ vec![(4..4, "M".into()), (8..10, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
@@ -391,7 +213,7 @@ mod tests {
&buffer,
cx
),
- vec![(9..11, "".to_string())]
+ vec![(9..11, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
@@ -401,7 +223,7 @@ mod tests {
&buffer,
cx
),
- vec![(4..4, "M".to_string()), (8..10, "".to_string())]
+ vec![(4..4, "M".into()), (8..10, "".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
@@ -411,7 +233,7 @@ mod tests {
&buffer,
cx
),
- vec![(4..4, "M".to_string())]
+ vec![(4..4, "M".into())]
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
@@ -420,10 +242,10 @@ mod tests {
}
fn to_prediction_edits(
- iterator: impl IntoIterator<Item = (Range<usize>, String)>,
+ iterator: impl IntoIterator<Item = (Range<usize>, Arc<str>)>,
buffer: &Entity<Buffer>,
cx: &App,
- ) -> Vec<(Range<Anchor>, String)> {
+ ) -> Vec<(Range<Anchor>, Arc<str>)> {
let buffer = buffer.read(cx);
iterator
.into_iter()
@@ -437,10 +259,10 @@ mod tests {
}
fn from_prediction_edits(
- editor_edits: &[(Range<Anchor>, String)],
+ editor_edits: &[(Range<Anchor>, Arc<str>)],
buffer: &Entity<Buffer>,
cx: &App,
- ) -> Vec<(Range<usize>, String)> {
+ ) -> Vec<(Range<usize>, Arc<str>)> {
let buffer = buffer.read(cx);
editor_edits
.iter()
@@ -1,717 +0,0 @@
-use std::{
- cmp::Reverse, collections::hash_map::Entry, ops::Range, path::PathBuf, sync::Arc, time::Instant,
-};
-
-use crate::{
- ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo,
- ZetaSearchQueryDebugInfo, merge_excerpts::merge_excerpts,
-};
-use anyhow::{Result, anyhow};
-use cloud_zeta2_prompt::write_codeblock;
-use collections::HashMap;
-use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
-use futures::{
- StreamExt,
- channel::mpsc::{self, UnboundedSender},
- stream::BoxStream,
-};
-use gpui::{App, AppContext, AsyncApp, Entity, Task};
-use indoc::indoc;
-use language::{
- Anchor, Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _,
-};
-use language_model::{
- LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
- LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
- LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
- LanguageModelToolUse, MessageContent, Role,
-};
-use project::{
- Project, WorktreeSettings,
- search::{SearchQuery, SearchResult},
-};
-use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
-use util::{
- ResultExt as _,
- paths::{PathMatcher, PathStyle},
-};
-use workspace::item::Settings as _;
-
-const SEARCH_PROMPT: &str = indoc! {r#"
- ## Task
-
- You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
- that will serve as context for predicting the next required edit.
-
- **Your task:**
- - Analyze the user's recent edits and current cursor context
- - Use the `search` tool to find code that may be relevant for predicting the next edit
- - Focus on finding:
- - Code patterns that might need similar changes based on the recent edits
- - Functions, variables, types, and constants referenced in the current cursor context
- - Related implementations, usages, or dependencies that may require consistent updates
-
- **Important constraints:**
- - This conversation has exactly 2 turns
- - You must make ALL search queries in your first response via the `search` tool
- - All queries will be executed in parallel and results returned together
- - In the second turn, you will select the most relevant results via the `select` tool.
-
- ## User Edits
-
- {edits}
-
- ## Current cursor context
-
- `````{current_file_path}
- {cursor_excerpt}
- `````
-
- --
- Use the `search` tool now
-"#};
-
-const SEARCH_TOOL_NAME: &str = "search";
-
-/// Search for relevant code
-///
-/// For the best results, run multiple queries at once with a single invocation of this tool.
-#[derive(Clone, Deserialize, Serialize, JsonSchema)]
-pub struct SearchToolInput {
- /// An array of queries to run for gathering context relevant to the next prediction
- #[schemars(length(max = 5))]
- pub queries: Box<[SearchToolQuery]>,
-}
-
-#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
-pub struct SearchToolQuery {
- /// A glob pattern to match file paths in the codebase
- pub glob: String,
- /// A regular expression to match content within the files matched by the glob pattern
- pub regex: String,
-}
-
-const RESULTS_MESSAGE: &str = indoc! {"
- Here are the results of your queries combined and grouped by file:
-
-"};
-
-const SELECT_TOOL_NAME: &str = "select";
-
-const SELECT_PROMPT: &str = indoc! {"
- Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
- Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
- Include up to 200 lines in total.
-"};
-
-/// Select line ranges from search results
-#[derive(Deserialize, JsonSchema)]
-struct SelectToolInput {
- /// The line ranges to select from search results.
- ranges: Vec<SelectLineRange>,
-}
-
-/// A specific line range to select from a file
-#[derive(Debug, Deserialize, JsonSchema)]
-struct SelectLineRange {
- /// The file path containing the lines to select
- /// Exactly as it appears in the search result codeblocks.
- path: PathBuf,
- /// The starting line number (1-based)
- #[schemars(range(min = 1))]
- start_line: u32,
- /// The ending line number (1-based, inclusive)
- #[schemars(range(min = 1))]
- end_line: u32,
-}
-
-#[derive(Debug, Clone, PartialEq)]
-pub struct LlmContextOptions {
- pub excerpt: EditPredictionExcerptOptions,
-}
-
-pub const MODEL_PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
-
-pub fn find_related_excerpts(
- buffer: Entity<language::Buffer>,
- cursor_position: Anchor,
- project: &Entity<Project>,
- mut edit_history_unified_diff: String,
- options: &LlmContextOptions,
- debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
- cx: &App,
-) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
- let language_model_registry = LanguageModelRegistry::global(cx);
- let Some(model) = language_model_registry
- .read(cx)
- .available_models(cx)
- .find(|model| {
- model.provider_id() == MODEL_PROVIDER_ID
- && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
- // model.provider_id() == LanguageModelProviderId::new("zeta-ctx-qwen-30b")
- // model.provider_id() == LanguageModelProviderId::new("ollama")
- // && model.id() == LanguageModelId("gpt-oss:20b".into())
- })
- else {
- return Task::ready(Err(anyhow!("could not find context model")));
- };
-
- if edit_history_unified_diff.is_empty() {
- edit_history_unified_diff.push_str("(No user edits yet)");
- }
-
- // TODO [zeta2] include breadcrumbs?
- let snapshot = buffer.read(cx).snapshot();
- let cursor_point = cursor_position.to_point(&snapshot);
- let Some(cursor_excerpt) =
- EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
- else {
- return Task::ready(Ok(HashMap::default()));
- };
-
- let current_file_path = snapshot
- .file()
- .map(|f| f.full_path(cx).display().to_string())
- .unwrap_or_else(|| "untitled".to_string());
-
- let prompt = SEARCH_PROMPT
- .replace("{edits}", &edit_history_unified_diff)
- .replace("{current_file_path}", ¤t_file_path)
- .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
- ZetaContextRetrievalStartedDebugInfo {
- project: project.clone(),
- timestamp: Instant::now(),
- search_prompt: prompt.clone(),
- },
- ))
- .ok();
- }
-
- let path_style = project.read(cx).path_style(cx);
-
- let exclude_matcher = {
- let global_settings = WorktreeSettings::get_global(cx);
- let exclude_patterns = global_settings
- .file_scan_exclusions
- .sources()
- .iter()
- .chain(global_settings.private_files.sources().iter());
-
- match PathMatcher::new(exclude_patterns, path_style) {
- Ok(matcher) => matcher,
- Err(err) => {
- return Task::ready(Err(anyhow!(err)));
- }
- }
- };
-
- let project = project.clone();
- cx.spawn(async move |cx| {
- let initial_prompt_message = LanguageModelRequestMessage {
- role: Role::User,
- content: vec![prompt.into()],
- cache: false,
- };
-
- let mut search_stream = request_tool_call::<SearchToolInput>(
- vec![initial_prompt_message.clone()],
- SEARCH_TOOL_NAME,
- &model,
- cx,
- )
- .await?;
-
- let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
- select_request_messages.push(initial_prompt_message);
-
- let mut regex_by_glob: HashMap<String, String> = HashMap::default();
- let mut search_calls = Vec::new();
-
- while let Some(event) = search_stream.next().await {
- match event? {
- LanguageModelCompletionEvent::ToolUse(tool_use) => {
- if !tool_use.is_input_complete {
- continue;
- }
-
- if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
- let input =
- serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
-
- for query in input.queries {
- let regex = regex_by_glob.entry(query.glob).or_default();
- if !regex.is_empty() {
- regex.push('|');
- }
- regex.push_str(&query.regex);
- }
-
- search_calls.push(tool_use);
- } else {
- log::warn!(
- "context gathering model tried to use unknown tool: {}",
- tool_use.name
- );
- }
- }
- LanguageModelCompletionEvent::Text(txt) => {
- if let Some(LanguageModelRequestMessage {
- role: Role::Assistant,
- content,
- ..
- }) = select_request_messages.last_mut()
- {
- if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
- existing_text.push_str(&txt);
- } else {
- content.push(MessageContent::Text(txt));
- }
- } else {
- select_request_messages.push(LanguageModelRequestMessage {
- role: Role::Assistant,
- content: vec![MessageContent::Text(txt)],
- cache: false,
- });
- }
- }
- LanguageModelCompletionEvent::Thinking { text, signature } => {
- if let Some(LanguageModelRequestMessage {
- role: Role::Assistant,
- content,
- ..
- }) = select_request_messages.last_mut()
- {
- if let Some(MessageContent::Thinking {
- text: existing_text,
- signature: existing_signature,
- }) = content.last_mut()
- {
- existing_text.push_str(&text);
- *existing_signature = signature;
- } else {
- content.push(MessageContent::Thinking { text, signature });
- }
- } else {
- select_request_messages.push(LanguageModelRequestMessage {
- role: Role::Assistant,
- content: vec![MessageContent::Thinking { text, signature }],
- cache: false,
- });
- }
- }
- LanguageModelCompletionEvent::RedactedThinking { data } => {
- if let Some(LanguageModelRequestMessage {
- role: Role::Assistant,
- content,
- ..
- }) = select_request_messages.last_mut()
- {
- if let Some(MessageContent::RedactedThinking(existing_data)) =
- content.last_mut()
- {
- existing_data.push_str(&data);
- } else {
- content.push(MessageContent::RedactedThinking(data));
- }
- } else {
- select_request_messages.push(LanguageModelRequestMessage {
- role: Role::Assistant,
- content: vec![MessageContent::RedactedThinking(data)],
- cache: false,
- });
- }
- }
- ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
- log::error!("{ev:?}");
- }
- ev => {
- log::trace!("context search event: {ev:?}")
- }
- }
- }
-
- let search_tool_use = if search_calls.is_empty() {
- log::warn!("context model ran 0 searches");
- return anyhow::Ok(Default::default());
- } else if search_calls.len() == 1 {
- search_calls.swap_remove(0)
- } else {
- // In theory, the model could perform multiple search calls
- // Dealing with them separately is not worth it when it doesn't happen in practice.
- // If it were to happen, here we would combine them into one.
- // The second request doesn't need to know it was actually two different calls ;)
- let input = serde_json::to_value(&SearchToolInput {
- queries: regex_by_glob
- .iter()
- .map(|(glob, regex)| SearchToolQuery {
- glob: glob.clone(),
- regex: regex.clone(),
- })
- .collect(),
- })
- .unwrap_or_default();
-
- LanguageModelToolUse {
- id: search_calls.swap_remove(0).id,
- name: SELECT_TOOL_NAME.into(),
- raw_input: serde_json::to_string(&input).unwrap_or_default(),
- input,
- is_input_complete: true,
- }
- };
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
- ZetaSearchQueryDebugInfo {
- project: project.clone(),
- timestamp: Instant::now(),
- queries: regex_by_glob
- .iter()
- .map(|(glob, regex)| SearchToolQuery {
- glob: glob.clone(),
- regex: regex.clone(),
- })
- .collect(),
- },
- ))
- .ok();
- }
-
- let (results_tx, mut results_rx) = mpsc::unbounded();
-
- for (glob, regex) in regex_by_glob {
- let exclude_matcher = exclude_matcher.clone();
- let results_tx = results_tx.clone();
- let project = project.clone();
- cx.spawn(async move |cx| {
- run_query(
- &glob,
- ®ex,
- results_tx.clone(),
- path_style,
- exclude_matcher,
- &project,
- cx,
- )
- .await
- .log_err();
- })
- .detach()
- }
- drop(results_tx);
-
- struct ResultBuffer {
- buffer: Entity<Buffer>,
- snapshot: TextBufferSnapshot,
- }
-
- let (result_buffers_by_path, merged_result) = cx
- .background_spawn(async move {
- let mut excerpts_by_buffer: HashMap<Entity<Buffer>, MatchedBuffer> =
- HashMap::default();
-
- while let Some((buffer, matched)) = results_rx.next().await {
- match excerpts_by_buffer.entry(buffer) {
- Entry::Occupied(mut entry) => {
- let entry = entry.get_mut();
- entry.full_path = matched.full_path;
- entry.snapshot = matched.snapshot;
- entry.line_ranges.extend(matched.line_ranges);
- }
- Entry::Vacant(entry) => {
- entry.insert(matched);
- }
- }
- }
-
- let mut result_buffers_by_path = HashMap::default();
- let mut merged_result = RESULTS_MESSAGE.to_string();
-
- for (buffer, mut matched) in excerpts_by_buffer {
- matched
- .line_ranges
- .sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
-
- write_codeblock(
- &matched.full_path,
- merge_excerpts(&matched.snapshot, matched.line_ranges).iter(),
- &[],
- Line(matched.snapshot.max_point().row),
- true,
- &mut merged_result,
- );
-
- result_buffers_by_path.insert(
- matched.full_path,
- ResultBuffer {
- buffer,
- snapshot: matched.snapshot.text,
- },
- );
- }
-
- (result_buffers_by_path, merged_result)
- })
- .await;
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
- ZetaContextRetrievalDebugInfo {
- project: project.clone(),
- timestamp: Instant::now(),
- },
- ))
- .ok();
- }
-
- let tool_result = LanguageModelToolResult {
- tool_use_id: search_tool_use.id.clone(),
- tool_name: SEARCH_TOOL_NAME.into(),
- is_error: false,
- content: merged_result.into(),
- output: None,
- };
-
- select_request_messages.extend([
- LanguageModelRequestMessage {
- role: Role::Assistant,
- content: vec![MessageContent::ToolUse(search_tool_use)],
- cache: false,
- },
- LanguageModelRequestMessage {
- role: Role::User,
- content: vec![MessageContent::ToolResult(tool_result)],
- cache: false,
- },
- ]);
-
- if result_buffers_by_path.is_empty() {
- log::trace!("context gathering queries produced no results");
- return anyhow::Ok(HashMap::default());
- }
-
- select_request_messages.push(LanguageModelRequestMessage {
- role: Role::User,
- content: vec![SELECT_PROMPT.into()],
- cache: false,
- });
-
- let mut select_stream = request_tool_call::<SelectToolInput>(
- select_request_messages,
- SELECT_TOOL_NAME,
- &model,
- cx,
- )
- .await?;
-
- cx.background_spawn(async move {
- let mut selected_ranges = Vec::new();
-
- while let Some(event) = select_stream.next().await {
- match event? {
- LanguageModelCompletionEvent::ToolUse(tool_use) => {
- if !tool_use.is_input_complete {
- continue;
- }
-
- if tool_use.name.as_ref() == SELECT_TOOL_NAME {
- let call =
- serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
- selected_ranges.extend(call.ranges);
- } else {
- log::warn!(
- "context gathering model tried to use unknown tool: {}",
- tool_use.name
- );
- }
- }
- ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
- log::error!("{ev:?}");
- }
- ev => {
- log::trace!("context select event: {ev:?}")
- }
- }
- }
-
- if let Some(debug_tx) = &debug_tx {
- debug_tx
- .unbounded_send(ZetaDebugInfo::SearchResultsFiltered(
- ZetaContextRetrievalDebugInfo {
- project: project.clone(),
- timestamp: Instant::now(),
- },
- ))
- .ok();
- }
-
- if selected_ranges.is_empty() {
- log::trace!("context gathering selected no ranges")
- }
-
- selected_ranges.sort_unstable_by(|a, b| {
- a.start_line
- .cmp(&b.start_line)
- .then(b.end_line.cmp(&a.end_line))
- });
-
- let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
-
- for selected_range in selected_ranges {
- if let Some(ResultBuffer { buffer, snapshot }) =
- result_buffers_by_path.get(&selected_range.path)
- {
- let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
- let end_point =
- snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
- let range =
- snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
-
- related_excerpts_by_buffer
- .entry(buffer.clone())
- .or_default()
- .push(range);
- } else {
- log::warn!(
- "selected path that wasn't included in search results: {}",
- selected_range.path.display()
- );
- }
- }
-
- anyhow::Ok(related_excerpts_by_buffer)
- })
- .await
- })
-}
-
-async fn request_tool_call<T: JsonSchema>(
- messages: Vec<LanguageModelRequestMessage>,
- tool_name: &'static str,
- model: &Arc<dyn LanguageModel>,
- cx: &mut AsyncApp,
-) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
-{
- let schema = schemars::schema_for!(T);
-
- let request = LanguageModelRequest {
- messages,
- tools: vec![LanguageModelRequestTool {
- name: tool_name.into(),
- description: schema
- .get("description")
- .and_then(|description| description.as_str())
- .unwrap()
- .to_string(),
- input_schema: serde_json::to_value(schema).unwrap(),
- }],
- ..Default::default()
- };
-
- Ok(model.stream_completion(request, cx).await?)
-}
-
-const MIN_EXCERPT_LEN: usize = 16;
-const MAX_EXCERPT_LEN: usize = 768;
-const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
-
-struct MatchedBuffer {
- snapshot: BufferSnapshot,
- line_ranges: Vec<Range<Line>>,
- full_path: PathBuf,
-}
-
-async fn run_query(
- glob: &str,
- regex: &str,
- results_tx: UnboundedSender<(Entity<Buffer>, MatchedBuffer)>,
- path_style: PathStyle,
- exclude_matcher: PathMatcher,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
-) -> Result<()> {
- let include_matcher = PathMatcher::new(vec![glob], path_style)?;
-
- let query = SearchQuery::regex(
- regex,
- false,
- true,
- false,
- true,
- include_matcher,
- exclude_matcher,
- true,
- None,
- )?;
-
- let results = project.update(cx, |project, cx| project.search(query, cx))?;
- futures::pin_mut!(results);
-
- let mut total_bytes = 0;
-
- while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
- if ranges.is_empty() {
- continue;
- }
-
- let Some((snapshot, full_path)) = buffer.read_with(cx, |buffer, cx| {
- Some((buffer.snapshot(), buffer.file()?.full_path(cx)))
- })?
- else {
- continue;
- };
-
- let results_tx = results_tx.clone();
- cx.background_spawn(async move {
- let mut line_ranges = Vec::with_capacity(ranges.len());
-
- for range in ranges {
- let offset_range = range.to_offset(&snapshot);
- let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
-
- if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
- break;
- }
-
- let excerpt = EditPredictionExcerpt::select_from_buffer(
- query_point,
- &snapshot,
- &EditPredictionExcerptOptions {
- max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
- min_bytes: MIN_EXCERPT_LEN,
- target_before_cursor_over_total_bytes: 0.5,
- },
- None,
- );
-
- if let Some(excerpt) = excerpt {
- total_bytes += excerpt.range.len();
- if !excerpt.line_range.is_empty() {
- line_ranges.push(excerpt.line_range);
- }
- }
- }
-
- results_tx
- .unbounded_send((
- buffer,
- MatchedBuffer {
- snapshot,
- line_ranges,
- full_path,
- },
- ))
- .log_err();
- })
- .detach();
- }
-
- anyhow::Ok(())
-}
@@ -0,0 +1,194 @@
+use std::ops::Range;
+
+use anyhow::Result;
+use collections::HashMap;
+use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions};
+use futures::{
+ StreamExt,
+ channel::mpsc::{self, UnboundedSender},
+};
+use gpui::{AppContext, AsyncApp, Entity};
+use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, ToPoint as _};
+use project::{
+ Project, WorktreeSettings,
+ search::{SearchQuery, SearchResult},
+};
+use util::{
+ ResultExt as _,
+ paths::{PathMatcher, PathStyle},
+};
+use workspace::item::Settings as _;
+
+pub async fn run_retrieval_searches(
+ project: Entity<Project>,
+ regex_by_glob: HashMap<String, String>,
+ cx: &mut AsyncApp,
+) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
+ let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
+ let global_settings = WorktreeSettings::get_global(cx);
+ let exclude_patterns = global_settings
+ .file_scan_exclusions
+ .sources()
+ .iter()
+ .chain(global_settings.private_files.sources().iter());
+ let path_style = project.path_style(cx);
+ anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style))
+ })??;
+
+ let (results_tx, mut results_rx) = mpsc::unbounded();
+
+ for (glob, regex) in regex_by_glob {
+ let exclude_matcher = exclude_matcher.clone();
+ let results_tx = results_tx.clone();
+ let project = project.clone();
+ cx.spawn(async move |cx| {
+ run_query(
+ &glob,
+ ®ex,
+ results_tx.clone(),
+ path_style,
+ exclude_matcher,
+ &project,
+ cx,
+ )
+ .await
+ .log_err();
+ })
+ .detach()
+ }
+ drop(results_tx);
+
+ cx.background_spawn(async move {
+ let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
+ let mut snapshots = HashMap::default();
+
+ let mut total_bytes = 0;
+ 'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await {
+ snapshots.insert(buffer.entity_id(), snapshot);
+ let existing = results.entry(buffer).or_default();
+ existing.reserve(excerpts.len());
+
+ for (range, size) in excerpts {
+ // Blunt trimming of the results until we have a proper algorithmic filtering step
+ if (total_bytes + size) > MAX_RESULTS_LEN {
+ log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B");
+ break 'outer;
+ }
+ total_bytes += size;
+ existing.push(range);
+ }
+ }
+
+ for (buffer, ranges) in results.iter_mut() {
+ if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
+ ranges.sort_unstable_by(|a, b| {
+ a.start
+ .cmp(&b.start, snapshot)
+ .then(b.end.cmp(&b.end, snapshot))
+ });
+
+ let mut index = 1;
+ while index < ranges.len() {
+ if ranges[index - 1]
+ .end
+ .cmp(&ranges[index].start, snapshot)
+ .is_gt()
+ {
+ let removed = ranges.remove(index);
+ ranges[index - 1].end = removed.end;
+ } else {
+ index += 1;
+ }
+ }
+ }
+ }
+
+ Ok(results)
+ })
+ .await
+}
+
+const MIN_EXCERPT_LEN: usize = 16;
+const MAX_EXCERPT_LEN: usize = 768;
+const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
+
+async fn run_query(
+ glob: &str,
+ regex: &str,
+ results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
+ path_style: PathStyle,
+ exclude_matcher: PathMatcher,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+) -> Result<()> {
+ let include_matcher = PathMatcher::new(vec![glob], path_style)?;
+
+ let query = SearchQuery::regex(
+ regex,
+ false,
+ true,
+ false,
+ true,
+ include_matcher,
+ exclude_matcher,
+ true,
+ None,
+ )?;
+
+ let results = project.update(cx, |project, cx| project.search(query, cx))?;
+ futures::pin_mut!(results);
+
+ while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
+ if results_tx.is_closed() {
+ break;
+ }
+
+ if ranges.is_empty() {
+ continue;
+ }
+
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+ let results_tx = results_tx.clone();
+
+ cx.background_spawn(async move {
+ let mut excerpts = Vec::with_capacity(ranges.len());
+
+ for range in ranges {
+ let offset_range = range.to_offset(&snapshot);
+ let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
+
+ let excerpt = EditPredictionExcerpt::select_from_buffer(
+ query_point,
+ &snapshot,
+ &EditPredictionExcerptOptions {
+ max_bytes: MAX_EXCERPT_LEN,
+ min_bytes: MIN_EXCERPT_LEN,
+ target_before_cursor_over_total_bytes: 0.5,
+ },
+ None,
+ );
+
+ if let Some(excerpt) = excerpt
+ && !excerpt.line_range.is_empty()
+ {
+ excerpts.push((
+ snapshot.anchor_after(excerpt.range.start)
+ ..snapshot.anchor_before(excerpt.range.end),
+ excerpt.range.len(),
+ ));
+ }
+ }
+
+ let send_result = results_tx.unbounded_send((buffer, snapshot, excerpts));
+
+ if let Err(err) = send_result
+ && !err.is_disconnected()
+ {
+ log::error!("{err}");
+ }
+ })
+ .detach();
+ }
+
+ anyhow::Ok(())
+}
@@ -0,0 +1,1024 @@
+use std::borrow::Cow;
+use std::fmt::Display;
+use std::sync::Arc;
+use std::{
+ fmt::{Debug, Write},
+ mem,
+ ops::Range,
+ path::Path,
+};
+
+use anyhow::Context as _;
+use anyhow::Result;
+use anyhow::anyhow;
+use collections::HashMap;
+use gpui::AsyncApp;
+use gpui::Entity;
+use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot};
+use project::Project;
+
+pub async fn parse_diff<'a>(
+ diff: &'a str,
+ get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
+) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
+ let mut diff = DiffParser::new(diff);
+ let mut edited_buffer = None;
+ let mut edits = Vec::new();
+
+ while let Some(event) = diff.next()? {
+ match event {
+ DiffEvent::Hunk {
+ path: file_path,
+ hunk,
+ } => {
+ let (buffer, ranges) = match edited_buffer {
+ None => {
+ edited_buffer = get_buffer(&Path::new(file_path.as_ref()));
+ edited_buffer
+ .as_ref()
+ .context("Model tried to edit a file that wasn't included")?
+ }
+ Some(ref current) => current,
+ };
+
+ edits.extend(resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)?);
+ }
+ DiffEvent::FileEnd { renamed_to } => {
+ let (buffer, _) = edited_buffer
+ .take()
+ .expect("Got a FileEnd event before an Hunk event");
+
+ if renamed_to.is_some() {
+ anyhow::bail!("edit predictions cannot rename files");
+ }
+
+ if diff.next()?.is_some() {
+ anyhow::bail!("Edited more than one file");
+ }
+
+ return Ok((buffer, edits));
+ }
+ }
+ }
+
+ Err(anyhow::anyhow!("No EOF"))
+}
+
+#[derive(Debug)]
+pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap<Cow<'a, str>, Entity<Buffer>>);
+
+#[must_use]
+pub async fn apply_diff<'a>(
+ diff: &'a str,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+) -> Result<OpenedBuffers<'a>> {
+ let mut included_files = HashMap::default();
+
+ for line in diff.lines() {
+ let diff_line = DiffLine::parse(line);
+
+ if let DiffLine::OldPath { path } = diff_line {
+ let buffer = project
+ .update(cx, |project, cx| {
+ let project_path =
+ project
+ .find_project_path(path.as_ref(), cx)
+ .with_context(|| {
+ format!("Failed to find worktree for new path: {}", path)
+ })?;
+ anyhow::Ok(project.open_buffer(project_path, cx))
+ })??
+ .await?;
+
+ included_files.insert(path, buffer);
+ }
+ }
+
+ let ranges = [Anchor::MIN..Anchor::MAX];
+
+ let mut diff = DiffParser::new(diff);
+ let mut current_file = None;
+ let mut edits = vec![];
+
+ while let Some(event) = diff.next()? {
+ match event {
+ DiffEvent::Hunk {
+ path: file_path,
+ hunk,
+ } => {
+ let (buffer, ranges) = match current_file {
+ None => {
+ let buffer = included_files
+ .get_mut(&file_path)
+ .expect("Opened all files in diff");
+
+ current_file = Some((buffer, ranges.as_slice()));
+ current_file.as_ref().unwrap()
+ }
+ Some(ref current) => current,
+ };
+
+ buffer.read_with(cx, |buffer, _| {
+ edits.extend(resolve_hunk_edits_in_buffer(hunk, buffer, ranges)?);
+ anyhow::Ok(())
+ })??;
+ }
+ DiffEvent::FileEnd { renamed_to } => {
+ let (buffer, _) = current_file
+ .take()
+ .expect("Got a FileEnd event before an Hunk event");
+
+ if let Some(renamed_to) = renamed_to {
+ project
+ .update(cx, |project, cx| {
+ let new_project_path = project
+ .find_project_path(Path::new(renamed_to.as_ref()), cx)
+ .with_context(|| {
+ format!("Failed to find worktree for new path: {}", renamed_to)
+ })?;
+
+ let project_file = project::File::from_dyn(buffer.read(cx).file())
+ .expect("Wrong file type");
+
+ anyhow::Ok(project.rename_entry(
+ project_file.entry_id.unwrap(),
+ new_project_path,
+ cx,
+ ))
+ })??
+ .await?;
+ }
+
+ let edits = mem::take(&mut edits);
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(edits, None, cx);
+ })?;
+ }
+ }
+ }
+
+ Ok(OpenedBuffers(included_files))
+}
+
+struct PatchFile<'a> {
+ old_path: Cow<'a, str>,
+ new_path: Cow<'a, str>,
+}
+
+struct DiffParser<'a> {
+ current_file: Option<PatchFile<'a>>,
+ current_line: Option<(&'a str, DiffLine<'a>)>,
+ hunk: Hunk,
+ diff: std::str::Lines<'a>,
+}
+
+#[derive(Debug, PartialEq)]
+enum DiffEvent<'a> {
+ Hunk { path: Cow<'a, str>, hunk: Hunk },
+ FileEnd { renamed_to: Option<Cow<'a, str>> },
+}
+
+#[derive(Debug, Default, PartialEq)]
+struct Hunk {
+ context: String,
+ edits: Vec<Edit>,
+}
+
+impl Hunk {
+ fn is_empty(&self) -> bool {
+ self.context.is_empty() && self.edits.is_empty()
+ }
+}
+
+#[derive(Debug, PartialEq)]
+struct Edit {
+ range: Range<usize>,
+ text: String,
+}
+
+impl<'a> DiffParser<'a> {
+ fn new(diff: &'a str) -> Self {
+ let mut diff = diff.lines();
+ let current_line = diff.next().map(|line| (line, DiffLine::parse(line)));
+ DiffParser {
+ current_file: None,
+ hunk: Hunk::default(),
+ current_line,
+ diff,
+ }
+ }
+
+ fn next(&mut self) -> Result<Option<DiffEvent<'a>>> {
+ loop {
+ let (hunk_done, file_done) = match self.current_line.as_ref().map(|e| &e.1) {
+ Some(DiffLine::OldPath { .. }) | Some(DiffLine::Garbage(_)) | None => (true, true),
+ Some(DiffLine::HunkHeader(_)) => (true, false),
+ _ => (false, false),
+ };
+
+ if hunk_done {
+ if let Some(file) = &self.current_file
+ && !self.hunk.is_empty()
+ {
+ return Ok(Some(DiffEvent::Hunk {
+ path: file.old_path.clone(),
+ hunk: mem::take(&mut self.hunk),
+ }));
+ }
+ }
+
+ if file_done {
+ if let Some(PatchFile { old_path, new_path }) = self.current_file.take() {
+ return Ok(Some(DiffEvent::FileEnd {
+ renamed_to: if old_path != new_path {
+ Some(new_path)
+ } else {
+ None
+ },
+ }));
+ }
+ }
+
+ let Some((line, parsed_line)) = self.current_line.take() else {
+ break;
+ };
+
+ util::maybe!({
+ match parsed_line {
+ DiffLine::OldPath { path } => {
+ self.current_file = Some(PatchFile {
+ old_path: path,
+ new_path: "".into(),
+ });
+ }
+ DiffLine::NewPath { path } => {
+ if let Some(current_file) = &mut self.current_file {
+ current_file.new_path = path
+ }
+ }
+ DiffLine::HunkHeader(_) => {}
+ DiffLine::Context(ctx) => {
+ if self.current_file.is_some() {
+ writeln!(&mut self.hunk.context, "{ctx}")?;
+ }
+ }
+ DiffLine::Deletion(del) => {
+ if self.current_file.is_some() {
+ let range = self.hunk.context.len()
+ ..self.hunk.context.len() + del.len() + '\n'.len_utf8();
+ if let Some(last_edit) = self.hunk.edits.last_mut()
+ && last_edit.range.end == range.start
+ {
+ last_edit.range.end = range.end;
+ } else {
+ self.hunk.edits.push(Edit {
+ range,
+ text: String::new(),
+ });
+ }
+ writeln!(&mut self.hunk.context, "{del}")?;
+ }
+ }
+ DiffLine::Addition(add) => {
+ if self.current_file.is_some() {
+ let range = self.hunk.context.len()..self.hunk.context.len();
+ if let Some(last_edit) = self.hunk.edits.last_mut()
+ && last_edit.range.end == range.start
+ {
+ writeln!(&mut last_edit.text, "{add}").unwrap();
+ } else {
+ self.hunk.edits.push(Edit {
+ range,
+ text: format!("{add}\n"),
+ });
+ }
+ }
+ }
+ DiffLine::Garbage(_) => {}
+ }
+
+ anyhow::Ok(())
+ })
+ .with_context(|| format!("on line:\n\n```\n{}```", line))?;
+
+ self.current_line = self.diff.next().map(|line| (line, DiffLine::parse(line)));
+ }
+
+ anyhow::Ok(None)
+ }
+}
+
+fn resolve_hunk_edits_in_buffer(
+ hunk: Hunk,
+ buffer: &TextBufferSnapshot,
+ ranges: &[Range<Anchor>],
+) -> Result<impl Iterator<Item = (Range<Anchor>, Arc<str>)>, anyhow::Error> {
+ let context_offset = if hunk.context.is_empty() {
+ Ok(0)
+ } else {
+ let mut offset = None;
+ for range in ranges {
+ let range = range.to_offset(buffer);
+ let text = buffer.text_for_range(range.clone()).collect::<String>();
+ for (ix, _) in text.match_indices(&hunk.context) {
+ if offset.is_some() {
+ anyhow::bail!("Context is not unique enough:\n{}", hunk.context);
+ }
+ offset = Some(range.start + ix);
+ }
+ }
+ offset.ok_or_else(|| {
+ anyhow!(
+ "Failed to match context:\n{}\n\nBuffer:\n{}",
+ hunk.context,
+ buffer.text(),
+ )
+ })
+ }?;
+ let iter = hunk.edits.into_iter().flat_map(move |edit| {
+ let old_text = buffer
+ .text_for_range(context_offset + edit.range.start..context_offset + edit.range.end)
+ .collect::<String>();
+ let edits_within_hunk = language::text_diff(&old_text, &edit.text);
+ edits_within_hunk
+ .into_iter()
+ .map(move |(inner_range, inner_text)| {
+ (
+ buffer.anchor_after(context_offset + edit.range.start + inner_range.start)
+ ..buffer.anchor_before(context_offset + edit.range.start + inner_range.end),
+ inner_text,
+ )
+ })
+ });
+ Ok(iter)
+}
+
+#[derive(Debug, PartialEq)]
+pub enum DiffLine<'a> {
+ OldPath { path: Cow<'a, str> },
+ NewPath { path: Cow<'a, str> },
+ HunkHeader(Option<HunkLocation>),
+ Context(&'a str),
+ Deletion(&'a str),
+ Addition(&'a str),
+ Garbage(&'a str),
+}
+
+#[derive(Debug, PartialEq)]
+pub struct HunkLocation {
+ start_line_old: u32,
+ count_old: u32,
+ start_line_new: u32,
+ count_new: u32,
+}
+
+impl<'a> DiffLine<'a> {
+ pub fn parse(line: &'a str) -> Self {
+ Self::try_parse(line).unwrap_or(Self::Garbage(line))
+ }
+
+ fn try_parse(line: &'a str) -> Option<Self> {
+ if let Some(header) = line.strip_prefix("---").and_then(eat_required_whitespace) {
+ let path = parse_header_path("a/", header);
+ Some(Self::OldPath { path })
+ } else if let Some(header) = line.strip_prefix("+++").and_then(eat_required_whitespace) {
+ Some(Self::NewPath {
+ path: parse_header_path("b/", header),
+ })
+ } else if let Some(header) = line.strip_prefix("@@").and_then(eat_required_whitespace) {
+ if header.starts_with("...") {
+ return Some(Self::HunkHeader(None));
+ }
+
+ let (start_line_old, header) = header.strip_prefix('-')?.split_once(',')?;
+ let mut parts = header.split_ascii_whitespace();
+ let count_old = parts.next()?;
+ let (start_line_new, count_new) = parts.next()?.strip_prefix('+')?.split_once(',')?;
+
+ Some(Self::HunkHeader(Some(HunkLocation {
+ start_line_old: start_line_old.parse::<u32>().ok()?.saturating_sub(1),
+ count_old: count_old.parse().ok()?,
+ start_line_new: start_line_new.parse::<u32>().ok()?.saturating_sub(1),
+ count_new: count_new.parse().ok()?,
+ })))
+ } else if let Some(deleted_header) = line.strip_prefix("-") {
+ Some(Self::Deletion(deleted_header))
+ } else if line.is_empty() {
+ Some(Self::Context(""))
+ } else if let Some(context) = line.strip_prefix(" ") {
+ Some(Self::Context(context))
+ } else {
+ Some(Self::Addition(line.strip_prefix("+")?))
+ }
+ }
+}
+
+impl<'a> Display for DiffLine<'a> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ DiffLine::OldPath { path } => write!(f, "--- {path}"),
+ DiffLine::NewPath { path } => write!(f, "+++ {path}"),
+ DiffLine::HunkHeader(Some(hunk_location)) => {
+ write!(
+ f,
+ "@@ -{},{} +{},{} @@",
+ hunk_location.start_line_old + 1,
+ hunk_location.count_old,
+ hunk_location.start_line_new + 1,
+ hunk_location.count_new
+ )
+ }
+ DiffLine::HunkHeader(None) => write!(f, "@@ ... @@"),
+ DiffLine::Context(content) => write!(f, " {content}"),
+ DiffLine::Deletion(content) => write!(f, "-{content}"),
+ DiffLine::Addition(content) => write!(f, "+{content}"),
+ DiffLine::Garbage(line) => write!(f, "{line}"),
+ }
+ }
+}
+
+fn parse_header_path<'a>(strip_prefix: &'static str, header: &'a str) -> Cow<'a, str> {
+ if !header.contains(['"', '\\']) {
+ let path = header.split_ascii_whitespace().next().unwrap_or(header);
+ return Cow::Borrowed(path.strip_prefix(strip_prefix).unwrap_or(path));
+ }
+
+ let mut path = String::with_capacity(header.len());
+ let mut in_quote = false;
+ let mut chars = header.chars().peekable();
+ let mut strip_prefix = Some(strip_prefix);
+
+ while let Some(char) = chars.next() {
+ if char == '"' {
+ in_quote = !in_quote;
+ } else if char == '\\' {
+ let Some(&next_char) = chars.peek() else {
+ break;
+ };
+ chars.next();
+ path.push(next_char);
+ } else if char.is_ascii_whitespace() && !in_quote {
+ break;
+ } else {
+ path.push(char);
+ }
+
+ if let Some(prefix) = strip_prefix
+ && path == prefix
+ {
+ strip_prefix.take();
+ path.clear();
+ }
+ }
+
+ Cow::Owned(path)
+}
+
+fn eat_required_whitespace(header: &str) -> Option<&str> {
+ let trimmed = header.trim_ascii_start();
+
+ if trimmed.len() == header.len() {
+ None
+ } else {
+ Some(trimmed)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use gpui::TestAppContext;
+ use indoc::indoc;
+ use language::Point;
+ use pretty_assertions::assert_eq;
+ use project::{FakeFs, Project};
+ use serde_json::json;
+ use settings::SettingsStore;
+ use util::path;
+
+ #[test]
+ fn parse_lines_simple() {
+ let input = indoc! {"
+ diff --git a/text.txt b/text.txt
+ index 86c770d..a1fd855 100644
+ --- a/file.txt
+ +++ b/file.txt
+ @@ -1,2 +1,3 @@
+ context
+ -deleted
+ +inserted
+ garbage
+
+ --- b/file.txt
+ +++ a/file.txt
+ "};
+
+ let lines = input.lines().map(DiffLine::parse).collect::<Vec<_>>();
+
+ pretty_assertions::assert_eq!(
+ lines,
+ &[
+ DiffLine::Garbage("diff --git a/text.txt b/text.txt"),
+ DiffLine::Garbage("index 86c770d..a1fd855 100644"),
+ DiffLine::OldPath {
+ path: "file.txt".into()
+ },
+ DiffLine::NewPath {
+ path: "file.txt".into()
+ },
+ DiffLine::HunkHeader(Some(HunkLocation {
+ start_line_old: 0,
+ count_old: 2,
+ start_line_new: 0,
+ count_new: 3
+ })),
+ DiffLine::Context("context"),
+ DiffLine::Deletion("deleted"),
+ DiffLine::Addition("inserted"),
+ DiffLine::Garbage("garbage"),
+ DiffLine::Context(""),
+ DiffLine::OldPath {
+ path: "b/file.txt".into()
+ },
+ DiffLine::NewPath {
+ path: "a/file.txt".into()
+ },
+ ]
+ );
+ }
+
+ #[test]
+ fn file_header_extra_space() {
+ let options = ["--- file", "--- file", "---\tfile"];
+
+ for option in options {
+ pretty_assertions::assert_eq!(
+ DiffLine::parse(option),
+ DiffLine::OldPath {
+ path: "file".into()
+ },
+ "{option}",
+ );
+ }
+ }
+
+ #[test]
+ fn hunk_header_extra_space() {
+ let options = [
+ "@@ -1,2 +1,3 @@",
+ "@@ -1,2 +1,3 @@",
+ "@@\t-1,2\t+1,3\t@@",
+ "@@ -1,2 +1,3 @@",
+ "@@ -1,2 +1,3 @@",
+ "@@ -1,2 +1,3 @@",
+ "@@ -1,2 +1,3 @@ garbage",
+ ];
+
+ for option in options {
+ pretty_assertions::assert_eq!(
+ DiffLine::parse(option),
+ DiffLine::HunkHeader(Some(HunkLocation {
+ start_line_old: 0,
+ count_old: 2,
+ start_line_new: 0,
+ count_new: 3
+ })),
+ "{option}",
+ );
+ }
+ }
+
+ #[test]
+ fn hunk_header_without_location() {
+ pretty_assertions::assert_eq!(DiffLine::parse("@@ ... @@"), DiffLine::HunkHeader(None));
+ }
+
+ #[test]
+ fn test_parse_path() {
+ assert_eq!(parse_header_path("a/", "foo.txt"), "foo.txt");
+ assert_eq!(
+ parse_header_path("a/", "foo/bar/baz.txt"),
+ "foo/bar/baz.txt"
+ );
+ assert_eq!(parse_header_path("a/", "a/foo.txt"), "foo.txt");
+ assert_eq!(
+ parse_header_path("a/", "a/foo/bar/baz.txt"),
+ "foo/bar/baz.txt"
+ );
+
+ // Extra
+ assert_eq!(
+ parse_header_path("a/", "a/foo/bar/baz.txt 2025"),
+ "foo/bar/baz.txt"
+ );
+ assert_eq!(
+ parse_header_path("a/", "a/foo/bar/baz.txt\t2025"),
+ "foo/bar/baz.txt"
+ );
+ assert_eq!(
+ parse_header_path("a/", "a/foo/bar/baz.txt \""),
+ "foo/bar/baz.txt"
+ );
+
+ // Quoted
+ assert_eq!(
+ parse_header_path("a/", "a/foo/bar/\"baz quox.txt\""),
+ "foo/bar/baz quox.txt"
+ );
+ assert_eq!(
+ parse_header_path("a/", "\"a/foo/bar/baz quox.txt\""),
+ "foo/bar/baz quox.txt"
+ );
+ assert_eq!(
+ parse_header_path("a/", "\"foo/bar/baz quox.txt\""),
+ "foo/bar/baz quox.txt"
+ );
+ assert_eq!(parse_header_path("a/", "\"whatever ๐คท\""), "whatever ๐คท");
+ assert_eq!(
+ parse_header_path("a/", "\"foo/bar/baz quox.txt\" 2025"),
+ "foo/bar/baz quox.txt"
+ );
+ // unescaped quotes are dropped
+ assert_eq!(parse_header_path("a/", "foo/\"bar\""), "foo/bar");
+
+ // Escaped
+ assert_eq!(
+ parse_header_path("a/", "\"foo/\\\"bar\\\"/baz.txt\""),
+ "foo/\"bar\"/baz.txt"
+ );
+ assert_eq!(
+ parse_header_path("a/", "\"C:\\\\Projects\\\\My App\\\\old file.txt\""),
+ "C:\\Projects\\My App\\old file.txt"
+ );
+ }
+
+ #[test]
+ fn test_parse_diff_with_leading_and_trailing_garbage() {
+ let diff = indoc! {"
+ I need to make some changes.
+
+ I'll change the following things:
+ - one
+ - two
+ - three
+
+ ```
+ --- a/file.txt
+ +++ b/file.txt
+ one
+ +AND
+ two
+ ```
+
+ Summary of what I did:
+ - one
+ - two
+ - three
+
+ That's about it.
+ "};
+
+ let mut events = Vec::new();
+ let mut parser = DiffParser::new(diff);
+ while let Some(event) = parser.next().unwrap() {
+ events.push(event);
+ }
+
+ assert_eq!(
+ events,
+ &[
+ DiffEvent::Hunk {
+ path: "file.txt".into(),
+ hunk: Hunk {
+ context: "one\ntwo\n".into(),
+ edits: vec![Edit {
+ range: 4..4,
+ text: "AND\n".into()
+ }],
+ }
+ },
+ DiffEvent::FileEnd { renamed_to: None }
+ ],
+ )
+ }
+
+ #[gpui::test]
+ async fn test_apply_diff_successful(cx: &mut TestAppContext) {
+ let fs = init_test(cx);
+
+ let buffer_1_text = indoc! {r#"
+ one
+ two
+ three
+ four
+ five
+ "# };
+
+ let buffer_1_text_final = indoc! {r#"
+ 3
+ 4
+ 5
+ "# };
+
+ let buffer_2_text = indoc! {r#"
+ six
+ seven
+ eight
+ nine
+ ten
+ "# };
+
+ let buffer_2_text_final = indoc! {r#"
+ 5
+ six
+ seven
+ 7.5
+ eight
+ nine
+ ten
+ 11
+ "# };
+
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "file1": buffer_1_text,
+ "file2": buffer_2_text,
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
+
+ let diff = indoc! {r#"
+ --- a/root/file1
+ +++ b/root/file1
+ one
+ two
+ -three
+ +3
+ four
+ five
+ --- a/root/file1
+ +++ b/root/file1
+ 3
+ -four
+ -five
+ +4
+ +5
+ --- a/root/file1
+ +++ b/root/file1
+ -one
+ -two
+ 3
+ 4
+ --- a/root/file2
+ +++ b/root/file2
+ +5
+ six
+ --- a/root/file2
+ +++ b/root/file2
+ seven
+ +7.5
+ eight
+ --- a/root/file2
+ +++ b/root/file2
+ ten
+ +11
+ "#};
+
+ let _buffers = apply_diff(diff, &project, &mut cx.to_async())
+ .await
+ .unwrap();
+ let buffer_1 = project
+ .update(cx, |project, cx| {
+ let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
+ project.open_buffer(project_path, cx)
+ })
+ .await
+ .unwrap();
+
+ buffer_1.read_with(cx, |buffer, _cx| {
+ assert_eq!(buffer.text(), buffer_1_text_final);
+ });
+ let buffer_2 = project
+ .update(cx, |project, cx| {
+ let project_path = project.find_project_path(path!("/root/file2"), cx).unwrap();
+ project.open_buffer(project_path, cx)
+ })
+ .await
+ .unwrap();
+
+ buffer_2.read_with(cx, |buffer, _cx| {
+ assert_eq!(buffer.text(), buffer_2_text_final);
+ });
+ }
+
+ #[gpui::test]
+ async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
+ let fs = init_test(cx);
+
+ let buffer_1_text = indoc! {r#"
+ one
+ two
+ three
+ four
+ five
+ one
+ two
+ three
+ four
+ five
+ "# };
+
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "file1": buffer_1_text,
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/root/file1"), cx)
+ })
+ .await
+ .unwrap();
+ let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
+
+ let diff = indoc! {r#"
+ --- a/root/file1
+ +++ b/root/file1
+ one
+ two
+ -three
+ +3
+ four
+ five
+ "#};
+
+ let final_text = indoc! {r#"
+ one
+ two
+ three
+ four
+ five
+ one
+ two
+ 3
+ four
+ five
+ "#};
+
+ apply_diff(diff, &project, &mut cx.to_async())
+ .await
+ .expect_err("Non-unique edits should fail");
+
+ let ranges = [buffer_snapshot.anchor_before(Point::new(1, 0))
+ ..buffer_snapshot.anchor_after(buffer_snapshot.max_point())];
+
+ let (edited_snapshot, edits) = parse_diff(diff, |_path| Some((&buffer_snapshot, &ranges)))
+ .await
+ .unwrap();
+
+ assert_eq!(edited_snapshot.remote_id(), buffer_snapshot.remote_id());
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(edits, None, cx);
+ assert_eq!(buffer.text(), final_text);
+ });
+ }
+
+ #[gpui::test]
+ async fn test_parse_diff_with_edits_within_line(cx: &mut TestAppContext) {
+ let fs = init_test(cx);
+
+ let buffer_1_text = indoc! {r#"
+ one two three four
+ five six seven eight
+ nine ten eleven twelve
+ "# };
+
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "file1": buffer_1_text,
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/root/file1"), cx)
+ })
+ .await
+ .unwrap();
+ let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
+
+ let diff = indoc! {r#"
+ --- a/root/file1
+ +++ b/root/file1
+ one two three four
+ -five six seven eight
+ +five SIX seven eight!
+ nine ten eleven twelve
+ "#};
+
+ let (buffer, edits) = parse_diff(diff, |_path| {
+ Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_]))
+ })
+ .await
+ .unwrap();
+
+ let edits = edits
+ .into_iter()
+ .map(|(range, text)| (range.to_point(&buffer), text))
+ .collect::<Vec<_>>();
+ assert_eq!(
+ edits,
+ &[
+ (Point::new(1, 5)..Point::new(1, 8), "SIX".into()),
+ (Point::new(1, 20)..Point::new(1, 20), "!".into())
+ ]
+ );
+ }
+
+ #[gpui::test]
+ async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
+ let fs = init_test(cx);
+
+ let start = indoc! {r#"
+ one
+ two
+ three
+ four
+ five
+
+ four
+ five
+ "# };
+
+ let end = indoc! {r#"
+ one
+ two
+ 3
+ four
+ 5
+
+ four
+ five
+ "# };
+
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "file1": start,
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
+
+ let diff = indoc! {r#"
+ --- a/root/file1
+ +++ b/root/file1
+ one
+ two
+ -three
+ +3
+ four
+ -five
+ +5
+ "#};
+
+ let _buffers = apply_diff(diff, &project, &mut cx.to_async())
+ .await
+ .unwrap();
+
+ let buffer_1 = project
+ .update(cx, |project, cx| {
+ let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
+ project.open_buffer(project_path, cx)
+ })
+ .await
+ .unwrap();
+
+ buffer_1.read_with(cx, |buffer, _cx| {
+ assert_eq!(buffer.text(), end);
+ });
+ }
+
+ fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ Project::init_settings(cx);
+ language::init(cx);
+ });
+
+ FakeFs::new(cx.background_executor.clone())
+ }
+}
@@ -6,7 +6,8 @@ use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
ZED_VERSION_HEADER_NAME,
};
-use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, build_prompt};
+use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
+use cloud_zeta2_prompt::retrieval_prompt::SearchToolInput;
use collections::HashMap;
use edit_prediction_context::{
DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
@@ -24,11 +25,13 @@ use gpui::{
use language::{Anchor, Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
+use open_ai::FunctionDefinition;
use project::Project;
use release_channel::AppVersion;
use serde::de::DeserializeOwned;
use std::collections::{VecDeque, hash_map};
-use std::fmt::Write;
+use uuid::Uuid;
+
use std::ops::Range;
use std::path::Path;
use std::str::FromStr as _;
@@ -42,12 +45,12 @@ use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_noti
pub mod merge_excerpts;
mod prediction;
mod provider;
-pub mod related_excerpts;
+pub mod retrieval_search;
+pub mod udiff;
use crate::merge_excerpts::merge_excerpts;
use crate::prediction::EditPrediction;
-use crate::related_excerpts::find_related_excerpts;
-pub use crate::related_excerpts::{LlmContextOptions, SearchToolQuery};
+pub use crate::prediction::EditPredictionId;
pub use provider::ZetaEditPredictionProvider;
/// Maximum number of events to track.
@@ -59,9 +62,10 @@ pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPrediction
target_before_cursor_over_total_bytes: 0.5,
};
-pub const DEFAULT_CONTEXT_OPTIONS: ContextMode = ContextMode::Llm(DEFAULT_LLM_CONTEXT_OPTIONS);
+pub const DEFAULT_CONTEXT_OPTIONS: ContextMode =
+ ContextMode::Agentic(DEFAULT_AGENTIC_CONTEXT_OPTIONS);
-pub const DEFAULT_LLM_CONTEXT_OPTIONS: LlmContextOptions = LlmContextOptions {
+pub const DEFAULT_AGENTIC_CONTEXT_OPTIONS: AgenticContextOptions = AgenticContextOptions {
excerpt: DEFAULT_EXCERPT_OPTIONS,
};
@@ -122,14 +126,19 @@ pub struct ZetaOptions {
#[derive(Debug, Clone, PartialEq)]
pub enum ContextMode {
- Llm(LlmContextOptions),
+ Agentic(AgenticContextOptions),
Syntax(EditPredictionContextOptions),
}
+#[derive(Debug, Clone, PartialEq)]
+pub struct AgenticContextOptions {
+ pub excerpt: EditPredictionExcerptOptions,
+}
+
impl ContextMode {
pub fn excerpt(&self) -> &EditPredictionExcerptOptions {
match self {
- ContextMode::Llm(options) => &options.excerpt,
+ ContextMode::Agentic(options) => &options.excerpt,
ContextMode::Syntax(options) => &options.excerpt,
}
}
@@ -140,9 +149,8 @@ pub enum ZetaDebugInfo {
ContextRetrievalStarted(ZetaContextRetrievalStartedDebugInfo),
SearchQueriesGenerated(ZetaSearchQueryDebugInfo),
SearchQueriesExecuted(ZetaContextRetrievalDebugInfo),
- SearchResultsFiltered(ZetaContextRetrievalDebugInfo),
ContextRetrievalFinished(ZetaContextRetrievalDebugInfo),
- EditPredicted(ZetaEditPredictionDebugInfo),
+ EditPredictionRequested(ZetaEditPredictionDebugInfo),
}
#[derive(Debug)]
@@ -165,14 +173,14 @@ pub struct ZetaEditPredictionDebugInfo {
pub buffer: WeakEntity<Buffer>,
pub position: language::Anchor,
pub local_prompt: Result<String, String>,
- pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
+ pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, TimeDelta)>,
}
#[derive(Debug)]
pub struct ZetaSearchQueryDebugInfo {
pub project: Entity<Project>,
pub timestamp: Instant,
- pub queries: Vec<SearchToolQuery>,
+ pub regex_by_glob: HashMap<String, String>,
}
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
@@ -224,7 +232,7 @@ impl CurrentEditPrediction {
{
let (old_range, old_text) = &old_edits[0];
let (new_range, new_text) = &new_edits[0];
- new_range == old_range && new_text.starts_with(old_text)
+ new_range == old_range && new_text.starts_with(old_text.as_ref())
} else {
true
}
@@ -539,7 +547,7 @@ impl Zeta {
prediction,
} = project_state.current_prediction.as_ref()?;
- if prediction.targets_buffer(buffer.read(cx), cx) {
+ if prediction.targets_buffer(buffer.read(cx)) {
Some(BufferEditPrediction::Local { prediction })
} else if *requested_by_buffer_id == buffer.entity_id() {
Some(BufferEditPrediction::Jump { prediction })
@@ -639,7 +647,7 @@ impl Zeta {
pub fn request_prediction(
&mut self,
project: &Entity<Project>,
- buffer: &Entity<Buffer>,
+ active_buffer: &Entity<Buffer>,
position: language::Anchor,
cx: &mut Context<Self>,
) -> Task<Result<Option<EditPrediction>>> {
@@ -651,8 +659,8 @@ impl Zeta {
.read_with(cx, |index, _cx| index.state().clone())
});
let options = self.options.clone();
- let snapshot = buffer.read(cx).snapshot();
- let Some(excerpt_path) = snapshot
+ let active_snapshot = active_buffer.read(cx).snapshot();
+ let Some(excerpt_path) = active_snapshot
.file()
.map(|path| -> Arc<Path> { path.full_path(cx).into() })
else {
@@ -678,12 +686,13 @@ impl Zeta {
})
.unwrap_or_default();
- let diagnostics = snapshot.diagnostic_sets().clone();
+ let diagnostics = active_snapshot.diagnostic_sets().clone();
- let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
- let mut path = f.worktree.read(cx).absolutize(&f.path);
- if path.pop() { Some(path) } else { None }
- });
+ let parent_abs_path =
+ project::File::from_dyn(active_buffer.read(cx).file()).and_then(|f| {
+ let mut path = f.worktree.read(cx).absolutize(&f.path);
+ if path.pop() { Some(path) } else { None }
+ });
// TODO data collection
let can_collect_data = cx.is_staff();
@@ -692,9 +701,10 @@ impl Zeta {
.and_then(|project_state| project_state.context.as_ref())
.unwrap_or(&HashMap::default())
.iter()
- .filter_map(|(buffer, ranges)| {
- let buffer = buffer.read(cx);
+ .filter_map(|(buffer_entity, ranges)| {
+ let buffer = buffer_entity.read(cx);
Some((
+ buffer_entity.clone(),
buffer.snapshot(),
buffer.file()?.full_path(cx).into(),
ranges.clone(),
@@ -703,8 +713,7 @@ impl Zeta {
.collect::<Vec<_>>();
let request_task = cx.background_spawn({
- let snapshot = snapshot.clone();
- let buffer = buffer.clone();
+ let active_buffer = active_buffer.clone();
async move {
let index_state = if let Some(index_state) = index_state {
Some(index_state.lock_owned().await)
@@ -712,8 +721,8 @@ impl Zeta {
None
};
- let cursor_offset = position.to_offset(&snapshot);
- let cursor_point = cursor_offset.to_point(&snapshot);
+ let cursor_offset = position.to_offset(&active_snapshot);
+ let cursor_point = cursor_offset.to_point(&active_snapshot);
let before_retrieval = chrono::Utc::now();
@@ -721,29 +730,30 @@ impl Zeta {
Self::gather_nearby_diagnostics(
cursor_offset,
&diagnostics,
- &snapshot,
+ &active_snapshot,
options.max_diagnostic_bytes,
);
- let request = match options.context {
- ContextMode::Llm(context_options) => {
+ let cloud_request = match options.context {
+ ContextMode::Agentic(context_options) => {
let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
cursor_point,
- &snapshot,
+ &active_snapshot,
&context_options.excerpt,
index_state.as_deref(),
) else {
return Ok((None, None));
};
- let excerpt_anchor_range = snapshot.anchor_after(excerpt.range.start)
- ..snapshot.anchor_before(excerpt.range.end);
+ let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
+ ..active_snapshot.anchor_before(excerpt.range.end);
- if let Some(buffer_ix) = included_files
- .iter()
- .position(|(buffer, _, _)| buffer.remote_id() == snapshot.remote_id())
+ if let Some(buffer_ix) =
+ included_files.iter().position(|(_, snapshot, _, _)| {
+ snapshot.remote_id() == active_snapshot.remote_id()
+ })
{
- let (buffer, _, ranges) = &mut included_files[buffer_ix];
+ let (_, buffer, _, ranges) = &mut included_files[buffer_ix];
let range_ix = ranges
.binary_search_by(|probe| {
probe
@@ -758,15 +768,16 @@ impl Zeta {
included_files.swap(buffer_ix, last_ix);
} else {
included_files.push((
- snapshot,
+ active_buffer.clone(),
+ active_snapshot,
excerpt_path.clone(),
vec![excerpt_anchor_range],
));
}
let included_files = included_files
- .into_iter()
- .map(|(buffer, path, ranges)| {
+ .iter()
+ .map(|(_, buffer, path, ranges)| {
let excerpts = merge_excerpts(
&buffer,
ranges.iter().map(|range| {
@@ -775,7 +786,7 @@ impl Zeta {
}),
);
predict_edits_v3::IncludedFile {
- path,
+ path: path.clone(),
max_row: Line(buffer.max_point().row),
excerpts,
}
@@ -809,7 +820,7 @@ impl Zeta {
ContextMode::Syntax(context_options) => {
let Some(context) = EditPredictionContext::gather_context(
cursor_point,
- &snapshot,
+ &active_snapshot,
parent_abs_path.as_deref(),
&context_options,
index_state.as_deref(),
@@ -834,24 +845,27 @@ impl Zeta {
}
};
+ let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
+
let retrieval_time = chrono::Utc::now() - before_retrieval;
let debug_response_tx = if let Some(debug_tx) = &debug_tx {
let (response_tx, response_rx) = oneshot::channel();
- let local_prompt = build_prompt(&request)
- .map(|(prompt, _)| prompt)
- .map_err(|err| err.to_string());
-
debug_tx
- .unbounded_send(ZetaDebugInfo::EditPredicted(ZetaEditPredictionDebugInfo {
- request: request.clone(),
- retrieval_time,
- buffer: buffer.downgrade(),
- local_prompt,
- position,
- response_rx,
- }))
+ .unbounded_send(ZetaDebugInfo::EditPredictionRequested(
+ ZetaEditPredictionDebugInfo {
+ request: cloud_request.clone(),
+ retrieval_time,
+ buffer: active_buffer.downgrade(),
+ local_prompt: match prompt_result.as_ref() {
+ Ok((prompt, _)) => Ok(prompt.clone()),
+ Err(err) => Err(err.to_string()),
+ },
+ position,
+ response_rx,
+ },
+ ))
.ok();
Some(response_tx)
} else {
@@ -861,61 +875,144 @@ impl Zeta {
if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
if let Some(debug_response_tx) = debug_response_tx {
debug_response_tx
- .send(Err("Request skipped".to_string()))
+ .send((Err("Request skipped".to_string()), TimeDelta::zero()))
.ok();
}
anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
}
+ let (prompt, _) = prompt_result?;
+ let request = open_ai::Request {
+ model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string()),
+ messages: vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt),
+ }],
+ stream: false,
+ max_completion_tokens: None,
+ stop: Default::default(),
+ temperature: 0.7,
+ tool_choice: None,
+ parallel_tool_calls: None,
+ tools: vec![],
+ prompt_cache_key: None,
+ reasoning_effort: None,
+ };
+
+ log::trace!("Sending edit prediction request");
+
+ let before_request = chrono::Utc::now();
let response =
- Self::send_prediction_request(client, llm_token, app_version, request).await;
+ Self::send_raw_llm_request(client, llm_token, app_version, request).await;
+ let request_time = chrono::Utc::now() - before_request;
+
+ log::trace!("Got edit prediction response");
if let Some(debug_response_tx) = debug_response_tx {
debug_response_tx
- .send(
+ .send((
response
.as_ref()
.map_err(|err| err.to_string())
.map(|response| response.0.clone()),
- )
+ request_time,
+ ))
.ok();
}
- response.map(|(res, usage)| (Some(res), usage))
+ let (mut res, usage) = response?;
+
+ let request_id = EditPredictionId(Uuid::from_str(&res.id)?);
+
+ let Some(choice) = res.choices.pop() else {
+ return Ok((None, usage));
+ };
+
+ let output_text = match choice.message {
+ open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Plain(content)),
+ ..
+ } => content,
+ open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Multipart(mut content)),
+ ..
+ } => {
+ if content.is_empty() {
+ log::error!("No output from Baseten completion response");
+ return Ok((None, usage));
+ }
+
+ match content.remove(0) {
+ open_ai::MessagePart::Text { text } => text,
+ open_ai::MessagePart::Image { .. } => {
+ log::error!("Expected text, got an image");
+ return Ok((None, usage));
+ }
+ }
+ }
+ _ => {
+ log::error!("Invalid response message: {:?}", choice.message);
+ return Ok((None, usage));
+ }
+ };
+
+ let (edited_buffer_snapshot, edits) =
+ crate::udiff::parse_diff(&output_text, |path| {
+ included_files
+ .iter()
+ .find_map(|(_, buffer, probe_path, ranges)| {
+ if probe_path.as_ref() == path {
+ Some((buffer, ranges.as_slice()))
+ } else {
+ None
+ }
+ })
+ })
+ .await?;
+
+ let edited_buffer = included_files
+ .iter()
+ .find_map(|(buffer, snapshot, _, _)| {
+ if snapshot.remote_id() == edited_buffer_snapshot.remote_id() {
+ Some(buffer.clone())
+ } else {
+ None
+ }
+ })
+ .context("Failed to find buffer in included_buffers, even though we just found the snapshot")?;
+
+ anyhow::Ok((Some((request_id, edited_buffer, edited_buffer_snapshot.clone(), edits)), usage))
}
});
- let buffer = buffer.clone();
-
cx.spawn({
- let project = project.clone();
async move |this, cx| {
- let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
+ let Some((id, edited_buffer, edited_buffer_snapshot, edits)) =
+ Self::handle_api_response(&this, request_task.await, cx)?
else {
return Ok(None);
};
// TODO telemetry: duration, etc
- Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
+ Ok(
+ EditPrediction::new(id, &edited_buffer, &edited_buffer_snapshot, edits, cx)
+ .await,
+ )
}
})
}
- async fn send_prediction_request(
+ async fn send_raw_llm_request(
client: Arc<Client>,
llm_token: LlmApiToken,
app_version: SemanticVersion,
- request: predict_edits_v3::PredictEditsRequest,
- ) -> Result<(
- predict_edits_v3::PredictEditsResponse,
- Option<EditPredictionUsage>,
- )> {
+ request: open_ai::Request,
+ ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
http_client::Url::parse(&predict_edits_url)?
} else {
client
.http_client()
- .build_zed_llm_url("/predict_edits/v3", &[])?
+ .build_zed_llm_url("/predict_edits/raw", &[])?
};
Self::send_api_request(
@@ -1052,7 +1149,7 @@ impl Zeta {
cursor_position: language::Anchor,
cx: &mut Context<Self>,
) {
- if !matches!(&self.options().context, ContextMode::Llm { .. }) {
+ if !matches!(&self.options().context, ContextMode::Agentic { .. }) {
return;
}
@@ -1100,36 +1197,149 @@ impl Zeta {
cursor_position: language::Anchor,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
+ let Some(zeta_project) = self.projects.get(&project.entity_id()) else {
+ return Task::ready(anyhow::Ok(()));
+ };
+
+ let ContextMode::Agentic(options) = &self.options().context else {
+ return Task::ready(anyhow::Ok(()));
+ };
+
+ let snapshot = buffer.read(cx).snapshot();
+ let cursor_point = cursor_position.to_point(&snapshot);
+ let Some(cursor_excerpt) = EditPredictionExcerpt::select_from_buffer(
+ cursor_point,
+ &snapshot,
+ &options.excerpt,
+ None,
+ ) else {
+ return Task::ready(Ok(()));
+ };
+
+ let current_file_path: Arc<Path> = snapshot
+ .file()
+ .map(|f| f.full_path(cx).into())
+ .unwrap_or_else(|| Path::new("untitled").into());
+
+ let prompt = match cloud_zeta2_prompt::retrieval_prompt::build_prompt(
+ predict_edits_v3::PlanContextRetrievalRequest {
+ excerpt: cursor_excerpt.text(&snapshot).body,
+ excerpt_path: current_file_path,
+ excerpt_line_range: cursor_excerpt.line_range,
+ cursor_file_max_row: Line(snapshot.max_point().row),
+ events: zeta_project
+ .events
+ .iter()
+ .filter_map(|ev| ev.to_request_event(cx))
+ .collect(),
+ },
+ ) {
+ Ok(prompt) => prompt,
+ Err(err) => {
+ return Task::ready(Err(err));
+ }
+ };
+
+ let app_version = AppVersion::global(cx);
+ let client = self.client.clone();
+ let llm_token = self.llm_token.clone();
+ let debug_tx = self.debug_tx.clone();
+
+ let (tool_schema, tool_description) = &*cloud_zeta2_prompt::retrieval_prompt::TOOL_SCHEMA;
+
+ let request = open_ai::Request {
+ model: std::env::var("ZED_ZETA2_MODEL").unwrap_or("2327jz9q".to_string()),
+ messages: vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt),
+ }],
+ stream: false,
+ max_completion_tokens: None,
+ stop: Default::default(),
+ temperature: 0.7,
+ tool_choice: None,
+ parallel_tool_calls: None,
+ tools: vec![open_ai::ToolDefinition::Function {
+ function: FunctionDefinition {
+ name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME.to_string(),
+ description: Some(tool_description.clone()),
+ parameters: Some(tool_schema.clone()),
+ },
+ }],
+ prompt_cache_key: None,
+ reasoning_effort: None,
+ };
+
cx.spawn(async move |this, cx| {
- let related_excerpts_result = this
- .update(cx, |this, cx| {
- let Some(zeta_project) = this.projects.get(&project.entity_id()) else {
- return Task::ready(anyhow::Ok(HashMap::default()));
- };
+ log::trace!("Sending search planning request");
+ let response =
+ Self::send_raw_llm_request(client, llm_token, app_version, request).await;
+ let mut response = Self::handle_api_response(&this, response, cx)?;
+
+ log::trace!("Got search planning response");
+
+ let choice = response
+ .choices
+ .pop()
+ .context("No choices in retrieval response")?;
+ let open_ai::RequestMessage::Assistant {
+ content: _,
+ tool_calls,
+ } = choice.message
+ else {
+ anyhow::bail!("Retrieval response didn't include an assistant message");
+ };
- let ContextMode::Llm(options) = &this.options().context else {
- return Task::ready(anyhow::Ok(HashMap::default()));
- };
+ let mut regex_by_glob: HashMap<String, String> = HashMap::default();
+ for tool_call in tool_calls {
+ let open_ai::ToolCallContent::Function { function } = tool_call.content;
+ if function.name != cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME {
+ log::warn!(
+ "Context retrieval response tried to call an unknown tool: {}",
+ function.name
+ );
- let mut edit_history_unified_diff = String::new();
+ continue;
+ }
- for event in zeta_project.events.iter() {
- if let Some(event) = event.to_request_event(cx) {
- writeln!(&mut edit_history_unified_diff, "{event}").ok();
- }
+ let input: SearchToolInput = serde_json::from_str(&function.arguments)?;
+ for query in input.queries {
+ let regex = regex_by_glob.entry(query.glob).or_default();
+ if !regex.is_empty() {
+ regex.push('|');
}
+ regex.push_str(&query.regex);
+ }
+ }
- find_related_excerpts(
- buffer.clone(),
- cursor_position,
- &project,
- edit_history_unified_diff,
- options,
- this.debug_tx.clone(),
- cx,
- )
- })?
- .await;
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
+ ZetaSearchQueryDebugInfo {
+ project: project.clone(),
+ timestamp: Instant::now(),
+ regex_by_glob: regex_by_glob.clone(),
+ },
+ ))
+ .ok();
+ }
+
+ log::trace!("Running retrieval search: {regex_by_glob:#?}");
+
+ let related_excerpts_result =
+ retrieval_search::run_retrieval_searches(project.clone(), regex_by_glob, cx).await;
+
+ log::trace!("Search queries executed");
+
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
+ ZetaContextRetrievalDebugInfo {
+ project: project.clone(),
+ timestamp: Instant::now(),
+ },
+ ))
+ .ok();
+ }
this.update(cx, |this, _cx| {
let Some(zeta_project) = this.projects.get_mut(&project.entity_id()) else {
@@ -1249,7 +1459,7 @@ impl Zeta {
&snapshot,
parent_abs_path.as_deref(),
match &options.context {
- ContextMode::Llm(_) => {
+ ContextMode::Agentic(_) => {
// TODO
panic!("Llm mode not supported in zeta cli yet");
}
@@ -1426,15 +1636,11 @@ fn add_signature(
#[cfg(test)]
mod tests {
- use std::{
- path::{Path, PathBuf},
- sync::Arc,
- };
+ use std::{path::Path, sync::Arc};
use client::UserStore;
use clock::FakeSystemClock;
- use cloud_llm_client::predict_edits_v3::{self, Point};
- use edit_prediction_context::Line;
+ use cloud_zeta2_prompt::retrieval_prompt::{SearchToolInput, SearchToolQuery};
use futures::{
AsyncReadExt, StreamExt,
channel::{mpsc, oneshot},
@@ -1445,7 +1651,8 @@ mod tests {
prelude::*,
};
use indoc::indoc;
- use language::{LanguageServerId, OffsetRangeExt as _};
+ use language::OffsetRangeExt as _;
+ use open_ai::Usage;
use pretty_assertions::{assert_eq, assert_matches};
use project::{FakeFs, Project};
use serde_json::json;
@@ -1462,8 +1669,8 @@ mod tests {
fs.insert_tree(
"/root",
json!({
- "1.txt": "Hello!\nHow\nBye",
- "2.txt": "Hola!\nComo\nAdios"
+ "1.txt": "Hello!\nHow\nBye\n",
+ "2.txt": "Hola!\nComo\nAdios\n"
}),
)
.await;
@@ -1489,16 +1696,17 @@ mod tests {
zeta.refresh_prediction(&project, &buffer1, position, cx)
});
let (_request, respond_tx) = req_rx.next().await.unwrap();
+
respond_tx
- .send(predict_edits_v3::PredictEditsResponse {
- request_id: Uuid::new_v4(),
- edits: vec![predict_edits_v3::Edit {
- path: Path::new(path!("root/1.txt")).into(),
- range: Line(0)..Line(snapshot1.max_point().row + 1),
- content: "Hello!\nHow are you?\nBye".into(),
- }],
- debug_info: None,
- })
+ .send(model_response(indoc! {r"
+ --- a/root/1.txt
+ +++ b/root/1.txt
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "}))
.unwrap();
prediction_task.await.unwrap();
@@ -1509,21 +1717,67 @@ mod tests {
assert_matches!(prediction, BufferEditPrediction::Local { .. });
});
+ // Context refresh
+ let refresh_task = zeta.update(cx, |zeta, cx| {
+ zeta.refresh_context(project.clone(), buffer1.clone(), position, cx)
+ });
+ let (_request, respond_tx) = req_rx.next().await.unwrap();
+ respond_tx
+ .send(open_ai::Response {
+ id: Uuid::new_v4().to_string(),
+ object: "response".into(),
+ created: 0,
+ model: "model".into(),
+ choices: vec![open_ai::Choice {
+ index: 0,
+ message: open_ai::RequestMessage::Assistant {
+ content: None,
+ tool_calls: vec![open_ai::ToolCall {
+ id: "search".into(),
+ content: open_ai::ToolCallContent::Function {
+ function: open_ai::FunctionContent {
+ name: cloud_zeta2_prompt::retrieval_prompt::TOOL_NAME
+ .to_string(),
+ arguments: serde_json::to_string(&SearchToolInput {
+ queries: Box::new([SearchToolQuery {
+ glob: "root/2.txt".to_string(),
+ regex: ".".to_string(),
+ }]),
+ })
+ .unwrap(),
+ },
+ },
+ }],
+ },
+ finish_reason: None,
+ }],
+ usage: Usage {
+ prompt_tokens: 0,
+ completion_tokens: 0,
+ total_tokens: 0,
+ },
+ })
+ .unwrap();
+ refresh_task.await.unwrap();
+
+ zeta.update(cx, |zeta, _cx| {
+ zeta.discard_current_prediction(&project);
+ });
+
// Prediction for another file
let prediction_task = zeta.update(cx, |zeta, cx| {
zeta.refresh_prediction(&project, &buffer1, position, cx)
});
let (_request, respond_tx) = req_rx.next().await.unwrap();
respond_tx
- .send(predict_edits_v3::PredictEditsResponse {
- request_id: Uuid::new_v4(),
- edits: vec![predict_edits_v3::Edit {
- path: Path::new(path!("root/2.txt")).into(),
- range: Line(0)..Line(snapshot1.max_point().row + 1),
- content: "Hola!\nComo estas?\nAdios".into(),
- }],
- debug_info: None,
- })
+ .send(model_response(indoc! {r#"
+ --- a/root/2.txt
+ +++ b/root/2.txt
+ Hola!
+ -Como
+ +Como estas?
+ Adios
+ "#}))
.unwrap();
prediction_task.await.unwrap();
zeta.read_with(cx, |zeta, cx| {
@@ -1532,7 +1786,7 @@ mod tests {
.unwrap();
assert_matches!(
prediction,
- BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
+ BufferEditPrediction::Jump { prediction } if prediction.snapshot.file().unwrap().full_path(cx) == Path::new(path!("root/2.txt"))
);
});
@@ -1559,7 +1813,7 @@ mod tests {
fs.insert_tree(
"/root",
json!({
- "foo.md": "Hello!\nHow\nBye"
+ "foo.md": "Hello!\nHow\nBye\n"
}),
)
.await;
@@ -1579,29 +1833,31 @@ mod tests {
zeta.request_prediction(&project, &buffer, position, cx)
});
- let (request, respond_tx) = req_rx.next().await.unwrap();
- assert_eq!(
- request.excerpt_path.as_ref(),
- Path::new(path!("root/foo.md"))
- );
- assert_eq!(
- request.cursor_point,
- Point {
- line: Line(1),
- column: 3
- }
- );
+ let (_, respond_tx) = req_rx.next().await.unwrap();
+
+ // TODO Put back when we have a structured request again
+ // assert_eq!(
+ // request.excerpt_path.as_ref(),
+ // Path::new(path!("root/foo.md"))
+ // );
+ // assert_eq!(
+ // request.cursor_point,
+ // Point {
+ // line: Line(1),
+ // column: 3
+ // }
+ // );
respond_tx
- .send(predict_edits_v3::PredictEditsResponse {
- request_id: Uuid::new_v4(),
- edits: vec![predict_edits_v3::Edit {
- path: Path::new(path!("root/foo.md")).into(),
- range: Line(0)..Line(snapshot.max_point().row + 1),
- content: "Hello!\nHow are you?\nBye".into(),
- }],
- debug_info: None,
- })
+ .send(model_response(indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "}))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap();
@@ -1611,7 +1867,7 @@ mod tests {
prediction.edits[0].0.to_point(&snapshot).start,
language::Point::new(1, 3)
);
- assert_eq!(prediction.edits[0].1, " are you?");
+ assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
}
#[gpui::test]
@@ -1621,7 +1877,7 @@ mod tests {
fs.insert_tree(
"/root",
json!({
- "foo.md": "Hello!\n\nBye"
+ "foo.md": "Hello!\n\nBye\n"
}),
)
.await;
@@ -1652,34 +1908,30 @@ mod tests {
let (request, respond_tx) = req_rx.next().await.unwrap();
- assert_eq!(request.events.len(), 1);
- assert_eq!(
- request.events[0],
- predict_edits_v3::Event::BufferChange {
- path: Some(PathBuf::from(path!("root/foo.md"))),
- old_path: None,
- diff: indoc! {"
- @@ -1,3 +1,3 @@
- Hello!
- -
- +How
- Bye
- "}
- .to_string(),
- predicted: false
- }
+ let prompt = prompt_from_request(&request);
+ assert!(
+ prompt.contains(indoc! {"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ -1,3 +1,3 @@
+ Hello!
+ -
+ +How
+ Bye
+ "}),
+ "{prompt}"
);
respond_tx
- .send(predict_edits_v3::PredictEditsResponse {
- request_id: Uuid::new_v4(),
- edits: vec![predict_edits_v3::Edit {
- path: Path::new(path!("root/foo.md")).into(),
- range: Line(0)..Line(snapshot.max_point().row + 1),
- content: "Hello!\nHow are you?\nBye".into(),
- }],
- debug_info: None,
- })
+ .send(model_response(indoc! {r#"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "#}))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap();
@@ -1689,114 +1941,150 @@ mod tests {
prediction.edits[0].0.to_point(&snapshot).start,
language::Point::new(1, 3)
);
- assert_eq!(prediction.edits[0].1, " are you?");
+ assert_eq!(prediction.edits[0].1.as_ref(), " are you?");
}
- #[gpui::test]
- async fn test_request_diagnostics(cx: &mut TestAppContext) {
- let (zeta, mut req_rx) = init_test(cx);
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/root",
- json!({
- "foo.md": "Hello!\nBye"
- }),
- )
- .await;
- let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+ // Skipped until we start including diagnostics in prompt
+ // #[gpui::test]
+ // async fn test_request_diagnostics(cx: &mut TestAppContext) {
+ // let (zeta, mut req_rx) = init_test(cx);
+ // let fs = FakeFs::new(cx.executor());
+ // fs.insert_tree(
+ // "/root",
+ // json!({
+ // "foo.md": "Hello!\nBye"
+ // }),
+ // )
+ // .await;
+ // let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ // let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
+ // let diagnostic = lsp::Diagnostic {
+ // range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
+ // severity: Some(lsp::DiagnosticSeverity::ERROR),
+ // message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
+ // ..Default::default()
+ // };
+
+ // project.update(cx, |project, cx| {
+ // project.lsp_store().update(cx, |lsp_store, cx| {
+ // // Create some diagnostics
+ // lsp_store
+ // .update_diagnostics(
+ // LanguageServerId(0),
+ // lsp::PublishDiagnosticsParams {
+ // uri: path_to_buffer_uri.clone(),
+ // diagnostics: vec![diagnostic],
+ // version: None,
+ // },
+ // None,
+ // language::DiagnosticSourceKind::Pushed,
+ // &[],
+ // cx,
+ // )
+ // .unwrap();
+ // });
+ // });
+
+ // let buffer = project
+ // .update(cx, |project, cx| {
+ // let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
+ // project.open_buffer(path, cx)
+ // })
+ // .await
+ // .unwrap();
+
+ // let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ // let position = snapshot.anchor_before(language::Point::new(0, 0));
+
+ // let _prediction_task = zeta.update(cx, |zeta, cx| {
+ // zeta.request_prediction(&project, &buffer, position, cx)
+ // });
+
+ // let (request, _respond_tx) = req_rx.next().await.unwrap();
+
+ // assert_eq!(request.diagnostic_groups.len(), 1);
+ // let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
+ // .unwrap();
+ // // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
+ // assert_eq!(
+ // value,
+ // json!({
+ // "entries": [{
+ // "range": {
+ // "start": 8,
+ // "end": 10
+ // },
+ // "diagnostic": {
+ // "source": null,
+ // "code": null,
+ // "code_description": null,
+ // "severity": 1,
+ // "message": "\"Hello\" deprecated. Use \"Hi\" instead",
+ // "markdown": null,
+ // "group_id": 0,
+ // "is_primary": true,
+ // "is_disk_based": false,
+ // "is_unnecessary": false,
+ // "source_kind": "Pushed",
+ // "data": null,
+ // "underline": true
+ // }
+ // }],
+ // "primary_ix": 0
+ // })
+ // );
+ // }
+
+ fn model_response(text: &str) -> open_ai::Response {
+ open_ai::Response {
+ id: Uuid::new_v4().to_string(),
+ object: "response".into(),
+ created: 0,
+ model: "model".into(),
+ choices: vec![open_ai::Choice {
+ index: 0,
+ message: open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Plain(text.to_string())),
+ tool_calls: vec![],
+ },
+ finish_reason: None,
+ }],
+ usage: Usage {
+ prompt_tokens: 0,
+ completion_tokens: 0,
+ total_tokens: 0,
+ },
+ }
+ }
- let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
- let diagnostic = lsp::Diagnostic {
- range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
- severity: Some(lsp::DiagnosticSeverity::ERROR),
- message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
- ..Default::default()
+ fn prompt_from_request(request: &open_ai::Request) -> &str {
+ assert_eq!(request.messages.len(), 1);
+ let open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(content),
+ ..
+ } = &request.messages[0]
+ else {
+ panic!(
+ "Request does not have single user message of type Plain. {:#?}",
+ request
+ );
};
-
- project.update(cx, |project, cx| {
- project.lsp_store().update(cx, |lsp_store, cx| {
- // Create some diagnostics
- lsp_store
- .update_diagnostics(
- LanguageServerId(0),
- lsp::PublishDiagnosticsParams {
- uri: path_to_buffer_uri.clone(),
- diagnostics: vec![diagnostic],
- version: None,
- },
- None,
- language::DiagnosticSourceKind::Pushed,
- &[],
- cx,
- )
- .unwrap();
- });
- });
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
-
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let position = snapshot.anchor_before(language::Point::new(0, 0));
-
- let _prediction_task = zeta.update(cx, |zeta, cx| {
- zeta.request_prediction(&project, &buffer, position, cx)
- });
-
- let (request, _respond_tx) = req_rx.next().await.unwrap();
-
- assert_eq!(request.diagnostic_groups.len(), 1);
- let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
- .unwrap();
- // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
- assert_eq!(
- value,
- json!({
- "entries": [{
- "range": {
- "start": 8,
- "end": 10
- },
- "diagnostic": {
- "source": null,
- "code": null,
- "code_description": null,
- "severity": 1,
- "message": "\"Hello\" deprecated. Use \"Hi\" instead",
- "markdown": null,
- "group_id": 0,
- "is_primary": true,
- "is_disk_based": false,
- "is_unnecessary": false,
- "source_kind": "Pushed",
- "data": null,
- "underline": true
- }
- }],
- "primary_ix": 0
- })
- );
+ content
}
fn init_test(
cx: &mut TestAppContext,
) -> (
Entity<Zeta>,
- mpsc::UnboundedReceiver<(
- predict_edits_v3::PredictEditsRequest,
- oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
- )>,
+ mpsc::UnboundedReceiver<(open_ai::Request, oneshot::Sender<open_ai::Response>)>,
) {
cx.update(move |cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
language::init(cx);
Project::init_settings(cx);
+ zlog::init_test();
let (req_tx, req_rx) = mpsc::unbounded();
@@ -27,10 +27,11 @@ log.workspace = true
multi_buffer.workspace = true
ordered-float.workspace = true
project.workspace = true
+regex-syntax = "0.8.8"
serde.workspace = true
+serde_json.workspace = true
telemetry.workspace = true
text.workspace = true
-regex-syntax = "0.8.8"
ui.workspace = true
ui_input.workspace = true
util.workspace = true
@@ -45,7 +45,6 @@ struct RetrievalRun {
started_at: Instant,
search_results_generated_at: Option<Instant>,
search_results_executed_at: Option<Instant>,
- search_results_filtered_at: Option<Instant>,
finished_at: Option<Instant>,
}
@@ -117,17 +116,12 @@ impl Zeta2ContextView {
self.handle_search_queries_executed(info, window, cx);
}
}
- ZetaDebugInfo::SearchResultsFiltered(info) => {
- if info.project == self.project {
- self.handle_search_results_filtered(info, window, cx);
- }
- }
ZetaDebugInfo::ContextRetrievalFinished(info) => {
if info.project == self.project {
self.handle_context_retrieval_finished(info, window, cx);
}
}
- ZetaDebugInfo::EditPredicted(_) => {}
+ ZetaDebugInfo::EditPredictionRequested(_) => {}
}
}
@@ -159,7 +153,6 @@ impl Zeta2ContextView {
started_at: info.timestamp,
search_results_generated_at: None,
search_results_executed_at: None,
- search_results_filtered_at: None,
finished_at: None,
});
@@ -218,18 +211,18 @@ impl Zeta2ContextView {
run.search_results_generated_at = Some(info.timestamp);
run.search_queries = info
- .queries
+ .regex_by_glob
.into_iter()
- .map(|query| {
+ .map(|(glob, regex)| {
let mut regex_parser = regex_syntax::ast::parse::Parser::new();
GlobQueries {
- glob: query.glob,
- alternations: match regex_parser.parse(&query.regex) {
+ glob,
+ alternations: match regex_parser.parse(®ex) {
Ok(regex_syntax::ast::Ast::Alternation(ref alt)) => {
alt.asts.iter().map(|ast| ast.to_string()).collect()
}
- _ => vec![query.regex],
+ _ => vec![regex],
},
}
})
@@ -256,20 +249,6 @@ impl Zeta2ContextView {
cx.notify();
}
- fn handle_search_results_filtered(
- &mut self,
- info: ZetaContextRetrievalDebugInfo,
- _window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- let Some(run) = self.runs.back_mut() else {
- return;
- };
-
- run.search_results_filtered_at = Some(info.timestamp);
- cx.notify();
- }
-
fn handle_go_back(
&mut self,
_: &Zeta2ContextGoBack,
@@ -398,19 +377,10 @@ impl Zeta2ContextView {
};
div = div.child(format!("Ran search: {:>5} ms", (t2 - t1).as_millis()));
- let Some(t3) = run.search_results_filtered_at else {
- return pending_message(div, "Filtering results...");
- };
- div =
- div.child(format!("Filtered results: {:>5} ms", (t3 - t2).as_millis()));
-
- let Some(t4) = run.finished_at else {
- return pending_message(div, "Building excerpts");
- };
- div = div
- .child(format!("Build excerpts: {:>5} ยตs", (t4 - t3).as_micros()))
- .child(format!("Total: {:>5} ms", (t4 - t0).as_millis()));
- div
+ div.child(format!(
+ "Total: {:>5} ms",
+ (run.finished_at.unwrap_or(t0) - t0).as_millis()
+ ))
}),
)
}
@@ -5,7 +5,7 @@ use std::{cmp::Reverse, path::PathBuf, str::FromStr, sync::Arc, time::Duration};
use chrono::TimeDelta;
use client::{Client, UserStore};
use cloud_llm_client::predict_edits_v3::{
- self, DeclarationScoreComponents, PredictEditsRequest, PredictEditsResponse, PromptFormat,
+ DeclarationScoreComponents, PredictEditsRequest, PromptFormat,
};
use collections::HashMap;
use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer};
@@ -23,7 +23,7 @@ use ui_input::InputField;
use util::{ResultExt, paths::PathStyle, rel_path::RelPath};
use workspace::{Item, SplitDirection, Workspace};
use zeta2::{
- ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, LlmContextOptions, Zeta, Zeta2FeatureFlag,
+ AgenticContextOptions, ContextMode, DEFAULT_SYNTAX_CONTEXT_OPTIONS, Zeta, Zeta2FeatureFlag,
ZetaDebugInfo, ZetaEditPredictionDebugInfo, ZetaOptions,
};
@@ -123,6 +123,7 @@ struct LastPrediction {
context_editor: Entity<Editor>,
prompt_editor: Entity<Editor>,
retrieval_time: TimeDelta,
+ request_time: Option<TimeDelta>,
buffer: WeakEntity<Buffer>,
position: language::Anchor,
state: LastPredictionState,
@@ -143,7 +144,7 @@ enum LastPredictionState {
model_response_editor: Entity<Editor>,
feedback_editor: Entity<Editor>,
feedback: Option<Feedback>,
- response: predict_edits_v3::PredictEditsResponse,
+ request_id: String,
},
Failed {
message: String,
@@ -217,7 +218,7 @@ impl Zeta2Inspector {
});
match &options.context {
- ContextMode::Llm(_) => {
+ ContextMode::Agentic(_) => {
self.context_mode = ContextModeState::Llm;
}
ContextMode::Syntax(_) => {
@@ -307,9 +308,11 @@ impl Zeta2Inspector {
};
let context = match zeta_options.context {
- ContextMode::Llm(_context_options) => ContextMode::Llm(LlmContextOptions {
- excerpt: excerpt_options,
- }),
+ ContextMode::Agentic(_context_options) => {
+ ContextMode::Agentic(AgenticContextOptions {
+ excerpt: excerpt_options,
+ })
+ }
ContextMode::Syntax(context_options) => {
let max_retrieved_declarations = match &this.context_mode {
ContextModeState::Llm => {
@@ -368,7 +371,7 @@ impl Zeta2Inspector {
let language_registry = self.project.read(cx).languages().clone();
async move |this, cx| {
let mut languages = HashMap::default();
- let ZetaDebugInfo::EditPredicted(prediction) = prediction else {
+ let ZetaDebugInfo::EditPredictionRequested(prediction) = prediction else {
return;
};
for ext in prediction
@@ -396,6 +399,8 @@ impl Zeta2Inspector {
.await
.log_err();
+ let json_language = language_registry.language_for_name("Json").await.log_err();
+
this.update_in(cx, |this, window, cx| {
let context_editor = cx.new(|cx| {
let mut excerpt_score_components = HashMap::default();
@@ -492,25 +497,15 @@ impl Zeta2Inspector {
let task = cx.spawn_in(window, {
let markdown_language = markdown_language.clone();
+ let json_language = json_language.clone();
async move |this, cx| {
let response = response_rx.await;
this.update_in(cx, |this, window, cx| {
if let Some(prediction) = this.last_prediction.as_mut() {
prediction.state = match response {
- Ok(Ok(response)) => {
- if let Some(debug_info) = &response.debug_info {
- prediction.prompt_editor.update(
- cx,
- |prompt_editor, cx| {
- prompt_editor.set_text(
- debug_info.prompt.as_str(),
- window,
- cx,
- );
- },
- );
- }
+ Ok((Ok(response), request_time)) => {
+ prediction.request_time = Some(request_time);
let feedback_editor = cx.new(|cx| {
let buffer = cx.new(|cx| {
@@ -577,16 +572,11 @@ impl Zeta2Inspector {
model_response_editor: cx.new(|cx| {
let buffer = cx.new(|cx| {
let mut buffer = Buffer::local(
- response
- .debug_info
- .as_ref()
- .map(|p| p.model_response.as_str())
- .unwrap_or(
- "(Debug info not available)",
- ),
+ serde_json::to_string_pretty(&response)
+ .unwrap_or_default(),
cx,
);
- buffer.set_language(markdown_language, cx);
+ buffer.set_language(json_language, cx);
buffer
});
let buffer = cx.new(|cx| {
@@ -607,10 +597,11 @@ impl Zeta2Inspector {
}),
feedback_editor,
feedback: None,
- response,
+ request_id: response.id.clone(),
}
}
- Ok(Err(err)) => {
+ Ok((Err(err), request_time)) => {
+ prediction.request_time = Some(request_time);
LastPredictionState::Failed { message: err }
}
Err(oneshot::Canceled) => LastPredictionState::Failed {
@@ -644,6 +635,7 @@ impl Zeta2Inspector {
editor
}),
retrieval_time,
+ request_time: None,
buffer,
position,
state: LastPredictionState::Requested,
@@ -700,7 +692,7 @@ impl Zeta2Inspector {
feedback: feedback_state,
feedback_editor,
model_response_editor,
- response,
+ request_id,
..
} = &mut last_prediction.state
else {
@@ -734,11 +726,10 @@ impl Zeta2Inspector {
telemetry::event!(
"Zeta2 Prediction Rated",
- id = response.request_id,
+ id = request_id,
kind = kind,
text = text,
request = last_prediction.request,
- response = response,
project_snapshot = project_snapshot,
);
})
@@ -834,11 +825,11 @@ impl Zeta2Inspector {
let current_options =
this.zeta.read(cx).options().clone();
match current_options.context.clone() {
- ContextMode::Llm(_) => {}
+ ContextMode::Agentic(_) => {}
ContextMode::Syntax(context_options) => {
let options = ZetaOptions {
- context: ContextMode::Llm(
- LlmContextOptions {
+ context: ContextMode::Agentic(
+ AgenticContextOptions {
excerpt: context_options.excerpt,
},
),
@@ -865,7 +856,7 @@ impl Zeta2Inspector {
let current_options =
this.zeta.read(cx).options().clone();
match current_options.context.clone() {
- ContextMode::Llm(context_options) => {
+ ContextMode::Agentic(context_options) => {
let options = ZetaOptions {
context: ContextMode::Syntax(
EditPredictionContextOptions {
@@ -976,25 +967,6 @@ impl Zeta2Inspector {
return None;
};
- let (prompt_planning_time, inference_time, parsing_time) =
- if let LastPredictionState::Success {
- response:
- PredictEditsResponse {
- debug_info: Some(debug_info),
- ..
- },
- ..
- } = &prediction.state
- {
- (
- Some(debug_info.prompt_planning_time),
- Some(debug_info.inference_time),
- Some(debug_info.parsing_time),
- )
- } else {
- (None, None, None)
- };
-
Some(
v_flex()
.p_4()
@@ -1005,12 +977,7 @@ impl Zeta2Inspector {
"Context retrieval",
Some(prediction.retrieval_time),
))
- .child(Self::render_duration(
- "Prompt planning",
- prompt_planning_time,
- ))
- .child(Self::render_duration("Inference", inference_time))
- .child(Self::render_duration("Parsing", parsing_time)),
+ .child(Self::render_duration("Request", prediction.request_time)),
)
}
@@ -7,13 +7,14 @@ use std::{
use anyhow::Result;
use clap::Args;
-use cloud_llm_client::udiff::DiffLine;
use collections::HashSet;
use gpui::AsyncApp;
+use zeta2::udiff::DiffLine;
use crate::{
example::{Example, NamedExample},
headless::ZetaCliAppState,
+ paths::CACHE_DIR,
predict::{PredictionDetails, zeta2_predict},
};
@@ -54,10 +55,8 @@ pub async fn run_evaluate_one(
app_state: Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<EvaluationResult> {
- let cache_dir = Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap_or_default())
- .join("../../target/zeta-prediction-cache");
let example = NamedExample::load(&example_path).unwrap();
- let example_cache_path = cache_dir.join(&example_path.file_name().unwrap());
+ let example_cache_path = CACHE_DIR.join(&example_path.file_name().unwrap());
let predictions = if !re_run && example_cache_path.exists() {
let file_contents = fs::read_to_string(&example_cache_path)?;
@@ -74,7 +73,7 @@ pub async fn run_evaluate_one(
};
if !example_cache_path.exists() {
- fs::create_dir_all(&cache_dir).unwrap();
+ fs::create_dir_all(&*CACHE_DIR).unwrap();
fs::write(
example_cache_path,
serde_json::to_string(&predictions).unwrap(),
@@ -1,28 +1,31 @@
use std::{
borrow::Cow,
cell::RefCell,
- env,
fmt::{self, Display},
fs,
io::Write,
mem,
- ops::Range,
path::{Path, PathBuf},
sync::Arc,
};
-use anyhow::{Context as _, Result};
+use anyhow::{Context as _, Result, anyhow};
use clap::ValueEnum;
-use collections::{HashMap, HashSet};
+use cloud_zeta2_prompt::CURSOR_MARKER;
+use collections::HashMap;
use futures::{
AsyncWriteExt as _,
lock::{Mutex, OwnedMutexGuard},
};
use gpui::{AsyncApp, Entity, http_client::Url};
-use language::Buffer;
+use language::{Anchor, Buffer};
use project::{Project, ProjectPath};
use pulldown_cmark::CowStr;
use serde::{Deserialize, Serialize};
+use util::{paths::PathStyle, rel_path::RelPath};
+use zeta2::udiff::OpenedBuffers;
+
+use crate::paths::{REPOS_DIR, WORKTREES_DIR};
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
const EDIT_HISTORY_HEADING: &str = "Edit History";
@@ -215,12 +218,10 @@ impl NamedExample {
let (repo_owner, repo_name) = self.repo_name()?;
let file_name = self.file_name();
- let worktrees_dir = env::current_dir()?.join("target").join("zeta-worktrees");
- let repos_dir = env::current_dir()?.join("target").join("zeta-repos");
- fs::create_dir_all(&repos_dir)?;
- fs::create_dir_all(&worktrees_dir)?;
+ fs::create_dir_all(&*REPOS_DIR)?;
+ fs::create_dir_all(&*WORKTREES_DIR)?;
- let repo_dir = repos_dir.join(repo_owner.as_ref()).join(repo_name.as_ref());
+ let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
let repo_lock = lock_repo(&repo_dir).await;
if !repo_dir.is_dir() {
@@ -251,7 +252,7 @@ impl NamedExample {
};
// Create the worktree for this example if needed.
- let worktree_path = worktrees_dir.join(&file_name);
+ let worktree_path = WORKTREES_DIR.join(&file_name);
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
@@ -309,7 +310,6 @@ impl NamedExample {
.collect()
}
- #[allow(unused)]
fn repo_name(&self) -> Result<(Cow<'_, str>, Cow<'_, str>)> {
// git@github.com:owner/repo.git
if self.example.repository_url.contains('@') {
@@ -344,13 +344,63 @@ impl NamedExample {
}
}
+ pub async fn cursor_position(
+ &self,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+ ) -> Result<(Entity<Buffer>, Anchor)> {
+ let worktree = project.read_with(cx, |project, cx| {
+ project.visible_worktrees(cx).next().unwrap()
+ })?;
+ let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
+ let cursor_buffer = project
+ .update(cx, |project, cx| {
+ project.open_buffer(
+ ProjectPath {
+ worktree_id: worktree.read(cx).id(),
+ path: cursor_path,
+ },
+ cx,
+ )
+ })?
+ .await?;
+ let cursor_offset_within_excerpt = self
+ .example
+ .cursor_position
+ .find(CURSOR_MARKER)
+ .ok_or_else(|| anyhow!("missing cursor marker"))?;
+ let mut cursor_excerpt = self.example.cursor_position.clone();
+ cursor_excerpt.replace_range(
+ cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
+ "",
+ );
+ let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
+ let text = buffer.text();
+
+ let mut matches = text.match_indices(&cursor_excerpt);
+ let Some((excerpt_offset, _)) = matches.next() else {
+ anyhow::bail!(
+ "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
+ );
+ };
+ assert!(matches.next().is_none());
+
+ Ok(excerpt_offset)
+ })??;
+
+ let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
+ let cursor_anchor =
+ cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
+ Ok((cursor_buffer, cursor_anchor))
+ }
+
#[must_use]
pub async fn apply_edit_history(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
- ) -> Result<HashSet<Entity<Buffer>>> {
- apply_diff(&self.example.edit_history, project, cx).await
+ ) -> Result<OpenedBuffers<'_>> {
+ zeta2::udiff::apply_diff(&self.example.edit_history, project, cx).await
}
}
@@ -446,404 +496,3 @@ pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
.lock_owned()
.await
}
-
-#[must_use]
-pub async fn apply_diff(
- diff: &str,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
-) -> Result<HashSet<Entity<Buffer>>> {
- use cloud_llm_client::udiff::DiffLine;
- use std::fmt::Write;
-
- #[derive(Debug, Default)]
- struct HunkState {
- context: String,
- edits: Vec<Edit>,
- }
-
- #[derive(Debug)]
- struct Edit {
- range: Range<usize>,
- text: String,
- }
-
- let mut old_path = None;
- let mut new_path = None;
- let mut hunk = HunkState::default();
- let mut diff_lines = diff.lines().map(DiffLine::parse).peekable();
- let mut open_buffers = HashSet::default();
-
- while let Some(diff_line) = diff_lines.next() {
- match diff_line {
- DiffLine::OldPath { path } => old_path = Some(path),
- DiffLine::NewPath { path } => {
- if old_path.is_none() {
- anyhow::bail!(
- "Found a new path header (`+++`) before an (`---`) old path header"
- );
- }
- new_path = Some(path)
- }
- DiffLine::Context(ctx) => {
- writeln!(&mut hunk.context, "{ctx}")?;
- }
- DiffLine::Deletion(del) => {
- let range = hunk.context.len()..hunk.context.len() + del.len() + '\n'.len_utf8();
- if let Some(last_edit) = hunk.edits.last_mut()
- && last_edit.range.end == range.start
- {
- last_edit.range.end = range.end;
- } else {
- hunk.edits.push(Edit {
- range,
- text: String::new(),
- });
- }
- writeln!(&mut hunk.context, "{del}")?;
- }
- DiffLine::Addition(add) => {
- let range = hunk.context.len()..hunk.context.len();
- if let Some(last_edit) = hunk.edits.last_mut()
- && last_edit.range.end == range.start
- {
- writeln!(&mut last_edit.text, "{add}").unwrap();
- } else {
- hunk.edits.push(Edit {
- range,
- text: format!("{add}\n"),
- });
- }
- }
- DiffLine::HunkHeader(_) | DiffLine::Garbage(_) => {}
- }
-
- let at_hunk_end = match diff_lines.peek() {
- Some(DiffLine::OldPath { .. }) | Some(DiffLine::HunkHeader(_)) | None => true,
- _ => false,
- };
-
- if at_hunk_end {
- let hunk = mem::take(&mut hunk);
-
- let Some(old_path) = old_path.as_deref() else {
- anyhow::bail!("Missing old path (`---`) header")
- };
-
- let Some(new_path) = new_path.as_deref() else {
- anyhow::bail!("Missing new path (`+++`) header")
- };
-
- let buffer = project
- .update(cx, |project, cx| {
- let project_path = project
- .find_project_path(old_path, cx)
- .context("Failed to find old_path in project")?;
-
- anyhow::Ok(project.open_buffer(project_path, cx))
- })??
- .await?;
- open_buffers.insert(buffer.clone());
-
- if old_path != new_path {
- project
- .update(cx, |project, cx| {
- let project_file = project::File::from_dyn(buffer.read(cx).file()).unwrap();
- let new_path = ProjectPath {
- worktree_id: project_file.worktree_id(cx),
- path: project_file.path.clone(),
- };
- project.rename_entry(project_file.entry_id.unwrap(), new_path, cx)
- })?
- .await?;
- }
-
- // TODO is it worth using project search?
- buffer.update(cx, |buffer, cx| {
- let context_offset = if hunk.context.is_empty() {
- 0
- } else {
- let text = buffer.text();
- if let Some(offset) = text.find(&hunk.context) {
- if text[offset + 1..].contains(&hunk.context) {
- anyhow::bail!("Context is not unique enough:\n{}", hunk.context);
- }
- offset
- } else {
- anyhow::bail!(
- "Failed to match context:\n{}\n\nBuffer:\n{}",
- hunk.context,
- text
- );
- }
- };
-
- buffer.edit(
- hunk.edits.into_iter().map(|edit| {
- (
- context_offset + edit.range.start..context_offset + edit.range.end,
- edit.text,
- )
- }),
- None,
- cx,
- );
-
- anyhow::Ok(())
- })??;
- }
- }
-
- anyhow::Ok(open_buffers)
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use ::fs::FakeFs;
- use gpui::TestAppContext;
- use indoc::indoc;
- use pretty_assertions::assert_eq;
- use project::Project;
- use serde_json::json;
- use settings::SettingsStore;
- use util::path;
-
- #[gpui::test]
- async fn test_apply_diff_successful(cx: &mut TestAppContext) {
- let buffer_1_text = indoc! {r#"
- one
- two
- three
- four
- five
- "# };
-
- let buffer_1_text_final = indoc! {r#"
- 3
- 4
- 5
- "# };
-
- let buffer_2_text = indoc! {r#"
- six
- seven
- eight
- nine
- ten
- "# };
-
- let buffer_2_text_final = indoc! {r#"
- 5
- six
- seven
- 7.5
- eight
- nine
- ten
- 11
- "# };
-
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- Project::init_settings(cx);
- language::init(cx);
- });
-
- let fs = FakeFs::new(cx.background_executor.clone());
- fs.insert_tree(
- path!("/root"),
- json!({
- "file1": buffer_1_text,
- "file2": buffer_2_text,
- }),
- )
- .await;
-
- let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
-
- let diff = indoc! {r#"
- --- a/root/file1
- +++ b/root/file1
- one
- two
- -three
- +3
- four
- five
- --- a/root/file1
- +++ b/root/file1
- 3
- -four
- -five
- +4
- +5
- --- a/root/file1
- +++ b/root/file1
- -one
- -two
- 3
- 4
- --- a/root/file2
- +++ b/root/file2
- +5
- six
- --- a/root/file2
- +++ b/root/file2
- seven
- +7.5
- eight
- --- a/root/file2
- +++ b/root/file2
- ten
- +11
- "#};
-
- let _buffers = apply_diff(diff, &project, &mut cx.to_async())
- .await
- .unwrap();
- let buffer_1 = project
- .update(cx, |project, cx| {
- let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
- project.open_buffer(project_path, cx)
- })
- .await
- .unwrap();
-
- buffer_1.read_with(cx, |buffer, _cx| {
- assert_eq!(buffer.text(), buffer_1_text_final);
- });
- let buffer_2 = project
- .update(cx, |project, cx| {
- let project_path = project.find_project_path(path!("/root/file2"), cx).unwrap();
- project.open_buffer(project_path, cx)
- })
- .await
- .unwrap();
-
- buffer_2.read_with(cx, |buffer, _cx| {
- assert_eq!(buffer.text(), buffer_2_text_final);
- });
- }
-
- #[gpui::test]
- async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
- let buffer_1_text = indoc! {r#"
- one
- two
- three
- four
- five
- one
- two
- three
- four
- five
- "# };
-
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- Project::init_settings(cx);
- language::init(cx);
- });
-
- let fs = FakeFs::new(cx.background_executor.clone());
- fs.insert_tree(
- path!("/root"),
- json!({
- "file1": buffer_1_text,
- }),
- )
- .await;
-
- let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
-
- let diff = indoc! {r#"
- --- a/root/file1
- +++ b/root/file1
- one
- two
- -three
- +3
- four
- five
- "#};
-
- apply_diff(diff, &project, &mut cx.to_async())
- .await
- .expect_err("Non-unique edits should fail");
- }
-
- #[gpui::test]
- async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
- let start = indoc! {r#"
- one
- two
- three
- four
- five
-
- four
- five
- "# };
-
- let end = indoc! {r#"
- one
- two
- 3
- four
- 5
-
- four
- five
- "# };
-
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- Project::init_settings(cx);
- language::init(cx);
- });
-
- let fs = FakeFs::new(cx.background_executor.clone());
- fs.insert_tree(
- path!("/root"),
- json!({
- "file1": start,
- }),
- )
- .await;
-
- let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
-
- let diff = indoc! {r#"
- --- a/root/file1
- +++ b/root/file1
- one
- two
- -three
- +3
- four
- -five
- +5
- "#};
-
- let _buffers = apply_diff(diff, &project, &mut cx.to_async())
- .await
- .unwrap();
-
- let buffer_1 = project
- .update(cx, |project, cx| {
- let project_path = project.find_project_path(path!("/root/file1"), cx).unwrap();
- project.open_buffer(project_path, cx)
- })
- .await
- .unwrap();
-
- buffer_1.read_with(cx, |buffer, _cx| {
- assert_eq!(buffer.text(), end);
- });
- }
-}
@@ -1,6 +1,7 @@
mod evaluate;
mod example;
mod headless;
+mod paths;
mod predict;
mod source_location;
mod syntax_retrieval_stats;
@@ -10,28 +11,22 @@ use crate::evaluate::{EvaluateArguments, run_evaluate};
use crate::example::{ExampleFormat, NamedExample};
use crate::predict::{PredictArguments, run_zeta2_predict};
use crate::syntax_retrieval_stats::retrieval_stats;
-use ::serde::Serialize;
use ::util::paths::PathStyle;
-use anyhow::{Context as _, Result, anyhow};
+use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand};
-use cloud_llm_client::predict_edits_v3::{self, Excerpt};
-use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
+use cloud_llm_client::predict_edits_v3;
use edit_prediction_context::{
- EditPredictionContextOptions, EditPredictionExcerpt, EditPredictionExcerptOptions,
- EditPredictionScoreOptions, Line,
+ EditPredictionContextOptions, EditPredictionExcerptOptions, EditPredictionScoreOptions,
};
-use futures::StreamExt as _;
-use futures::channel::mpsc;
use gpui::{Application, AsyncApp, Entity, prelude::*};
-use language::{Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point};
-use language_model::LanguageModelRegistry;
+use language::{Bias, Buffer, BufferSnapshot, Point};
use project::{Project, Worktree};
use reqwest_client::ReqwestClient;
use serde_json::json;
use std::io::{self};
use std::time::Duration;
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
-use zeta2::{ContextMode, LlmContextOptions, SearchToolQuery};
+use zeta2::ContextMode;
use crate::headless::ZetaCliAppState;
use crate::source_location::SourceLocation;
@@ -79,12 +74,6 @@ enum Zeta2Command {
#[command(subcommand)]
command: Zeta2SyntaxCommand,
},
- Llm {
- #[clap(flatten)]
- args: Zeta2Args,
- #[command(subcommand)]
- command: Zeta2LlmCommand,
- },
Predict(PredictArguments),
Eval(EvaluateArguments),
}
@@ -107,14 +96,6 @@ enum Zeta2SyntaxCommand {
},
}
-#[derive(Subcommand, Debug)]
-enum Zeta2LlmCommand {
- Context {
- #[clap(flatten)]
- context_args: ContextArgs,
- },
-}
-
#[derive(Debug, Args)]
#[group(requires = "worktree")]
struct ContextArgs {
@@ -388,197 +369,6 @@ async fn zeta2_syntax_context(
Ok(output)
}
-async fn zeta2_llm_context(
- zeta2_args: Zeta2Args,
- context_args: ContextArgs,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
-) -> Result<String> {
- let LoadedContext {
- buffer,
- clipped_cursor,
- snapshot: cursor_snapshot,
- project,
- ..
- } = load_context(&context_args, app_state, cx).await?;
-
- let cursor_position = cursor_snapshot.anchor_after(clipped_cursor);
-
- cx.update(|cx| {
- LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
- registry
- .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
- .unwrap()
- .authenticate(cx)
- })
- })?
- .await?;
-
- let edit_history_unified_diff = match context_args.edit_history {
- Some(events) => events.read_to_string().await?,
- None => String::new(),
- };
-
- let (debug_tx, mut debug_rx) = mpsc::unbounded();
-
- let excerpt_options = EditPredictionExcerptOptions {
- max_bytes: zeta2_args.max_excerpt_bytes,
- min_bytes: zeta2_args.min_excerpt_bytes,
- target_before_cursor_over_total_bytes: zeta2_args.target_before_cursor_over_total_bytes,
- };
-
- let related_excerpts = cx
- .update(|cx| {
- zeta2::related_excerpts::find_related_excerpts(
- buffer,
- cursor_position,
- &project,
- edit_history_unified_diff,
- &LlmContextOptions {
- excerpt: excerpt_options.clone(),
- },
- Some(debug_tx),
- cx,
- )
- })?
- .await?;
-
- let cursor_excerpt = EditPredictionExcerpt::select_from_buffer(
- clipped_cursor,
- &cursor_snapshot,
- &excerpt_options,
- None,
- )
- .context("line didn't fit")?;
-
- #[derive(Serialize)]
- struct Output {
- excerpts: Vec<OutputExcerpt>,
- formatted_excerpts: String,
- meta: OutputMeta,
- }
-
- #[derive(Default, Serialize)]
- struct OutputMeta {
- search_prompt: String,
- search_queries: Vec<SearchToolQuery>,
- }
-
- #[derive(Serialize)]
- struct OutputExcerpt {
- path: PathBuf,
- #[serde(flatten)]
- excerpt: Excerpt,
- }
-
- let mut meta = OutputMeta::default();
-
- while let Some(debug_info) = debug_rx.next().await {
- match debug_info {
- zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
- meta.search_prompt = info.search_prompt;
- }
- zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
- meta.search_queries = info.queries
- }
- _ => {}
- }
- }
-
- cx.update(|cx| {
- let mut excerpts = Vec::new();
- let mut formatted_excerpts = String::new();
-
- let cursor_insertions = [(
- predict_edits_v3::Point {
- line: Line(clipped_cursor.row),
- column: clipped_cursor.column,
- },
- CURSOR_MARKER,
- )];
-
- let mut cursor_excerpt_added = false;
-
- for (buffer, ranges) in related_excerpts {
- let excerpt_snapshot = buffer.read(cx).snapshot();
-
- let mut line_ranges = ranges
- .into_iter()
- .map(|range| {
- let point_range = range.to_point(&excerpt_snapshot);
- Line(point_range.start.row)..Line(point_range.end.row)
- })
- .collect::<Vec<_>>();
-
- let Some(file) = excerpt_snapshot.file() else {
- continue;
- };
- let path = file.full_path(cx);
-
- let is_cursor_file = path == cursor_snapshot.file().unwrap().full_path(cx);
- if is_cursor_file {
- let insertion_ix = line_ranges
- .binary_search_by(|probe| {
- probe
- .start
- .cmp(&cursor_excerpt.line_range.start)
- .then(cursor_excerpt.line_range.end.cmp(&probe.end))
- })
- .unwrap_or_else(|ix| ix);
- line_ranges.insert(insertion_ix, cursor_excerpt.line_range.clone());
- cursor_excerpt_added = true;
- }
-
- let merged_excerpts =
- zeta2::merge_excerpts::merge_excerpts(&excerpt_snapshot, line_ranges)
- .into_iter()
- .map(|excerpt| OutputExcerpt {
- path: path.clone(),
- excerpt,
- });
-
- let excerpt_start_ix = excerpts.len();
- excerpts.extend(merged_excerpts);
-
- write_codeblock(
- &path,
- excerpts[excerpt_start_ix..].iter().map(|e| &e.excerpt),
- if is_cursor_file {
- &cursor_insertions
- } else {
- &[]
- },
- Line(excerpt_snapshot.max_point().row),
- true,
- &mut formatted_excerpts,
- );
- }
-
- if !cursor_excerpt_added {
- write_codeblock(
- &cursor_snapshot.file().unwrap().full_path(cx),
- &[Excerpt {
- start_line: cursor_excerpt.line_range.start,
- text: cursor_excerpt.text(&cursor_snapshot).body.into(),
- }],
- &cursor_insertions,
- Line(cursor_snapshot.max_point().row),
- true,
- &mut formatted_excerpts,
- );
- }
-
- let output = Output {
- excerpts,
- formatted_excerpts,
- meta,
- };
-
- Ok(serde_json::to_string_pretty(&output)?)
- })
- .unwrap()
-}
-
async fn zeta1_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
@@ -670,13 +460,6 @@ fn main() {
};
println!("{}", result.unwrap());
}
- Zeta2Command::Llm { args, command } => match command {
- Zeta2LlmCommand::Context { context_args } => {
- let result =
- zeta2_llm_context(args, context_args, &app_state, cx).await;
- println!("{}", result.unwrap());
- }
- },
},
Command::ConvertExample {
path,
@@ -0,0 +1,8 @@
+use std::{env, path::PathBuf, sync::LazyLock};
+
+static TARGET_DIR: LazyLock<PathBuf> = LazyLock::new(|| env::current_dir().unwrap().join("target"));
+pub static CACHE_DIR: LazyLock<PathBuf> =
+ LazyLock::new(|| TARGET_DIR.join("zeta-prediction-cache"));
+pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-repos"));
+pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-worktrees"));
+pub static LOGS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_DIR.join("zeta-logs"));
@@ -1,22 +1,20 @@
use crate::example::{ActualExcerpt, NamedExample};
-
use crate::headless::ZetaCliAppState;
+use crate::paths::LOGS_DIR;
use ::serde::Serialize;
-use ::util::paths::PathStyle;
use anyhow::{Context as _, Result, anyhow};
use clap::Args;
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
use futures::StreamExt as _;
use gpui::AsyncApp;
-use language_model::LanguageModelRegistry;
-use project::{Project, ProjectPath};
+use project::Project;
use serde::Deserialize;
use std::cell::Cell;
+use std::fs;
use std::io::Write;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
-use util::rel_path::RelPath;
#[derive(Debug, Args)]
pub struct PredictArguments {
@@ -50,21 +48,12 @@ pub async fn zeta2_predict(
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<PredictionDetails> {
+ fs::create_dir_all(&*LOGS_DIR)?;
let worktree_path = example.setup_worktree().await?;
if !AUTHENTICATED.get() {
AUTHENTICATED.set(true);
- cx.update(|cx| {
- LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
- registry
- .provider(&zeta2::related_excerpts::MODEL_PROVIDER_ID)
- .unwrap()
- .authenticate(cx)
- })
- })?
- .await?;
-
app_state
.client
.sign_in_with_optional_connect(true, cx)
@@ -83,6 +72,8 @@ pub async fn zeta2_predict(
)
})?;
+ let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
+
let worktree = project
.update(cx, |project, cx| {
project.create_worktree(&worktree_path, true, cx)
@@ -94,58 +85,30 @@ pub async fn zeta2_predict(
})?
.await;
- let _edited_buffers = example.apply_edit_history(&project, cx).await?;
-
- let cursor_path = RelPath::new(&example.example.cursor_path, PathStyle::Posix)?.into_arc();
-
- let cursor_buffer = project
- .update(cx, |project, cx| {
- project.open_buffer(
- ProjectPath {
- worktree_id: worktree.read(cx).id(),
- path: cursor_path,
- },
- cx,
- )
- })?
- .await?;
-
- let cursor_offset_within_excerpt = example
- .example
- .cursor_position
- .find(CURSOR_MARKER)
- .ok_or_else(|| anyhow!("missing cursor marker"))?;
- let mut cursor_excerpt = example.example.cursor_position.clone();
- cursor_excerpt.replace_range(
- cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
- "",
- );
- let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
- let text = buffer.text();
-
- let mut matches = text.match_indices(&cursor_excerpt);
- let Some((excerpt_offset, _)) = matches.next() else {
- anyhow::bail!(
- "Cursor excerpt did not exist in buffer.\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n"
- );
- };
- assert!(matches.next().is_none());
+ let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
- Ok(excerpt_offset)
- })??;
+ cx.subscribe(&buffer_store, {
+ let project = project.clone();
+ move |_, event, cx| match event {
+ project::buffer_store::BufferStoreEvent::BufferAdded(buffer) => {
+ zeta2::Zeta::try_global(cx)
+ .unwrap()
+ .update(cx, |zeta, cx| zeta.register_buffer(&buffer, &project, cx));
+ }
+ _ => {}
+ }
+ })?
+ .detach();
- let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
- let cursor_anchor =
- cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
+ let _edited_buffers = example.apply_edit_history(&project, cx).await?;
+ let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
- let zeta = cx.update(|cx| zeta2::Zeta::global(&app_state.client, &app_state.user_store, cx))?;
+ let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
let refresh_task = zeta.update(cx, |zeta, cx| {
- zeta.register_buffer(&cursor_buffer, &project, cx);
zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
})?;
- let mut debug_rx = zeta.update(cx, |zeta, _| zeta.debug_info())?;
let mut context_retrieval_started_at = None;
let mut context_retrieval_finished_at = None;
let mut search_queries_generated_at = None;
@@ -159,9 +122,14 @@ pub async fn zeta2_predict(
match event {
zeta2::ZetaDebugInfo::ContextRetrievalStarted(info) => {
context_retrieval_started_at = Some(info.timestamp);
+ fs::write(LOGS_DIR.join("search_prompt.md"), &info.search_prompt)?;
}
zeta2::ZetaDebugInfo::SearchQueriesGenerated(info) => {
search_queries_generated_at = Some(info.timestamp);
+ fs::write(
+ LOGS_DIR.join("search_queries.json"),
+ serde_json::to_string_pretty(&info.regex_by_glob).unwrap(),
+ )?;
}
zeta2::ZetaDebugInfo::SearchQueriesExecuted(info) => {
search_queries_executed_at = Some(info.timestamp);
@@ -173,11 +141,21 @@ pub async fn zeta2_predict(
zeta.request_prediction(&project, &cursor_buffer, cursor_anchor, cx)
})?);
}
- zeta2::ZetaDebugInfo::EditPredicted(request) => {
+ zeta2::ZetaDebugInfo::EditPredictionRequested(request) => {
prediction_started_at = Some(Instant::now());
- request.response_rx.await?.map_err(|err| anyhow!(err))?;
+ fs::write(
+ LOGS_DIR.join("prediction_prompt.md"),
+ &request.local_prompt.unwrap_or_default(),
+ )?;
+
+ let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
prediction_finished_at = Some(Instant::now());
+ fs::write(
+ LOGS_DIR.join("prediction_response.json"),
+ &serde_json::to_string_pretty(&response).unwrap(),
+ )?;
+
for included_file in request.request.included_files {
let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
result
@@ -201,7 +179,6 @@ pub async fn zeta2_predict(
}
break;
}
- _ => {}
}
}