diff --git a/Cargo.lock b/Cargo.lock
index a8f602640838d3634863fc60a2399e8a9a9f5288..ff1041695e1f1e95bcbc05798d1a1e0f953533ff 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -3111,16 +3111,6 @@ dependencies = [
"uuid",
]
-[[package]]
-name = "cloud_zeta2_prompt"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "cloud_llm_client",
- "indoc",
- "serde",
-]
-
[[package]]
name = "cmake"
version = "0.1.54"
@@ -5119,7 +5109,6 @@ dependencies = [
"clock",
"cloud_api_types",
"cloud_llm_client",
- "cloud_zeta2_prompt",
"collections",
"copilot",
"credentials_provider",
@@ -5150,8 +5139,6 @@ dependencies = [
"serde",
"serde_json",
"settings",
- "smol",
- "strsim",
"strum 0.27.2",
"telemetry",
"telemetry_events",
@@ -5162,6 +5149,7 @@ dependencies = [
"workspace",
"worktree",
"zed_actions",
+ "zeta_prompt",
"zlog",
]
@@ -5175,11 +5163,10 @@ dependencies = [
"clap",
"client",
"cloud_llm_client",
- "cloud_zeta2_prompt",
"collections",
"debug_adapter_extension",
+ "dirs 4.0.0",
"edit_prediction",
- "edit_prediction_context",
"extension",
"fs",
"futures 0.3.31",
@@ -5209,9 +5196,10 @@ dependencies = [
"sqlez",
"sqlez_macros",
"terminal_view",
- "toml 0.8.23",
"util",
+ "wasmtime",
"watch",
+ "zeta_prompt",
"zlog",
]
@@ -5239,6 +5227,7 @@ dependencies = [
"text",
"tree-sitter",
"util",
+ "zeta_prompt",
"zlog",
]
@@ -5260,7 +5249,6 @@ dependencies = [
"buffer_diff",
"client",
"cloud_llm_client",
- "cloud_zeta2_prompt",
"codestral",
"command_palette_hooks",
"copilot",
@@ -5291,6 +5279,7 @@ dependencies = [
"util",
"workspace",
"zed_actions",
+ "zeta_prompt",
]
[[package]]
@@ -20933,6 +20922,13 @@ dependencies = [
"syn 2.0.106",
]
+[[package]]
+name = "zeta_prompt"
+version = "0.1.0"
+dependencies = [
+ "serde",
+]
+
[[package]]
name = "zip"
version = "0.6.6"
diff --git a/Cargo.toml b/Cargo.toml
index 0ad4d2b14523988aa0dd6e3bfc935f84bcd0d8d9..fcbe5c829ded21a9aaf9e6bec93b9955b1db6447 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -32,7 +32,6 @@ members = [
"crates/cloud_api_client",
"crates/cloud_api_types",
"crates/cloud_llm_client",
- "crates/cloud_zeta2_prompt",
"crates/collab",
"crates/collab_ui",
"crates/collections",
@@ -202,6 +201,7 @@ members = [
"crates/zed_actions",
"crates/zed_env_vars",
"crates/edit_prediction_cli",
+ "crates/zeta_prompt",
"crates/zlog",
"crates/zlog_settings",
"crates/ztracing",
@@ -266,7 +266,6 @@ clock = { path = "crates/clock" }
cloud_api_client = { path = "crates/cloud_api_client" }
cloud_api_types = { path = "crates/cloud_api_types" }
cloud_llm_client = { path = "crates/cloud_llm_client" }
-cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections", version = "0.1.0" }
command_palette = { path = "crates/command_palette" }
@@ -425,6 +424,7 @@ zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_env_vars = { path = "crates/zed_env_vars" }
edit_prediction = { path = "crates/edit_prediction" }
+zeta_prompt = { path = "crates/zeta_prompt" }
zlog = { path = "crates/zlog" }
zlog_settings = { path = "crates/zlog_settings" }
ztracing = { path = "crates/ztracing" }
@@ -657,6 +657,7 @@ time = { version = "0.3", features = [
tiny_http = "0.8"
tokio = { version = "1" }
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
+tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io", "tokio"] }
toml = "0.8"
toml_edit = { version = "0.22", default-features = false, features = ["display", "parse", "serde"] }
tower-http = "0.4.4"
diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml
index 7149ad4f55feaae5b596a39a3dd460d71cc5daa5..50cf12b977a62d56bf9d4a036165917a5dfff2fc 100644
--- a/crates/client/Cargo.toml
+++ b/crates/client/Cargo.toml
@@ -53,7 +53,7 @@ text.workspace = true
thiserror.workspace = true
time.workspace = true
tiny_http.workspace = true
-tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
+tokio-socks.workspace = true
tokio.workspace = true
url.workspace = true
util.workspace = true
diff --git a/crates/cloud_zeta2_prompt/Cargo.toml b/crates/cloud_zeta2_prompt/Cargo.toml
deleted file mode 100644
index a15e3fe43c28349920433272c4040ccc58ff4cb4..0000000000000000000000000000000000000000
--- a/crates/cloud_zeta2_prompt/Cargo.toml
+++ /dev/null
@@ -1,18 +0,0 @@
-[package]
-name = "cloud_zeta2_prompt"
-version = "0.1.0"
-publish.workspace = true
-edition.workspace = true
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/cloud_zeta2_prompt.rs"
-
-[dependencies]
-anyhow.workspace = true
-cloud_llm_client.workspace = true
-indoc.workspace = true
-serde.workspace = true
diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs
deleted file mode 100644
index 62bfa45f47d0fdfefa9fbd72320c0ddee71cbc47..0000000000000000000000000000000000000000
--- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs
+++ /dev/null
@@ -1,485 +0,0 @@
-use anyhow::Result;
-use cloud_llm_client::predict_edits_v3::{
- self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
-};
-use indoc::indoc;
-use std::cmp;
-use std::fmt::Write;
-use std::path::Path;
-use std::sync::Arc;
-
-pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
-
-pub const CURSOR_MARKER: &str = "<|user_cursor|>";
-/// NOTE: Differs from zed version of constant - includes a newline
-pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
-/// NOTE: Differs from zed version of constant - includes a newline
-pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
-
-const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
- You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
-
- ## Edit History
-
- "#};
-
-const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
- ---
-
- Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
- Do not include the cursor marker in your output.
- If you're editing multiple files, be sure to reflect filename in the hunk's header.
- "};
-
-const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#"
- # Instructions
-
- You are an edit prediction agent in a code editor.
-
- Analyze the history of edits made by the user in order to infer what they are currently trying to accomplish.
- Then complete the remainder of the current change if it is incomplete, or predict the next edit the user intends to make.
- Always continue along the user's current trajectory, rather than changing course.
-
- ## Output Format
-
- You should briefly explain your understanding of the user's overall goal in one sentence, then explain what the next change
- along the users current trajectory will be in another, and finally specify the next edit using the following XML-like format:
-
-
-
- OLD TEXT 1 HERE
-
-
- NEW TEXT 1 HERE
-
-
-
- OLD TEXT 1 HERE
-
-
- NEW TEXT 1 HERE
-
-
-
- - Specify the file to edit using the `path` attribute.
- - Use `` and `` tags to replace content
- - `` must exactly match existing file content, including indentation
- - `` cannot be empty
- - Do not escape quotes, newlines, or other characters within tags
- - Always close all tags properly
- - Don't include the <|user_cursor|> marker in your output.
-
- ## Edit History
-
-"#};
-
-const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
- ---
-
- Remember that the edits in the edit history have already been applied.
-"#};
-
-pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result {
- let prompt_data = PromptData {
- events: request.events.clone(),
- cursor_point: request.cursor_point,
- cursor_path: request.excerpt_path.clone(),
- included_files: request.related_files.clone(),
- };
- match request.prompt_format {
- PromptFormat::MinimalQwen => {
- return Ok(MinimalQwenPrompt.render(&prompt_data));
- }
- PromptFormat::SeedCoder1120 => {
- return Ok(SeedCoder1120Prompt.render(&prompt_data));
- }
- _ => (),
- };
-
- let insertions = match request.prompt_format {
- PromptFormat::Minimal | PromptFormat::OldTextNewText => {
- vec![(request.cursor_point, CURSOR_MARKER)]
- }
- PromptFormat::OnlySnippets => vec![],
- PromptFormat::MinimalQwen => unreachable!(),
- PromptFormat::SeedCoder1120 => unreachable!(),
- };
-
- let mut prompt = match request.prompt_format {
- PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
- PromptFormat::OnlySnippets => String::new(),
- PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
- PromptFormat::MinimalQwen => unreachable!(),
- PromptFormat::SeedCoder1120 => unreachable!(),
- };
-
- if request.events.is_empty() {
- prompt.push_str("(No edit history)\n\n");
- } else {
- let edit_preamble = if request.prompt_format == PromptFormat::Minimal {
- "The following are the latest edits made by the user, from earlier to later.\n\n"
- } else {
- "Here are the latest edits made by the user, from earlier to later.\n\n"
- };
- prompt.push_str(edit_preamble);
- push_events(&mut prompt, &request.events);
- }
-
- let excerpts_preamble = match request.prompt_format {
- PromptFormat::Minimal => indoc! {"
- ## Part of the file under the cursor
-
- (The cursor marker <|user_cursor|> indicates the current user cursor position.
- The file is in current state, edits from edit history has been applied.
- We only show part of the file around the cursor.
- You can only edit exactly this part of the file.
- We prepend line numbers (e.g., `123|`); they are not part of the file.)
- "},
- PromptFormat::OldTextNewText => indoc! {"
- ## Code Excerpts
-
- Here is some excerpts of code that you should take into account to predict the next edit.
-
- The cursor position is marked by `<|user_cursor|>` as it stands after the last edit in the history.
-
- In addition other excerpts are included to better understand what the edit will be, including the declaration
- or references of symbols around the cursor, or other similar code snippets that may need to be updated
- following patterns that appear in the edit history.
-
- Consider each of them carefully in relation to the edit history, and that the user may not have navigated
- to the next place they want to edit yet.
-
- Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs.
- "},
- PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
- indoc! {"
- ## Code Excerpts
-
- The cursor marker <|user_cursor|> indicates the current user cursor position.
- The file is in current state, edits from edit history have been applied.
- "}
- }
- };
-
- prompt.push_str(excerpts_preamble);
- prompt.push('\n');
-
- let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
- for related_file in &request.related_files {
- if request.prompt_format == PromptFormat::Minimal {
- write_codeblock_with_filename(
- &related_file.path,
- &related_file.excerpts,
- if related_file.path == request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- related_file.max_row,
- include_line_numbers,
- &mut prompt,
- );
- } else {
- write_codeblock(
- &related_file.path,
- &related_file.excerpts,
- if related_file.path == request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- related_file.max_row,
- include_line_numbers,
- &mut prompt,
- );
- }
- }
-
- match request.prompt_format {
- PromptFormat::OldTextNewText => {
- prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
- }
- PromptFormat::Minimal => {
- prompt.push_str(MINIMAL_PROMPT_REMINDER);
- }
- _ => {}
- }
-
- Ok(prompt)
-}
-
-pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
- match prompt_format {
- PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(),
- _ => GenerationParams::default(),
- }
-}
-
-pub fn write_codeblock<'a>(
- path: &Path,
- excerpts: impl IntoIterator- ,
- sorted_insertions: &[(Point, &str)],
- file_line_count: Line,
- include_line_numbers: bool,
- output: &'a mut String,
-) {
- writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
-
- write_excerpts(
- excerpts,
- sorted_insertions,
- file_line_count,
- include_line_numbers,
- output,
- );
- write!(output, "`````\n\n").unwrap();
-}
-
-fn write_codeblock_with_filename<'a>(
- path: &Path,
- excerpts: impl IntoIterator
- ,
- sorted_insertions: &[(Point, &str)],
- file_line_count: Line,
- include_line_numbers: bool,
- output: &'a mut String,
-) {
- writeln!(output, "`````filename={}", DiffPathFmt(path)).unwrap();
-
- write_excerpts(
- excerpts,
- sorted_insertions,
- file_line_count,
- include_line_numbers,
- output,
- );
- write!(output, "`````\n\n").unwrap();
-}
-
-pub fn write_excerpts<'a>(
- excerpts: impl IntoIterator
- ,
- sorted_insertions: &[(Point, &str)],
- file_line_count: Line,
- include_line_numbers: bool,
- output: &mut String,
-) {
- let mut current_row = Line(0);
- let mut sorted_insertions = sorted_insertions.iter().peekable();
-
- for excerpt in excerpts {
- if excerpt.start_line > current_row {
- writeln!(output, "…").unwrap();
- }
- if excerpt.text.is_empty() {
- return;
- }
-
- current_row = excerpt.start_line;
-
- for mut line in excerpt.text.lines() {
- if include_line_numbers {
- write!(output, "{}|", current_row.0 + 1).unwrap();
- }
-
- while let Some((insertion_location, insertion_marker)) = sorted_insertions.peek() {
- match current_row.cmp(&insertion_location.line) {
- cmp::Ordering::Equal => {
- let (prefix, suffix) = line.split_at(insertion_location.column as usize);
- output.push_str(prefix);
- output.push_str(insertion_marker);
- line = suffix;
- sorted_insertions.next();
- }
- cmp::Ordering::Less => break,
- cmp::Ordering::Greater => {
- sorted_insertions.next();
- break;
- }
- }
- }
- output.push_str(line);
- output.push('\n');
- current_row.0 += 1;
- }
- }
-
- if current_row < file_line_count {
- writeln!(output, "…").unwrap();
- }
-}
-
-pub fn push_events(output: &mut String, events: &[Arc]) {
- if events.is_empty() {
- return;
- };
-
- writeln!(output, "`````diff").unwrap();
- for event in events {
- writeln!(output, "{}", event).unwrap();
- }
- writeln!(output, "`````\n").unwrap();
-}
-
-struct PromptData {
- events: Vec>,
- cursor_point: Point,
- cursor_path: Arc, // TODO: make a common struct with cursor_point
- included_files: Vec,
-}
-
-#[derive(Default)]
-pub struct GenerationParams {
- pub temperature: Option,
- pub top_p: Option,
- pub stop: Option>,
-}
-
-trait PromptFormatter {
- fn render(&self, data: &PromptData) -> String;
-
- fn generation_params() -> GenerationParams {
- return GenerationParams::default();
- }
-}
-
-struct MinimalQwenPrompt;
-
-impl PromptFormatter for MinimalQwenPrompt {
- fn render(&self, data: &PromptData) -> String {
- let edit_history = self.fmt_edit_history(data);
- let context = self.fmt_context(data);
-
- format!(
- "{instructions}\n\n{edit_history}\n\n{context}",
- instructions = MinimalQwenPrompt::INSTRUCTIONS,
- edit_history = edit_history,
- context = context
- )
- }
-}
-
-impl MinimalQwenPrompt {
- const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n";
-
- fn fmt_edit_history(&self, data: &PromptData) -> String {
- if data.events.is_empty() {
- "(No edit history)\n\n".to_string()
- } else {
- let mut events_str = String::new();
- push_events(&mut events_str, &data.events);
- format!(
- "The following are the latest edits made by the user, from earlier to later.\n\n{}",
- events_str
- )
- }
- }
-
- fn fmt_context(&self, data: &PromptData) -> String {
- let mut context = String::new();
- let include_line_numbers = true;
-
- for related_file in &data.included_files {
- writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap();
-
- if related_file.path == data.cursor_path {
- write!(context, "<|fim_prefix|>").unwrap();
- write_excerpts(
- &related_file.excerpts,
- &[(data.cursor_point, "<|fim_suffix|>")],
- related_file.max_row,
- include_line_numbers,
- &mut context,
- );
- writeln!(context, "<|fim_middle|>").unwrap();
- } else {
- write_excerpts(
- &related_file.excerpts,
- &[],
- related_file.max_row,
- include_line_numbers,
- &mut context,
- );
- }
- }
- context
- }
-}
-
-struct SeedCoder1120Prompt;
-
-impl PromptFormatter for SeedCoder1120Prompt {
- fn render(&self, data: &PromptData) -> String {
- let edit_history = self.fmt_edit_history(data);
- let context = self.fmt_context(data);
-
- format!(
- "# Edit History:\n{edit_history}\n\n{context}",
- edit_history = edit_history,
- context = context
- )
- }
-
- fn generation_params() -> GenerationParams {
- GenerationParams {
- temperature: Some(0.2),
- top_p: Some(0.9),
- stop: Some(vec!["<[end_of_sentence]>".into()]),
- }
- }
-}
-
-impl SeedCoder1120Prompt {
- fn fmt_edit_history(&self, data: &PromptData) -> String {
- if data.events.is_empty() {
- "(No edit history)\n\n".to_string()
- } else {
- let mut events_str = String::new();
- push_events(&mut events_str, &data.events);
- events_str
- }
- }
-
- fn fmt_context(&self, data: &PromptData) -> String {
- let mut context = String::new();
- let include_line_numbers = true;
-
- for related_file in &data.included_files {
- writeln!(context, "# Path: {}\n", DiffPathFmt(&related_file.path)).unwrap();
-
- if related_file.path == data.cursor_path {
- let fim_prompt = self.fmt_fim(&related_file, data.cursor_point);
- context.push_str(&fim_prompt);
- } else {
- write_excerpts(
- &related_file.excerpts,
- &[],
- related_file.max_row,
- include_line_numbers,
- &mut context,
- );
- }
- }
- context
- }
-
- fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
- let mut buf = String::new();
- const FIM_SUFFIX: &str = "<[fim-suffix]>";
- const FIM_PREFIX: &str = "<[fim-prefix]>";
- const FIM_MIDDLE: &str = "<[fim-middle]>";
- write!(buf, "{}", FIM_PREFIX).unwrap();
- write_excerpts(
- &file.excerpts,
- &[(cursor_point, FIM_SUFFIX)],
- file.max_row,
- true,
- &mut buf,
- );
-
- // Swap prefix and suffix parts
- let index = buf.find(FIM_SUFFIX).unwrap();
- let prefix = &buf[..index];
- let suffix = &buf[index..];
-
- format!("{}{}{}", suffix, prefix, FIM_MIDDLE)
- }
-}
diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml
index 6e62cfa6f038671d595c5671de147cdc2125064d..c9237232e5e0bb6167fbeee8732d46ee584b080b 100644
--- a/crates/edit_prediction/Cargo.toml
+++ b/crates/edit_prediction/Cargo.toml
@@ -21,7 +21,6 @@ arrayvec.workspace = true
brotli.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
-cloud_zeta2_prompt.workspace = true
collections.workspace = true
copilot.workspace = true
credentials_provider.workspace = true
@@ -50,8 +49,6 @@ semver.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
-smol.workspace = true
-strsim.workspace = true
strum.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
@@ -62,6 +59,7 @@ uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
zed_actions.workspace = true
+zeta_prompt.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }
diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs
index 141fff3063b83d7e0003fddd6b4eba2d213d5fd5..b0d4a5f4d69c357fb0a153bee267a64dc0c465dd 100644
--- a/crates/edit_prediction/src/edit_prediction.rs
+++ b/crates/edit_prediction/src/edit_prediction.rs
@@ -1,14 +1,13 @@
use anyhow::Result;
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
-use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
+use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
ZED_VERSION_HEADER_NAME,
};
-use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
use collections::{HashMap, HashSet};
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use edit_prediction_context::EditPredictionExcerptOptions;
@@ -16,10 +15,7 @@ use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, Rel
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
use futures::{
AsyncReadExt as _, FutureExt as _, StreamExt as _,
- channel::{
- mpsc::{self, UnboundedReceiver},
- oneshot,
- },
+ channel::mpsc::{self, UnboundedReceiver},
select_biased,
};
use gpui::BackgroundExecutor;
@@ -58,8 +54,10 @@ mod onboarding_modal;
pub mod open_ai_response;
mod prediction;
pub mod sweep_ai;
+
+#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
pub mod udiff;
-mod xml_edits;
+
mod zed_edit_prediction_delegate;
pub mod zeta1;
pub mod zeta2;
@@ -72,7 +70,6 @@ use crate::mercury::Mercury;
use crate::onboarding_modal::ZedPredictModal;
pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
-pub use crate::prediction::EditPredictionInputs;
use crate::prediction::EditPredictionResult;
pub use crate::sweep_ai::SweepAi;
pub use telemetry_events::EditPredictionRating;
@@ -112,7 +109,6 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
min_bytes: 128,
target_before_cursor_over_total_bytes: 0.5,
},
- max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
prompt_format: PromptFormat::DEFAULT,
};
@@ -162,7 +158,6 @@ pub struct EditPredictionStore {
use_context: bool,
options: ZetaOptions,
update_required: bool,
- debug_tx: Option>,
#[cfg(feature = "eval-support")]
eval_cache: Option>,
edit_prediction_model: EditPredictionModel,
@@ -183,10 +178,22 @@ pub enum EditPredictionModel {
Mercury,
}
+pub struct EditPredictionModelInput {
+ project: Entity,
+ buffer: Entity,
+ snapshot: BufferSnapshot,
+ position: Anchor,
+ events: Vec>,
+ related_files: Arc<[RelatedFile]>,
+ recent_paths: VecDeque,
+ trigger: PredictEditsRequestTrigger,
+ diagnostic_search_range: Range,
+ debug_tx: Option>,
+}
+
#[derive(Debug, Clone, PartialEq)]
pub struct ZetaOptions {
pub context: EditPredictionExcerptOptions,
- pub max_prompt_bytes: usize,
pub prompt_format: predict_edits_v3::PromptFormat,
}
@@ -194,7 +201,8 @@ pub struct ZetaOptions {
pub enum DebugEvent {
ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
- EditPredictionRequested(EditPredictionRequestedDebugEvent),
+ EditPredictionStarted(EditPredictionStartedDebugEvent),
+ EditPredictionFinished(EditPredictionFinishedDebugEvent),
}
#[derive(Debug)]
@@ -212,27 +220,30 @@ pub struct ContextRetrievalFinishedDebugEvent {
}
#[derive(Debug)]
-pub struct EditPredictionRequestedDebugEvent {
- pub inputs: EditPredictionInputs,
- pub retrieval_time: Duration,
+pub struct EditPredictionStartedDebugEvent {
pub buffer: WeakEntity,
pub position: Anchor,
- pub local_prompt: Result,
- pub response_rx: oneshot::Receiver<(Result, Duration)>,
+ pub prompt: Option,
+}
+
+#[derive(Debug)]
+pub struct EditPredictionFinishedDebugEvent {
+ pub buffer: WeakEntity,
+ pub position: Anchor,
+ pub model_output: Option,
}
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
struct ProjectState {
- events: VecDeque>,
+ events: VecDeque>,
last_event: Option,
recent_paths: VecDeque,
registered_buffers: HashMap,
current_prediction: Option,
next_pending_prediction_id: usize,
pending_predictions: ArrayVec,
- context_updates_tx: smol::channel::Sender<()>,
- context_updates_rx: smol::channel::Receiver<()>,
+ debug_tx: Option>,
last_prediction_refresh: Option<(EntityId, Instant)>,
cancelled_predictions: HashSet,
context: Entity,
@@ -241,7 +252,7 @@ struct ProjectState {
}
impl ProjectState {
- pub fn events(&self, cx: &App) -> Vec> {
+ pub fn events(&self, cx: &App) -> Vec> {
self.events
.iter()
.cloned()
@@ -376,7 +387,7 @@ impl LastEvent {
&self,
license_detection_watchers: &HashMap>,
cx: &App,
- ) -> Option> {
+ ) -> Option> {
let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
@@ -396,7 +407,7 @@ impl LastEvent {
if path == old_path && diff.is_empty() {
None
} else {
- Some(Arc::new(predict_edits_v3::Event::BufferChange {
+ Some(Arc::new(zeta_prompt::Event::BufferChange {
old_path,
path,
diff,
@@ -481,7 +492,6 @@ impl EditPredictionStore {
},
),
update_required: false,
- debug_tx: None,
#[cfg(feature = "eval-support")]
eval_cache: None,
edit_prediction_model: EditPredictionModel::Zeta2,
@@ -536,12 +546,6 @@ impl EditPredictionStore {
self.eval_cache = Some(cache);
}
- pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver {
- let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
- self.debug_tx = Some(debug_watch_tx);
- debug_watch_rx
- }
-
pub fn options(&self) -> &ZetaOptions {
&self.options
}
@@ -560,15 +564,35 @@ impl EditPredictionStore {
}
}
+ pub fn edit_history_for_project(
+ &self,
+ project: &Entity,
+ ) -> Vec> {
+ self.projects
+ .get(&project.entity_id())
+ .map(|project_state| project_state.events.iter().cloned().collect())
+ .unwrap_or_default()
+ }
+
pub fn context_for_project<'a>(
&'a self,
project: &Entity,
cx: &'a App,
- ) -> &'a [RelatedFile] {
+ ) -> Arc<[RelatedFile]> {
self.projects
.get(&project.entity_id())
.map(|project| project.context.read(cx).related_files())
- .unwrap_or(&[])
+ .unwrap_or_else(|| vec![].into())
+ }
+
+ pub fn context_for_project_with_buffers<'a>(
+ &'a self,
+ project: &Entity,
+ cx: &'a App,
+ ) -> Option)>> {
+ self.projects
+ .get(&project.entity_id())
+ .map(|project| project.context.read(cx).related_files_with_buffers())
}
pub fn usage(&self, cx: &App) -> Option {
@@ -599,85 +623,21 @@ impl EditPredictionStore {
cx: &mut Context,
) -> &mut ProjectState {
let entity_id = project.entity_id();
- let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
self.projects
.entry(entity_id)
.or_insert_with(|| ProjectState {
context: {
let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
- cx.subscribe(
- &related_excerpt_store,
- move |this, _, event, _| match event {
- RelatedExcerptStoreEvent::StartedRefresh => {
- if let Some(debug_tx) = this.debug_tx.clone() {
- debug_tx
- .unbounded_send(DebugEvent::ContextRetrievalStarted(
- ContextRetrievalStartedDebugEvent {
- project_entity_id: entity_id,
- timestamp: Instant::now(),
- search_prompt: String::new(),
- },
- ))
- .ok();
- }
- }
- RelatedExcerptStoreEvent::FinishedRefresh {
- cache_hit_count,
- cache_miss_count,
- mean_definition_latency,
- max_definition_latency,
- } => {
- if let Some(debug_tx) = this.debug_tx.clone() {
- debug_tx
- .unbounded_send(DebugEvent::ContextRetrievalFinished(
- ContextRetrievalFinishedDebugEvent {
- project_entity_id: entity_id,
- timestamp: Instant::now(),
- metadata: vec![
- (
- "Cache Hits",
- format!(
- "{}/{}",
- cache_hit_count,
- cache_hit_count + cache_miss_count
- )
- .into(),
- ),
- (
- "Max LSP Time",
- format!(
- "{} ms",
- max_definition_latency.as_millis()
- )
- .into(),
- ),
- (
- "Mean LSP Time",
- format!(
- "{} ms",
- mean_definition_latency.as_millis()
- )
- .into(),
- ),
- ],
- },
- ))
- .ok();
- }
- if let Some(project_state) = this.projects.get(&entity_id) {
- project_state.context_updates_tx.send_blocking(()).ok();
- }
- }
- },
- )
+ cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
+ this.handle_excerpt_store_event(entity_id, event);
+ })
.detach();
related_excerpt_store
},
events: VecDeque::new(),
last_event: None,
recent_paths: VecDeque::new(),
- context_updates_rx,
- context_updates_tx,
+ debug_tx: None,
registered_buffers: HashMap::default(),
current_prediction: None,
cancelled_predictions: HashSet::default(),
@@ -689,12 +649,79 @@ impl EditPredictionStore {
})
}
- pub fn project_context_updates(
- &self,
+ pub fn remove_project(&mut self, project: &Entity) {
+ self.projects.remove(&project.entity_id());
+ }
+
+ fn handle_excerpt_store_event(
+ &mut self,
+ project_entity_id: EntityId,
+ event: &RelatedExcerptStoreEvent,
+ ) {
+ if let Some(project_state) = self.projects.get(&project_entity_id) {
+ if let Some(debug_tx) = project_state.debug_tx.clone() {
+ match event {
+ RelatedExcerptStoreEvent::StartedRefresh => {
+ debug_tx
+ .unbounded_send(DebugEvent::ContextRetrievalStarted(
+ ContextRetrievalStartedDebugEvent {
+ project_entity_id: project_entity_id,
+ timestamp: Instant::now(),
+ search_prompt: String::new(),
+ },
+ ))
+ .ok();
+ }
+ RelatedExcerptStoreEvent::FinishedRefresh {
+ cache_hit_count,
+ cache_miss_count,
+ mean_definition_latency,
+ max_definition_latency,
+ } => {
+ debug_tx
+ .unbounded_send(DebugEvent::ContextRetrievalFinished(
+ ContextRetrievalFinishedDebugEvent {
+ project_entity_id: project_entity_id,
+ timestamp: Instant::now(),
+ metadata: vec![
+ (
+ "Cache Hits",
+ format!(
+ "{}/{}",
+ cache_hit_count,
+ cache_hit_count + cache_miss_count
+ )
+ .into(),
+ ),
+ (
+ "Max LSP Time",
+ format!("{} ms", max_definition_latency.as_millis())
+ .into(),
+ ),
+ (
+ "Mean LSP Time",
+ format!("{} ms", mean_definition_latency.as_millis())
+ .into(),
+ ),
+ ],
+ },
+ ))
+ .ok();
+ }
+ }
+ }
+ }
+ }
+
+ pub fn debug_info(
+ &mut self,
project: &Entity,
- ) -> Option> {
- let project_state = self.projects.get(&project.entity_id())?;
- Some(project_state.context_updates_rx.clone())
+ cx: &mut Context,
+ ) -> mpsc::UnboundedReceiver {
+ let project_state = self.get_or_init_project(project, cx);
+ let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
+ project_state.debug_tx = Some(debug_watch_tx);
+ debug_watch_rx
}
fn handle_project_event(
@@ -1348,6 +1375,7 @@ impl EditPredictionStore {
let project_state = self.projects.get(&project.entity_id()).unwrap();
let events = project_state.events(cx);
let has_events = !events.is_empty();
+ let debug_tx = project_state.debug_tx.clone();
let snapshot = active_buffer.read(cx).snapshot();
let cursor_point = position.to_point(&snapshot);
@@ -1357,55 +1385,29 @@ impl EditPredictionStore {
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
let related_files = if self.use_context {
- self.context_for_project(&project, cx).to_vec()
+ self.context_for_project(&project, cx)
} else {
- Vec::new()
+ Vec::new().into()
+ };
+
+ let inputs = EditPredictionModelInput {
+ project: project.clone(),
+ buffer: active_buffer.clone(),
+ snapshot: snapshot.clone(),
+ position,
+ events,
+ related_files,
+ recent_paths: project_state.recent_paths.clone(),
+ trigger,
+ diagnostic_search_range: diagnostic_search_range.clone(),
+ debug_tx,
};
let task = match self.edit_prediction_model {
- EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
- self,
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- trigger,
- cx,
- ),
- EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
- self,
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- related_files,
- trigger,
- cx,
- ),
- EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- &project_state.recent_paths,
- related_files,
- diagnostic_search_range.clone(),
- cx,
- ),
- EditPredictionModel::Mercury => self.mercury.request_prediction(
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- &project_state.recent_paths,
- related_files,
- diagnostic_search_range.clone(),
- cx,
- ),
+ EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
+ EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
+ EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
+ EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
};
cx.spawn(async move |this, cx| {
@@ -1706,6 +1708,20 @@ impl EditPredictionStore {
}
}
+ #[cfg(feature = "eval-support")]
+ pub fn set_context_for_buffer(
+ &mut self,
+ project: &Entity,
+ related_files: Vec,
+ cx: &mut Context,
+ ) {
+ self.get_or_init_project(project, cx)
+ .context
+ .update(cx, |store, _| {
+ store.set_related_files(related_files);
+ });
+ }
+
fn is_file_open_source(
&self,
project: &Entity,
@@ -1729,14 +1745,14 @@ impl EditPredictionStore {
self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
}
- fn can_collect_events(&self, events: &[Arc]) -> bool {
+ fn can_collect_events(&self, events: &[Arc]) -> bool {
if !self.data_collection_choice.is_enabled() {
return false;
}
events.iter().all(|event| {
matches!(
event.as_ref(),
- Event::BufferChange {
+ zeta_prompt::Event::BufferChange {
in_open_source_repo: true,
..
}
diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs
index 0b7e289bb32b5a10c32a4bd34f118d7cb6c7d43c..f6465b14cbd1b3357349071bc5eda399253b5328 100644
--- a/crates/edit_prediction/src/edit_prediction_tests.rs
+++ b/crates/edit_prediction/src/edit_prediction_tests.rs
@@ -1,5 +1,5 @@
use super::*;
-use crate::zeta1::MAX_EVENT_TOKENS;
+use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
use client::{UserStore, test::FakeServer};
use clock::{FakeSystemClock, ReplicaId};
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -7,7 +7,6 @@ use cloud_llm_client::{
EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
RejectEditPredictionsBody,
};
-use edit_prediction_context::Line;
use futures::{
AsyncReadExt, StreamExt,
channel::{mpsc, oneshot},
@@ -28,6 +27,7 @@ use settings::SettingsStore;
use std::{path::Path, sync::Arc, time::Duration};
use util::{path, rel_path::rel_path};
use uuid::Uuid;
+use zeta_prompt::ZetaPromptInput;
use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
@@ -65,18 +65,21 @@ async fn test_current_state(cx: &mut TestAppContext) {
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
});
- let (_request, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
- .send(model_response(indoc! {r"
- --- a/root/1.txt
- +++ b/root/1.txt
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "}))
+ .send(model_response(
+ request,
+ indoc! {r"
+ --- a/root/1.txt
+ +++ b/root/1.txt
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "},
+ ))
.unwrap();
cx.run_until_parked();
@@ -120,16 +123,20 @@ async fn test_current_state(cx: &mut TestAppContext) {
});
});
- let (_request, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
- .send(model_response(indoc! {r#"
- --- a/root/2.txt
- +++ b/root/2.txt
- Hola!
- -Como
- +Como estas?
- Adios
- "#}))
+ .send(model_response(
+ request,
+ indoc! {r#"
+ --- a/root/2.txt
+ +++ b/root/2.txt
+ @@ ... @@
+ Hola!
+ -Como
+ +Como estas?
+ Adios
+ "#},
+ ))
.unwrap();
cx.run_until_parked();
@@ -186,7 +193,7 @@ async fn test_simple_request(cx: &mut TestAppContext) {
ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
// TODO Put back when we have a structured request again
// assert_eq!(
@@ -202,15 +209,18 @@ async fn test_simple_request(cx: &mut TestAppContext) {
// );
respond_tx
- .send(model_response(indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "}))
+ .send(model_response(
+ request,
+ indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "},
+ ))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -276,15 +286,18 @@ async fn test_request_events(cx: &mut TestAppContext) {
);
respond_tx
- .send(model_response(indoc! {r#"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "#}))
+ .send(model_response(
+ request,
+ indoc! {r#"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "#},
+ ))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -324,18 +337,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- const NO_OP_DIFF: &str = indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How
- Bye
- "};
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let response = model_response(NO_OP_DIFF);
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let response = model_response(request, "");
let id = response.id.clone();
respond_tx.send(response).unwrap();
@@ -389,13 +392,13 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text("Hello!\nHow are you?\nBye", cx);
});
- let response = model_response(SIMPLE_DIFF);
+ let response = model_response(request, SIMPLE_DIFF);
let id = response.id.clone();
respond_tx.send(response).unwrap();
@@ -459,8 +462,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let first_response = model_response(SIMPLE_DIFF);
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let first_response = model_response(request, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_tx.send(first_response).unwrap();
@@ -482,8 +485,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let second_response = model_response(SIMPLE_DIFF);
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let second_response = model_response(request, SIMPLE_DIFF);
let second_id = second_response.id.clone();
respond_tx.send(second_response).unwrap();
@@ -541,8 +544,8 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let first_response = model_response(SIMPLE_DIFF);
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let first_response = model_response(request, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_tx.send(first_response).unwrap();
@@ -564,17 +567,20 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
// worse than current prediction
- let second_response = model_response(indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are
- Bye
- "});
+ let second_response = model_response(
+ request,
+ indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are
+ Bye
+ "},
+ );
let second_id = second_response.id.clone();
respond_tx.send(second_response).unwrap();
@@ -633,19 +639,19 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_first) = requests.predict.next().await.unwrap();
+ let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_second) = requests.predict.next().await.unwrap();
+ let (request, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle
cx.run_until_parked();
// second responds first
- let second_response = model_response(SIMPLE_DIFF);
+ let second_response = model_response(request, SIMPLE_DIFF);
let second_id = second_response.id.clone();
respond_second.send(second_response).unwrap();
@@ -663,7 +669,7 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
);
});
- let first_response = model_response(SIMPLE_DIFF);
+ let first_response = model_response(request1, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_first.send(first_response).unwrap();
@@ -724,13 +730,13 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_first) = requests.predict.next().await.unwrap();
+ let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_second) = requests.predict.next().await.unwrap();
+ let (request2, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle, so requests are sent
cx.run_until_parked();
@@ -754,9 +760,9 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
// wait for throttle
cx.run_until_parked();
- let (_, respond_third) = requests.predict.next().await.unwrap();
+ let (request3, respond_third) = requests.predict.next().await.unwrap();
- let first_response = model_response(SIMPLE_DIFF);
+ let first_response = model_response(request1, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_first.send(first_response).unwrap();
@@ -774,7 +780,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
- let cancelled_response = model_response(SIMPLE_DIFF);
+ let cancelled_response = model_response(request2, SIMPLE_DIFF);
let cancelled_id = cancelled_response.id.clone();
respond_second.send(cancelled_response).unwrap();
@@ -792,7 +798,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
- let third_response = model_response(SIMPLE_DIFF);
+ let third_response = model_response(request3, SIMPLE_DIFF);
let third_response_id = third_response.id.clone();
respond_third.send(third_response).unwrap();
@@ -1036,7 +1042,24 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
// );
// }
-fn model_response(text: &str) -> open_ai::Response {
+// Generate a model response that would apply the given diff to the active file.
+fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
+ let prompt = match &request.messages[0] {
+ open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(content),
+ } => content,
+ _ => panic!("unexpected request {request:?}"),
+ };
+
+ let open = "\n";
+ let close = "";
+ let cursor = "<|user_cursor|>";
+
+ let start_ix = open.len() + prompt.find(open).unwrap();
+ let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
+ let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
+ let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
+
open_ai::Response {
id: Uuid::new_v4().to_string(),
object: "response".into(),
@@ -1045,7 +1068,7 @@ fn model_response(text: &str) -> open_ai::Response {
choices: vec![open_ai::Choice {
index: 0,
message: open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Plain(text.to_string())),
+ content: Some(open_ai::MessageContent::Plain(new_excerpt)),
tool_calls: vec![],
},
finish_reason: None,
@@ -1160,20 +1183,19 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
.read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
.await;
- let completion = EditPrediction {
+ let prediction = EditPrediction {
edits,
edit_preview,
buffer: buffer.clone(),
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
id: EditPredictionId("the-id".into()),
- inputs: EditPredictionInputs {
+ inputs: ZetaPromptInput {
events: Default::default(),
- included_files: Default::default(),
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- line: Line(0),
- column: 0,
- },
+ related_files: Default::default(),
cursor_path: Path::new("").into(),
+ cursor_excerpt: "".into(),
+ editable_range_in_excerpt: 0..0,
+ cursor_offset_in_excerpt: 0,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
@@ -1182,7 +1204,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
cx.update(|cx| {
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1192,7 +1214,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1202,7 +1224,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1212,7 +1234,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1222,7 +1244,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1232,7 +1254,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1242,7 +1264,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1252,7 +1274,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1260,7 +1282,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
- assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
+ assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
})
}
diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs
index 40c0fdfac021f937df5172fd423d3b6bfc5f8146..f3a3afc53fc5e175fdbda2dc6b5867da6fd38feb 100644
--- a/crates/edit_prediction/src/mercury.rs
+++ b/crates/edit_prediction/src/mercury.rs
@@ -1,20 +1,17 @@
use anyhow::{Context as _, Result};
-use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
-use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
- App, AppContext as _, Entity, Task,
+ App, AppContext as _, Task,
http_client::{self, AsyncBody, Method},
};
-use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
-use project::{Project, ProjectPath};
-use std::{
- collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
-};
+use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
+use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
+use zeta_prompt::ZetaPromptInput;
use crate::{
- EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
+ DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
+ EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
prediction::EditPredictionResult,
};
@@ -38,16 +35,17 @@ impl Mercury {
store_api_token_in_keychain(api_token, cx)
}
- pub fn request_prediction(
+ pub(crate) fn request_prediction(
&self,
- _project: &Entity,
- active_buffer: &Entity,
- snapshot: BufferSnapshot,
- position: language::Anchor,
- events: Vec>,
- _recent_paths: &VecDeque,
- related_files: Vec,
- _diagnostic_search_range: Range,
+ EditPredictionModelInput {
+ buffer,
+ snapshot,
+ position,
+ events,
+ related_files,
+ debug_tx,
+ ..
+ }: EditPredictionModelInput,
cx: &mut App,
) -> Task>> {
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
@@ -62,6 +60,7 @@ impl Mercury {
let http_client = cx.http_client();
let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
+ let active_buffer = buffer.clone();
let result = cx.background_spawn(async move {
let (editable_range, context_range) =
@@ -72,39 +71,39 @@ impl Mercury {
MAX_REWRITE_TOKENS,
);
- let offset_range = editable_range.to_offset(&snapshot);
- let prompt = build_prompt(
- &events,
- &related_files,
- &snapshot,
- full_path.as_ref(),
- cursor_point,
- editable_range,
- context_range.clone(),
- );
-
- let inputs = EditPredictionInputs {
- events: events,
- included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
- path: full_path.clone(),
- max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
- excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
- start_line: cloud_llm_client::predict_edits_v3::Line(
- context_range.start.row,
- ),
- text: snapshot
- .text_for_range(context_range.clone())
- .collect::()
- .into(),
- }],
- }],
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- column: cursor_point.column,
- line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
- },
+ let context_offset_range = context_range.to_offset(&snapshot);
+
+ let editable_offset_range = editable_range.to_offset(&snapshot);
+
+ let inputs = zeta_prompt::ZetaPromptInput {
+ events,
+ related_files,
+ cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
+ - context_range.start.to_offset(&snapshot),
cursor_path: full_path.clone(),
+ cursor_excerpt: snapshot
+ .text_for_range(context_range)
+ .collect::()
+ .into(),
+ editable_range_in_excerpt: (editable_offset_range.start
+ - context_offset_range.start)
+ ..(editable_offset_range.end - context_offset_range.start),
};
+ let prompt = build_prompt(&inputs);
+
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionStarted(
+ EditPredictionStartedDebugEvent {
+ buffer: active_buffer.downgrade(),
+ prompt: Some(prompt.clone()),
+ position,
+ },
+ ))
+ .ok();
+ }
+
let request_body = open_ai::Request {
model: "mercury-coder".into(),
messages: vec![open_ai::RequestMessage::User {
@@ -160,6 +159,18 @@ impl Mercury {
let id = mem::take(&mut response.id);
let response_str = text_from_response(response).unwrap_or_default();
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionFinished(
+ EditPredictionFinishedDebugEvent {
+ buffer: active_buffer.downgrade(),
+ model_output: Some(response_str.clone()),
+ position,
+ },
+ ))
+ .ok();
+ }
+
let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
@@ -168,15 +179,16 @@ impl Mercury {
if response_str != NO_PREDICTION_OUTPUT {
let old_text = snapshot
- .text_for_range(offset_range.clone())
+ .text_for_range(editable_offset_range.clone())
.collect::();
edits.extend(
language::text_diff(&old_text, &response_str)
.into_iter()
.map(|(range, text)| {
(
- snapshot.anchor_after(offset_range.start + range.start)
- ..snapshot.anchor_before(offset_range.start + range.end),
+ snapshot.anchor_after(editable_offset_range.start + range.start)
+ ..snapshot
+ .anchor_before(editable_offset_range.start + range.end),
text,
)
}),
@@ -186,8 +198,6 @@ impl Mercury {
anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
});
- let buffer = active_buffer.clone();
-
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) =
result.await.context("Mercury edit prediction failed")?;
@@ -208,15 +218,7 @@ impl Mercury {
}
}
-fn build_prompt(
- events: &[Arc],
- related_files: &[RelatedFile],
- cursor_buffer: &BufferSnapshot,
- cursor_buffer_path: &Path,
- cursor_point: Point,
- editable_range: Range,
- context_range: Range,
-) -> String {
+fn build_prompt(inputs: &ZetaPromptInput) -> String {
const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
@@ -237,14 +239,14 @@ fn build_prompt(
&mut prompt,
RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
|prompt| {
- for related_file in related_files {
+ for related_file in inputs.related_files.iter() {
for related_excerpt in &related_file.excerpts {
push_delimited(
prompt,
RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
|prompt| {
prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
- prompt.push_str(related_file.path.path.as_unix_str());
+ prompt.push_str(related_file.path.to_string_lossy().as_ref());
prompt.push('\n');
prompt.push_str(&related_excerpt.text.to_string());
},
@@ -259,21 +261,22 @@ fn build_prompt(
CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
|prompt| {
prompt.push_str(CURRENT_FILE_PATH_PREFIX);
- prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
+ prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
prompt.push('\n');
- let prefix_range = context_range.start..editable_range.start;
- let suffix_range = editable_range.end..context_range.end;
-
- prompt.extend(cursor_buffer.text_for_range(prefix_range));
+ prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]);
push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
- let range_before_cursor = editable_range.start..cursor_point;
- let range_after_cursor = cursor_point..editable_range.end;
- prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
+ prompt.push_str(
+ &inputs.cursor_excerpt
+ [inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt],
+ );
prompt.push_str(CURSOR_TAG);
- prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
+ prompt.push_str(
+ &inputs.cursor_excerpt
+ [inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end],
+ );
});
- prompt.extend(cursor_buffer.text_for_range(suffix_range));
+ prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]);
},
);
@@ -281,8 +284,8 @@ fn build_prompt(
&mut prompt,
EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
|prompt| {
- for event in events {
- writeln!(prompt, "{event}").unwrap();
+ for event in inputs.events.iter() {
+ zeta_prompt::write_event(prompt, &event);
}
},
);
diff --git a/crates/edit_prediction/src/prediction.rs b/crates/edit_prediction/src/prediction.rs
index 8aa2a8218568a99404cc9aceff36b84127700152..c63640ccd0e1815b32f736e8a0fee8d75d124df1 100644
--- a/crates/edit_prediction/src/prediction.rs
+++ b/crates/edit_prediction/src/prediction.rs
@@ -1,6 +1,5 @@
use std::{
ops::Range,
- path::Path,
sync::Arc,
time::{Duration, Instant},
};
@@ -9,7 +8,7 @@ use cloud_llm_client::EditPredictionRejectReason;
use edit_prediction_types::interpolate_edits;
use gpui::{AsyncApp, Entity, SharedString};
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
-use serde::Serialize;
+use zeta_prompt::ZetaPromptInput;
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct EditPredictionId(pub SharedString);
@@ -40,7 +39,7 @@ impl EditPredictionResult {
edits: Arc<[(Range, Arc)]>,
buffer_snapshotted_at: Instant,
response_received_at: Instant,
- inputs: EditPredictionInputs,
+ inputs: ZetaPromptInput,
cx: &mut AsyncApp,
) -> Self {
if edits.is_empty() {
@@ -94,15 +93,7 @@ pub struct EditPrediction {
pub buffer: Entity,
pub buffer_snapshotted_at: Instant,
pub response_received_at: Instant,
- pub inputs: EditPredictionInputs,
-}
-
-#[derive(Debug, Clone, Serialize)]
-pub struct EditPredictionInputs {
- pub events: Vec>,
- pub included_files: Vec,
- pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
- pub cursor_path: Arc,
+ pub inputs: zeta_prompt::ZetaPromptInput,
}
impl EditPrediction {
@@ -133,9 +124,12 @@ impl std::fmt::Debug for EditPrediction {
#[cfg(test)]
mod tests {
+ use std::path::Path;
+
use super::*;
use gpui::{App, Entity, TestAppContext, prelude::*};
use language::{Buffer, ToOffset as _};
+ use zeta_prompt::ZetaPromptInput;
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
@@ -154,14 +148,13 @@ mod tests {
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
buffer: buffer.clone(),
edit_preview,
- inputs: EditPredictionInputs {
+ inputs: ZetaPromptInput {
events: vec![],
- included_files: vec![],
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- line: cloud_llm_client::predict_edits_v3::Line(0),
- column: 0,
- },
+ related_files: vec![].into(),
cursor_path: Path::new("path.txt").into(),
+ cursor_offset_in_excerpt: 0,
+ cursor_excerpt: "".into(),
+ editable_range_in_excerpt: 0..0,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
diff --git a/crates/edit_prediction/src/sweep_ai.rs b/crates/edit_prediction/src/sweep_ai.rs
index 4bb014c640cb489db29c800835a58febf91a7270..f65749ceadf6e05fc3b56838c03234b2f83dc51e 100644
--- a/crates/edit_prediction/src/sweep_ai.rs
+++ b/crates/edit_prediction/src/sweep_ai.rs
@@ -1,26 +1,21 @@
use anyhow::{Context as _, Result};
-use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
-use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
- App, AppContext as _, Entity, Task,
+ App, AppContext as _, Task,
http_client::{self, AsyncBody, Method},
};
-use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
+use language::{Point, ToOffset as _};
use lsp::DiagnosticSeverity;
-use project::{Project, ProjectPath};
use serde::{Deserialize, Serialize};
use std::{
- collections::VecDeque,
fmt::{self, Write as _},
- ops::Range,
path::Path,
sync::Arc,
time::Instant,
};
-use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
+use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
@@ -44,40 +39,34 @@ impl SweepAi {
pub fn request_prediction_with_sweep(
&self,
- project: &Entity,
- active_buffer: &Entity,
- snapshot: BufferSnapshot,
- position: language::Anchor,
- events: Vec>,
- recent_paths: &VecDeque,
- related_files: Vec,
- diagnostic_search_range: Range,
+ inputs: EditPredictionModelInput,
cx: &mut App,
) -> Task>> {
let debug_info = self.debug_info.clone();
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
return Task::ready(Ok(None));
};
- let full_path: Arc = snapshot
+ let full_path: Arc = inputs
+ .snapshot
.file()
.map(|file| file.full_path(cx))
.unwrap_or_else(|| "untitled".into())
.into();
- let project_file = project::File::from_dyn(snapshot.file());
+ let project_file = project::File::from_dyn(inputs.snapshot.file());
let repo_name = project_file
.map(|file| file.worktree.read(cx).root_name_str())
.unwrap_or("untitled")
.into();
- let offset = position.to_offset(&snapshot);
+ let offset = inputs.position.to_offset(&inputs.snapshot);
- let recent_buffers = recent_paths.iter().cloned();
+ let recent_buffers = inputs.recent_paths.iter().cloned();
let http_client = cx.http_client();
let recent_buffer_snapshots = recent_buffers
.filter_map(|project_path| {
- let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
- if active_buffer == &buffer {
+ let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
+ if inputs.buffer == buffer {
None
} else {
Some(buffer.read(cx).snapshot())
@@ -86,14 +75,13 @@ impl SweepAi {
.take(3)
.collect::>();
- let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
let result = cx.background_spawn(async move {
- let text = snapshot.text();
+ let text = inputs.snapshot.text();
let mut recent_changes = String::new();
- for event in &events {
+ for event in &inputs.events {
write_event(event.as_ref(), &mut recent_changes).unwrap();
}
@@ -122,20 +110,23 @@ impl SweepAi {
})
.collect::>();
- let retrieval_chunks = related_files
+ let retrieval_chunks = inputs
+ .related_files
.iter()
.flat_map(|related_file| {
related_file.excerpts.iter().map(|excerpt| FileChunk {
- file_path: related_file.path.path.as_unix_str().to_string(),
- start_line: excerpt.point_range.start.row as usize,
- end_line: excerpt.point_range.end.row as usize,
+ file_path: related_file.path.to_string_lossy().to_string(),
+ start_line: excerpt.row_range.start as usize,
+ end_line: excerpt.row_range.end as usize,
content: excerpt.text.to_string(),
timestamp: None,
})
})
.collect();
- let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
+ let diagnostic_entries = inputs
+ .snapshot
+ .diagnostics_in_range(inputs.diagnostic_search_range, false);
let mut diagnostic_content = String::new();
let mut diagnostic_count = 0;
@@ -195,21 +186,14 @@ impl SweepAi {
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
- let inputs = EditPredictionInputs {
- events,
- included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
- path: full_path.clone(),
- max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
- excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
- start_line: cloud_llm_client::predict_edits_v3::Line(0),
- text: request_body.file_contents.into(),
- }],
- }],
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- column: cursor_point.column,
- line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
- },
+ let ep_inputs = zeta_prompt::ZetaPromptInput {
+ events: inputs.events,
+ related_files: inputs.related_files.clone(),
cursor_path: full_path.clone(),
+ cursor_excerpt: request_body.file_contents.into(),
+ // we actually don't know
+ editable_range_in_excerpt: 0..inputs.snapshot.len(),
+ cursor_offset_in_excerpt: request_body.cursor_position,
};
let request = http_client::Request::builder()
@@ -237,15 +221,20 @@ impl SweepAi {
let response: AutocompleteResponse = serde_json::from_slice(&body)?;
- let old_text = snapshot
+ let old_text = inputs
+ .snapshot
.text_for_range(response.start_index..response.end_index)
.collect::();
let edits = language::text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
- snapshot.anchor_after(response.start_index + range.start)
- ..snapshot.anchor_before(response.start_index + range.end),
+ inputs
+ .snapshot
+ .anchor_after(response.start_index + range.start)
+ ..inputs
+ .snapshot
+ .anchor_before(response.start_index + range.end),
text,
)
})
@@ -254,13 +243,13 @@ impl SweepAi {
anyhow::Ok((
response.autocomplete_id,
edits,
- snapshot,
+ inputs.snapshot,
response_received_at,
- inputs,
+ ep_inputs,
))
});
- let buffer = active_buffer.clone();
+ let buffer = inputs.buffer.clone();
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
@@ -403,12 +392,9 @@ struct AdditionalCompletion {
pub finish_reason: Option,
}
-fn write_event(
- event: &cloud_llm_client::predict_edits_v3::Event,
- f: &mut impl fmt::Write,
-) -> fmt::Result {
+fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
match event {
- cloud_llm_client::predict_edits_v3::Event::BufferChange {
+ zeta_prompt::Event::BufferChange {
old_path,
path,
diff,
diff --git a/crates/edit_prediction/src/udiff.rs b/crates/edit_prediction/src/udiff.rs
index 5ae029c6c16c2c6b6d0c2451cc961e8399a64a8f..b9cf564c16d68a98baa1986333f2bfd767c6a24b 100644
--- a/crates/edit_prediction/src/udiff.rs
+++ b/crates/edit_prediction/src/udiff.rs
@@ -14,68 +14,18 @@ use anyhow::anyhow;
use collections::HashMap;
use gpui::AsyncApp;
use gpui::Entity;
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot};
+use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
use project::Project;
-pub async fn parse_diff<'a>(
- diff_str: &'a str,
- get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range])> + Send,
-) -> Result<(&'a BufferSnapshot, Vec<(Range, Arc)>)> {
- let mut diff = DiffParser::new(diff_str);
- let mut edited_buffer = None;
- let mut edits = Vec::new();
-
- while let Some(event) = diff.next()? {
- match event {
- DiffEvent::Hunk {
- path: file_path,
- hunk,
- } => {
- let (buffer, ranges) = match edited_buffer {
- None => {
- edited_buffer = get_buffer(&Path::new(file_path.as_ref()));
- edited_buffer
- .as_ref()
- .context("Model tried to edit a file that wasn't included")?
- }
- Some(ref current) => current,
- };
-
- edits.extend(
- resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)
- .with_context(|| format!("Diff:\n{diff_str}"))?,
- );
- }
- DiffEvent::FileEnd { renamed_to } => {
- let (buffer, _) = edited_buffer
- .take()
- .context("Got a FileEnd event before an Hunk event")?;
-
- if renamed_to.is_some() {
- anyhow::bail!("edit predictions cannot rename files");
- }
-
- if diff.next()?.is_some() {
- anyhow::bail!("Edited more than one file");
- }
-
- return Ok((buffer, edits));
- }
- }
- }
-
- Err(anyhow::anyhow!("No EOF"))
-}
-
-#[derive(Debug)]
-pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap, Entity>);
+#[derive(Clone, Debug)]
+pub struct OpenedBuffers(#[allow(unused)] HashMap>);
#[must_use]
-pub async fn apply_diff<'a>(
- diff_str: &'a str,
+pub async fn apply_diff(
+ diff_str: &str,
project: &Entity,
cx: &mut AsyncApp,
-) -> Result> {
+) -> Result {
let mut included_files = HashMap::default();
for line in diff_str.lines() {
@@ -94,7 +44,7 @@ pub async fn apply_diff<'a>(
})??
.await?;
- included_files.insert(path, buffer);
+ included_files.insert(path.to_string(), buffer);
}
}
@@ -113,7 +63,7 @@ pub async fn apply_diff<'a>(
let (buffer, ranges) = match current_file {
None => {
let buffer = included_files
- .get_mut(&file_path)
+ .get_mut(file_path.as_ref())
.expect("Opened all files in diff");
current_file = Some((buffer, ranges.as_slice()));
@@ -167,6 +117,29 @@ pub async fn apply_diff<'a>(
Ok(OpenedBuffers(included_files))
}
+pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result {
+ let mut diff = DiffParser::new(diff_str);
+
+ let mut text = text.to_string();
+
+ while let Some(event) = diff.next()? {
+ match event {
+ DiffEvent::Hunk { hunk, .. } => {
+ let hunk_offset = text
+ .find(&hunk.context)
+ .ok_or_else(|| anyhow!("couldn't result hunk {:?}", hunk.context))?;
+ for edit in hunk.edits.iter().rev() {
+ let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
+ text.replace_range(range, &edit.text);
+ }
+ }
+ DiffEvent::FileEnd { .. } => {}
+ }
+ }
+
+ Ok(text)
+}
+
struct PatchFile<'a> {
old_path: Cow<'a, str>,
new_path: Cow<'a, str>,
@@ -492,7 +465,6 @@ mod tests {
use super::*;
use gpui::TestAppContext;
use indoc::indoc;
- use language::Point;
use pretty_assertions::assert_eq;
use project::{FakeFs, Project};
use serde_json::json;
@@ -817,137 +789,6 @@ mod tests {
});
}
- #[gpui::test]
- async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
- let fs = init_test(cx);
-
- let buffer_1_text = indoc! {r#"
- one
- two
- three
- four
- five
- one
- two
- three
- four
- five
- "# };
-
- fs.insert_tree(
- path!("/root"),
- json!({
- "file1": buffer_1_text,
- }),
- )
- .await;
-
- let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/root/file1"), cx)
- })
- .await
- .unwrap();
- let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
- let diff = indoc! {r#"
- --- a/root/file1
- +++ b/root/file1
- one
- two
- -three
- +3
- four
- five
- "#};
-
- let final_text = indoc! {r#"
- one
- two
- three
- four
- five
- one
- two
- 3
- four
- five
- "#};
-
- apply_diff(diff, &project, &mut cx.to_async())
- .await
- .expect_err("Non-unique edits should fail");
-
- let ranges = [buffer_snapshot.anchor_before(Point::new(1, 0))
- ..buffer_snapshot.anchor_after(buffer_snapshot.max_point())];
-
- let (edited_snapshot, edits) = parse_diff(diff, |_path| Some((&buffer_snapshot, &ranges)))
- .await
- .unwrap();
-
- assert_eq!(edited_snapshot.remote_id(), buffer_snapshot.remote_id());
- buffer.update(cx, |buffer, cx| {
- buffer.edit(edits, None, cx);
- assert_eq!(buffer.text(), final_text);
- });
- }
-
- #[gpui::test]
- async fn test_parse_diff_with_edits_within_line(cx: &mut TestAppContext) {
- let fs = init_test(cx);
-
- let buffer_1_text = indoc! {r#"
- one two three four
- five six seven eight
- nine ten eleven twelve
- "# };
-
- fs.insert_tree(
- path!("/root"),
- json!({
- "file1": buffer_1_text,
- }),
- )
- .await;
-
- let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/root/file1"), cx)
- })
- .await
- .unwrap();
- let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
- let diff = indoc! {r#"
- --- a/root/file1
- +++ b/root/file1
- one two three four
- -five six seven eight
- +five SIX seven eight!
- nine ten eleven twelve
- "#};
-
- let (buffer, edits) = parse_diff(diff, |_path| {
- Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_]))
- })
- .await
- .unwrap();
-
- let edits = edits
- .into_iter()
- .map(|(range, text)| (range.to_point(&buffer), text))
- .collect::>();
- assert_eq!(
- edits,
- &[
- (Point::new(1, 5)..Point::new(1, 8), "SIX".into()),
- (Point::new(1, 20)..Point::new(1, 20), "!".into())
- ]
- );
- }
-
#[gpui::test]
async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
let fs = init_test(cx);
diff --git a/crates/edit_prediction/src/xml_edits.rs b/crates/edit_prediction/src/xml_edits.rs
deleted file mode 100644
index ee8dd47cb25ad3dcd2c3d7d172b62e724b41c22d..0000000000000000000000000000000000000000
--- a/crates/edit_prediction/src/xml_edits.rs
+++ /dev/null
@@ -1,637 +0,0 @@
-use anyhow::{Context as _, Result};
-use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
-use std::{cmp, ops::Range, path::Path, sync::Arc};
-
-const EDITS_TAG_NAME: &'static str = "edits";
-const OLD_TEXT_TAG_NAME: &'static str = "old_text";
-const NEW_TEXT_TAG_NAME: &'static str = "new_text";
-const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME];
-
-pub async fn parse_xml_edits<'a>(
- input: &'a str,
- get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range])> + Send,
-) -> Result<(&'a BufferSnapshot, Vec<(Range, Arc)>)> {
- parse_xml_edits_inner(input, get_buffer)
- .await
- .with_context(|| format!("Failed to parse XML edits:\n{input}"))
-}
-
-async fn parse_xml_edits_inner<'a>(
- input: &'a str,
- get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range])> + Send,
-) -> Result<(&'a BufferSnapshot, Vec<(Range, Arc)>)> {
- let xml_edits = extract_xml_replacements(input)?;
-
- let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref())
- .with_context(|| format!("no buffer for file {}", xml_edits.file_path))?;
-
- let mut all_edits = vec![];
- for (old_text, new_text) in xml_edits.replacements {
- let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?;
- let matched_old_text = buffer
- .text_for_range(match_range.clone())
- .collect::();
- let edits_within_hunk = language::text_diff(&matched_old_text, new_text);
- all_edits.extend(
- edits_within_hunk
- .into_iter()
- .map(move |(inner_range, inner_text)| {
- (
- buffer.anchor_after(match_range.start + inner_range.start)
- ..buffer.anchor_before(match_range.start + inner_range.end),
- inner_text,
- )
- }),
- );
- }
-
- Ok((buffer, all_edits))
-}
-
-fn fuzzy_match_in_ranges(
- old_text: &str,
- buffer: &BufferSnapshot,
- context_ranges: &[Range],
-) -> Result> {
- let mut state = FuzzyMatcher::new(buffer, old_text);
- let mut best_match = None;
- let mut tie_match_range = None;
-
- for range in context_ranges {
- let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
- match (best_match_cost, state.match_range(range.to_offset(buffer))) {
- (Some(lowest_cost), Some((new_cost, new_range))) => {
- if new_cost == lowest_cost {
- tie_match_range = Some(new_range);
- } else if new_cost < lowest_cost {
- tie_match_range.take();
- best_match = Some((new_cost, new_range));
- }
- }
- (None, Some(new_match)) => {
- best_match = Some(new_match);
- }
- (None, None) | (Some(_), None) => {}
- };
- }
-
- if let Some((_, best_match_range)) = best_match {
- if let Some(tie_match_range) = tie_match_range {
- anyhow::bail!(
- "Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
- best_match_range.clone(),
- buffer.text_for_range(best_match_range).collect::(),
- tie_match_range.clone(),
- buffer.text_for_range(tie_match_range).collect::()
- );
- }
- return Ok(best_match_range);
- }
-
- anyhow::bail!(
- "Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
- old_text,
- context_ranges
- .iter()
- .map(|range| buffer.text_for_range(range.clone()).collect::())
- .collect::>()
- .join("```\n```")
- );
-}
-
-#[derive(Debug)]
-struct XmlEdits<'a> {
- file_path: &'a str,
- /// Vec of (old_text, new_text) pairs
- replacements: Vec<(&'a str, &'a str)>,
-}
-
-fn extract_xml_replacements(input: &str) -> Result> {
- let mut cursor = 0;
-
- let (edits_body_start, edits_attrs) =
- find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?;
-
- let file_path = edits_attrs
- .trim_start()
- .strip_prefix("path")
- .context("no path attribute on edits tag")?
- .trim_end()
- .strip_prefix('=')
- .context("no value for path attribute")?
- .trim()
- .trim_start_matches('"')
- .trim_end_matches('"');
-
- cursor = edits_body_start;
- let mut edits_list = Vec::new();
-
- while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? {
- let old_body_end = find_tag_close(input, &mut cursor)?;
- let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]);
-
- let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)?
- .context("no new_text tag following old_text")?;
- let new_body_end = find_tag_close(input, &mut cursor)?;
- let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]);
-
- edits_list.push((old_text, new_text));
- }
-
- Ok(XmlEdits {
- file_path,
- replacements: edits_list,
- })
-}
-
-/// Trims a single leading and trailing newline
-fn trim_surrounding_newlines(input: &str) -> &str {
- let start = input.strip_prefix('\n').unwrap_or(input);
- let end = start.strip_suffix('\n').unwrap_or(start);
- end
-}
-
-fn find_tag_open<'a>(
- input: &'a str,
- cursor: &mut usize,
- expected_tag: &str,
-) -> Result
-
- new content
-
- edits >
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "test.rs");
- assert_eq!(result.replacements.len(), 1);
- assert_eq!(result.replacements[0].0, "old content");
- assert_eq!(result.replacements[0].1, "new content");
- }
-
- #[test]
- fn test_extract_xml_edits_with_xml_like_content() {
- let input = indoc! {r#"
-
-
-
-
-
-
-
-
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "component.tsx");
- assert_eq!(result.replacements.len(), 1);
- assert_eq!(result.replacements[0].0, "");
- assert_eq!(
- result.replacements[0].1,
- ""
- );
- }
-
- #[test]
- fn test_extract_xml_edits_with_conflicting_content() {
- let input = indoc! {r#"
-
-
-
-
-
-
-
-
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "component.tsx");
- assert_eq!(result.replacements.len(), 1);
- assert_eq!(result.replacements[0].0, "");
- assert_eq!(result.replacements[0].1, "");
- }
-
- #[test]
- fn test_extract_xml_edits_multiple_pairs() {
- let input = indoc! {r#"
- Some reasoning before edits. Lots of thinking going on here
-
-
-
- first old
-
-
- first new
-
-
- second old
-
-
- second new
-
-
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "test.rs");
- assert_eq!(result.replacements.len(), 2);
- assert_eq!(result.replacements[0].0, "first old");
- assert_eq!(result.replacements[0].1, "first new");
- assert_eq!(result.replacements[1].0, "second old");
- assert_eq!(result.replacements[1].1, "second new");
- }
-
- #[test]
- fn test_extract_xml_edits_unexpected_eof() {
- let input = indoc! {r#"
-
-
- first old
-
- "#};
-
- extract_xml_replacements(input).expect_err("Unexpected end of file");
- }
-
- #[gpui::test]
- async fn test_parse_xml_edits(cx: &mut TestAppContext) {
- let fs = init_test(cx);
-
- let buffer_1_text = indoc! {r#"
- one two three four
- five six seven eight
- nine ten eleven twelve
- thirteen fourteen fifteen
- sixteen seventeen eighteen
- "#};
-
- fs.insert_tree(
- path!("/root"),
- json!({
- "file1": buffer_1_text,
- }),
- )
- .await;
-
- let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/root/file1"), cx)
- })
- .await
- .unwrap();
- let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
- let edits = indoc! {r#"
-
-
- nine ten eleven twelve
-
-
- nine TEN eleven twelve!
-
-
- "#};
-
- let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
- let (buffer, edits) = parse_xml_edits(edits, |_path| {
- Some((&buffer_snapshot, included_ranges.as_slice()))
- })
- .await
- .unwrap();
-
- let edits = edits
- .into_iter()
- .map(|(range, text)| (range.to_point(&buffer), text))
- .collect::>();
- assert_eq!(
- edits,
- &[
- (Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
- (Point::new(2, 22)..Point::new(2, 22), "!".into())
- ]
- );
- }
-
- fn init_test(cx: &mut TestAppContext) -> Arc {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- });
-
- FakeFs::new(cx.background_executor.clone())
- }
-}
diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs
index ad630484d392d75849bd33a52a55e63ea77ca23f..ed531749cb39d10d71d18947990dd1972f23a986 100644
--- a/crates/edit_prediction/src/zeta1.rs
+++ b/crates/edit_prediction/src/zeta1.rs
@@ -1,22 +1,23 @@
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
use crate::{
- EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
+ DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
+ EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
- prediction::{EditPredictionInputs, EditPredictionResult},
+ prediction::EditPredictionResult,
};
use anyhow::{Context as _, Result};
use cloud_llm_client::{
PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
- predict_edits_v3::Event,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
use language::{
- Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
+ Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
};
use project::{Project, ProjectPath};
use release_channel::AppVersion;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
+use zeta_prompt::{Event, ZetaPromptInput};
const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
@@ -29,24 +30,27 @@ pub(crate) const MAX_EVENT_TOKENS: usize = 500;
pub(crate) fn request_prediction_with_zeta1(
store: &mut EditPredictionStore,
- project: &Entity,
- buffer: &Entity,
- snapshot: BufferSnapshot,
- position: language::Anchor,
- events: Vec>,
- trigger: PredictEditsRequestTrigger,
+ EditPredictionModelInput {
+ project,
+ buffer,
+ snapshot,
+ position,
+ events,
+ trigger,
+ debug_tx,
+ ..
+ }: EditPredictionModelInput,
cx: &mut Context,
) -> Task>> {
- let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
- let can_collect_file = store.can_collect_file(project, file, cx);
+ let can_collect_file = store.can_collect_file(&project, file, cx);
let git_info = if can_collect_file {
- git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
+ git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
} else {
None
};
@@ -120,33 +124,33 @@ pub(crate) fn request_prediction_with_zeta1(
)
.await;
- let inputs = EditPredictionInputs {
+ let context_start_offset = context_range.start.to_offset(&snapshot);
+ let editable_offset_range = editable_range.to_offset(&snapshot);
+
+ let inputs = ZetaPromptInput {
events: included_events.into(),
- included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
- path: full_path.clone(),
- max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
- excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
- start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
- text: snapshot
- .text_for_range(context_range)
- .collect::()
- .into(),
- }],
- }],
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- column: cursor_point.column,
- line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
- },
+ related_files: vec![].into(),
cursor_path: full_path,
+ cursor_excerpt: snapshot
+ .text_for_range(context_range)
+ .collect::()
+ .into(),
+ editable_range_in_excerpt: (editable_range.start - context_start_offset)
+ ..(editable_offset_range.end - context_start_offset),
+ cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
};
- // let response = perform_predict_edits(PerformPredictEditsParams {
- // client,
- // llm_token,
- // app_version,
- // body,
- // })
- // .await;
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionStarted(
+ EditPredictionStartedDebugEvent {
+ buffer: buffer.downgrade(),
+ prompt: Some(serde_json::to_string(&inputs).unwrap()),
+ position,
+ },
+ ))
+ .ok();
+ }
let (response, usage) = match response {
Ok(response) => response,
@@ -189,6 +193,18 @@ pub(crate) fn request_prediction_with_zeta1(
.ok();
}
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionFinished(
+ EditPredictionFinishedDebugEvent {
+ buffer: buffer.downgrade(),
+ model_output: Some(response.output_excerpt.clone()),
+ position,
+ },
+ ))
+ .ok();
+ }
+
let edit_prediction = process_completion_response(
response,
buffer,
@@ -226,7 +242,7 @@ fn process_completion_response(
buffer: Entity,
snapshot: &BufferSnapshot,
editable_range: Range,
- inputs: EditPredictionInputs,
+ inputs: ZetaPromptInput,
buffer_snapshotted_at: Instant,
received_response_at: Instant,
cx: &AsyncApp,
diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs
index e542bc7e86e6e381766bbedac6a15f431e0693f1..034954f5760939fc31b3e5e1e8a09737c5b2e568 100644
--- a/crates/edit_prediction/src/zeta2.rs
+++ b/crates/edit_prediction/src/zeta2.rs
@@ -3,46 +3,39 @@ use crate::EvalCacheEntryKind;
use crate::open_ai_response::text_from_response;
use crate::prediction::EditPredictionResult;
use crate::{
- DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
- EditPredictionRequestedDebugEvent, EditPredictionStore,
+ DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, EditPredictionId,
+ EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
};
-use anyhow::{Result, anyhow, bail};
-use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
-use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
-use cloud_zeta2_prompt::CURSOR_MARKER;
-use edit_prediction_context::{EditPredictionExcerpt, Line};
-use edit_prediction_context::{RelatedExcerpt, RelatedFile};
-use futures::channel::oneshot;
-use gpui::{Entity, Task, prelude::*};
-use language::{Anchor, BufferSnapshot};
-use language::{Buffer, Point, ToOffset as _, ToPoint};
-use project::{Project, ProjectItem as _};
+use anyhow::{Result, anyhow};
+use cloud_llm_client::EditPredictionRejectReason;
+use gpui::{Task, prelude::*};
+use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
use release_channel::AppVersion;
-use std::{
- env,
- path::Path,
- sync::Arc,
- time::{Duration, Instant},
-};
+use std::{path::Path, sync::Arc, time::Instant};
+use zeta_prompt::CURSOR_MARKER;
+use zeta_prompt::format_zeta_prompt;
+
+const MAX_CONTEXT_TOKENS: usize = 150;
+const MAX_REWRITE_TOKENS: usize = 350;
pub fn request_prediction_with_zeta2(
store: &mut EditPredictionStore,
- project: &Entity,
- active_buffer: &Entity,
- active_snapshot: BufferSnapshot,
- position: Anchor,
- events: Vec>,
- mut included_files: Vec,
- trigger: PredictEditsRequestTrigger,
+ EditPredictionModelInput {
+ buffer,
+ snapshot,
+ position,
+ related_files,
+ events,
+ debug_tx,
+ ..
+ }: EditPredictionModelInput,
cx: &mut Context,
) -> Task>> {
- let options = store.options.clone();
let buffer_snapshotted_at = Instant::now();
- let Some((excerpt_path, active_project_path)) = active_snapshot
+ let Some(excerpt_path) = snapshot
.file()
.map(|file| -> Arc { file.full_path(cx).into() })
- .zip(active_buffer.read(cx).project_path(cx))
else {
return Task::ready(Err(anyhow!("No file path for excerpt")));
};
@@ -50,148 +43,35 @@ pub fn request_prediction_with_zeta2(
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
- let debug_tx = store.debug_tx.clone();
-
- let file = active_buffer.read(cx).file();
-
- let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
-
- // TODO data collection
- let can_collect_data = file
- .as_ref()
- .map_or(false, |file| store.can_collect_file(project, file, cx));
#[cfg(feature = "eval-support")]
let eval_cache = store.eval_cache.clone();
let request_task = cx.background_spawn({
- let active_buffer = active_buffer.clone();
async move {
- let cursor_offset = position.to_offset(&active_snapshot);
- let cursor_point = cursor_offset.to_point(&active_snapshot);
-
- let before_retrieval = Instant::now();
-
- let excerpt_options = options.context;
-
- let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
- cursor_point,
- &active_snapshot,
- &excerpt_options,
- ) else {
- return Ok((None, None));
- };
-
- let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
- ..active_snapshot.anchor_before(excerpt.range.end);
- let related_excerpt = RelatedExcerpt {
- anchor_range: excerpt_anchor_range.clone(),
- point_range: Point::new(excerpt.line_range.start.0, 0)
- ..Point::new(excerpt.line_range.end.0, 0),
- text: active_snapshot.as_rope().slice(excerpt.range),
- };
-
- if let Some(buffer_ix) = included_files
- .iter()
- .position(|file| file.buffer.entity_id() == active_buffer.entity_id())
- {
- let file = &mut included_files[buffer_ix];
- file.excerpts.push(related_excerpt);
- file.merge_excerpts();
- let last_ix = included_files.len() - 1;
- included_files.swap(buffer_ix, last_ix);
- } else {
- let active_file = RelatedFile {
- path: active_project_path,
- buffer: active_buffer.downgrade(),
- excerpts: vec![related_excerpt],
- max_row: active_snapshot.max_point().row,
- };
- included_files.push(active_file);
- }
-
- let included_files = included_files
- .iter()
- .map(|related_file| predict_edits_v3::RelatedFile {
- path: Arc::from(related_file.path.path.as_std_path()),
- max_row: Line(related_file.max_row),
- excerpts: related_file
- .excerpts
- .iter()
- .map(|excerpt| predict_edits_v3::Excerpt {
- start_line: Line(excerpt.point_range.start.row),
- text: excerpt.text.to_string().into(),
- })
- .collect(),
- })
- .collect::>();
-
- let cloud_request = predict_edits_v3::PredictEditsRequest {
- excerpt_path,
- excerpt: String::new(),
- excerpt_line_range: Line(0)..Line(0),
- excerpt_range: 0..0,
- cursor_point: predict_edits_v3::Point {
- line: predict_edits_v3::Line(cursor_point.row),
- column: cursor_point.column,
- },
- related_files: included_files,
+ let cursor_offset = position.to_offset(&snapshot);
+ let (editable_offset_range, prompt_input) = zeta2_prompt_input(
+ &snapshot,
+ related_files,
events,
- can_collect_data,
- debug_info: debug_tx.is_some(),
- prompt_max_bytes: Some(options.max_prompt_bytes),
- prompt_format: options.prompt_format,
- excerpt_parent: None,
- git_info: None,
- trigger,
- };
-
- let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
-
- let inputs = EditPredictionInputs {
- included_files: cloud_request.related_files,
- events: cloud_request.events,
- cursor_point: cloud_request.cursor_point,
- cursor_path: cloud_request.excerpt_path,
- };
-
- let retrieval_time = Instant::now() - before_retrieval;
+ excerpt_path,
+ cursor_offset,
+ );
- let debug_response_tx = if let Some(debug_tx) = &debug_tx {
- let (response_tx, response_rx) = oneshot::channel();
+ let prompt = format_zeta_prompt(&prompt_input);
+ if let Some(debug_tx) = &debug_tx {
debug_tx
- .unbounded_send(DebugEvent::EditPredictionRequested(
- EditPredictionRequestedDebugEvent {
- inputs: inputs.clone(),
- retrieval_time,
- buffer: active_buffer.downgrade(),
- local_prompt: match prompt_result.as_ref() {
- Ok(prompt) => Ok(prompt.clone()),
- Err(err) => Err(err.to_string()),
- },
+ .unbounded_send(DebugEvent::EditPredictionStarted(
+ EditPredictionStartedDebugEvent {
+ buffer: buffer.downgrade(),
+ prompt: Some(prompt.clone()),
position,
- response_rx,
},
))
.ok();
- Some(response_tx)
- } else {
- None
- };
-
- if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
- if let Some(debug_response_tx) = debug_response_tx {
- debug_response_tx
- .send((Err("Request skipped".to_string()), Duration::ZERO))
- .ok();
- }
- anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
}
- let prompt = prompt_result?;
- let generation_params =
- cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
let request = open_ai::Request {
model: EDIT_PREDICTIONS_MODEL_ID.clone(),
messages: vec![open_ai::RequestMessage::User {
@@ -199,8 +79,8 @@ pub fn request_prediction_with_zeta2(
}],
stream: false,
max_completion_tokens: None,
- stop: generation_params.stop.unwrap_or_default(),
- temperature: generation_params.temperature.or(Some(0.7)),
+ stop: Default::default(),
+ temperature: Default::default(),
tool_choice: None,
parallel_tool_calls: None,
tools: vec![],
@@ -210,7 +90,6 @@ pub fn request_prediction_with_zeta2(
log::trace!("Sending edit prediction request");
- let before_request = Instant::now();
let response = EditPredictionStore::send_raw_llm_request(
request,
client,
@@ -223,68 +102,53 @@ pub fn request_prediction_with_zeta2(
)
.await;
let received_response_at = Instant::now();
- let request_time = received_response_at - before_request;
log::trace!("Got edit prediction response");
- if let Some(debug_response_tx) = debug_response_tx {
- debug_response_tx
- .send((
- response
- .as_ref()
- .map_err(|err| err.to_string())
- .map(|response| response.0.clone()),
- request_time,
- ))
- .ok();
- }
-
let (res, usage) = response?;
let request_id = EditPredictionId(res.id.clone().into());
let Some(mut output_text) = text_from_response(res) else {
return Ok((Some((request_id, None)), usage));
};
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionFinished(
+ EditPredictionFinishedDebugEvent {
+ buffer: buffer.downgrade(),
+ position,
+ model_output: Some(output_text.clone()),
+ },
+ ))
+ .ok();
+ }
+
if output_text.contains(CURSOR_MARKER) {
log::trace!("Stripping out {CURSOR_MARKER} from response");
output_text = output_text.replace(CURSOR_MARKER, "");
}
- let get_buffer_from_context = |path: &Path| {
- if Some(path) == active_file_full_path.as_deref() {
- Some((
- &active_snapshot,
- std::slice::from_ref(&excerpt_anchor_range),
- ))
- } else {
- None
- }
- };
-
- let (_, edits) = match options.prompt_format {
- PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
- if output_text.contains("--- a/\n+++ b/\nNo edits") {
- let edits = vec![];
- (&active_snapshot, edits)
- } else {
- crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
- }
- }
- PromptFormat::OldTextNewText => {
- crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
- }
- _ => {
- bail!("unsupported prompt format {}", options.prompt_format)
- }
- };
+ let old_text = snapshot
+ .text_for_range(editable_offset_range.clone())
+ .collect::();
+ let edits: Vec<_> = language::text_diff(&old_text, &output_text)
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ snapshot.anchor_after(editable_offset_range.start + range.start)
+ ..snapshot.anchor_before(editable_offset_range.start + range.end),
+ text,
+ )
+ })
+ .collect();
anyhow::Ok((
Some((
request_id,
Some((
- inputs,
- active_buffer,
- active_snapshot.clone(),
+ prompt_input,
+ buffer,
+ snapshot.clone(),
edits,
received_response_at,
)),
@@ -325,3 +189,40 @@ pub fn request_prediction_with_zeta2(
))
})
}
+
+pub fn zeta2_prompt_input(
+ snapshot: &language::BufferSnapshot,
+ related_files: Arc<[zeta_prompt::RelatedFile]>,
+ events: Vec>,
+ excerpt_path: Arc,
+ cursor_offset: usize,
+) -> (std::ops::Range, zeta_prompt::ZetaPromptInput) {
+ let cursor_point = cursor_offset.to_point(snapshot);
+
+ let (editable_range, context_range) =
+ crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
+ cursor_point,
+ snapshot,
+ MAX_CONTEXT_TOKENS,
+ MAX_REWRITE_TOKENS,
+ );
+
+ let context_start_offset = context_range.start.to_offset(snapshot);
+ let editable_offset_range = editable_range.to_offset(snapshot);
+ let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
+ let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
+ ..(editable_offset_range.end - context_start_offset);
+
+ let prompt_input = zeta_prompt::ZetaPromptInput {
+ cursor_path: excerpt_path,
+ cursor_excerpt: snapshot
+ .text_for_range(context_range)
+ .collect::()
+ .into(),
+ editable_range_in_excerpt,
+ cursor_offset_in_excerpt,
+ events,
+ related_files,
+ };
+ (editable_offset_range, prompt_input)
+}
diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml
index 26a060994d75a2c194cc159c33d88fbc296dfa47..0e7fff8d70156c58147069f8da64035d6a80adc8 100644
--- a/crates/edit_prediction_cli/Cargo.toml
+++ b/crates/edit_prediction_cli/Cargo.toml
@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
workspace = true
[[bin]]
-name = "ep_cli"
+name = "ep"
path = "src/main.rs"
[dependencies]
@@ -20,10 +20,9 @@ chrono.workspace = true
clap.workspace = true
client.workspace = true
cloud_llm_client.workspace= true
-cloud_zeta2_prompt.workspace = true
collections.workspace = true
debug_adapter_extension.workspace = true
-edit_prediction_context.workspace = true
+dirs.workspace = true
extension.workspace = true
fs.workspace = true
futures.workspace = true
@@ -51,12 +50,21 @@ smol.workspace = true
sqlez.workspace = true
sqlez_macros.workspace = true
terminal_view.workspace = true
-toml.workspace = true
util.workspace = true
watch.workspace = true
edit_prediction = { workspace = true, features = ["eval-support"] }
+wasmtime.workspace = true
+zeta_prompt.workspace = true
zlog.workspace = true
+# Wasmtime is included as a dependency in order to enable the same
+# features that are enabled in Zed.
+#
+# If we don't enable these features we get crashes when creating
+# a Tree-sitter WasmStore.
+[package.metadata.cargo-machete]
+ignored = ["wasmtime"]
+
[dev-dependencies]
indoc.workspace = true
gpui = { workspace = true, features = ["test-support"] }
diff --git a/crates/edit_prediction_cli/src/training/llm_client.rs b/crates/edit_prediction_cli/src/anthropic_client.rs
similarity index 89%
rename from crates/edit_prediction_cli/src/training/llm_client.rs
rename to crates/edit_prediction_cli/src/anthropic_client.rs
index ebecbe915d36a9a456296e818e559c654370f939..8afc4d1c03f8a37ae258cc2926daf85caebe3d8a 100644
--- a/crates/edit_prediction_cli/src/training/llm_client.rs
+++ b/crates/edit_prediction_cli/src/anthropic_client.rs
@@ -5,11 +5,13 @@ use anthropic::{
use anyhow::Result;
use http_client::HttpClient;
use indoc::indoc;
+use reqwest_client::ReqwestClient;
use sqlez::bindable::Bind;
use sqlez::bindable::StaticColumnCount;
use sqlez_macros::sql;
use std::hash::Hash;
use std::hash::Hasher;
+use std::path::Path;
use std::sync::Arc;
pub struct PlainLlmClient {
@@ -18,7 +20,8 @@ pub struct PlainLlmClient {
}
impl PlainLlmClient {
- fn new(http_client: Arc) -> Result {
+ fn new() -> Result {
+ let http_client: Arc = Arc::new(ReqwestClient::new());
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
Ok(Self {
@@ -29,12 +32,12 @@ impl PlainLlmClient {
async fn generate(
&self,
- model: String,
+ model: &str,
max_tokens: u64,
messages: Vec,
) -> Result {
let request = AnthropicRequest {
- model,
+ model: model.to_string(),
max_tokens,
messages,
tools: Vec::new(),
@@ -105,11 +108,12 @@ struct SerializableMessage {
}
impl BatchingLlmClient {
- fn new(cache_path: &str, http_client: Arc) -> Result {
+ fn new(cache_path: &Path) -> Result {
+ let http_client: Arc = Arc::new(ReqwestClient::new());
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
- let connection = sqlez::connection::Connection::open_file(&cache_path);
+ let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
let mut statement = sqlez::statement::Statement::prepare(
&connection,
indoc! {"
@@ -182,16 +186,16 @@ impl BatchingLlmClient {
async fn generate(
&self,
- model: String,
+ model: &str,
max_tokens: u64,
messages: Vec,
) -> Result