Rework edit prediction CLI (#44562)

Max Brunsfeld , Oleksiy Syvokon , Agus Zubiaga , and Ben Kunkle created

This PR restructures the commands of the Edit Prediction CLI (now called
`ep`), to support some flows that are important for the training
process:
* generating zeta2 prompt and expected output, without running
predictions
* scoring outputs that are generated by a system other than the
production code (to evaluate the model during training)

To achieve this, we've restructured the CLI commands so that they all
take as input, and produce as output, a consistent, uniform data format:
a set of one or more `Example` structs, expressible either as the
original markdown format, or as a JSON lines. The `Example` struct
starts with the basic fields that are in human-readable eval format, but
contain a number of optional fields that are filled in by different
steps in the processing pipeline (`context`, `predict`, `format-prompt`,
and `score`).

### To do

* [x] Adjust the teacher model output parsing to use the full buffer
contents
* [x] Move udiff to cli
* [x] Align `format-prompt` with Zeta2's production code
* [x] Change score output to assume same provider
* [x] Move pretty reporting to `eval` command
* [x] Store cursor point in addition to cursor offset
* [x] Rename `edit_prediction_cli2` -> `edit_prediction_cli` (nuke the
old one)

Release Notes:

- N/A

---------

Co-authored-by: Oleksiy Syvokon <oleksiy@zed.dev>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Kunkle <ben@zed.dev>

Change summary

Cargo.lock                                                          |  30 
Cargo.toml                                                          |   5 
crates/client/Cargo.toml                                            |   2 
crates/cloud_zeta2_prompt/Cargo.toml                                |  18 
crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs                 | 485 
crates/edit_prediction/Cargo.toml                                   |   4 
crates/edit_prediction/src/edit_prediction.rs                       | 320 
crates/edit_prediction/src/edit_prediction_tests.rs                 | 218 
crates/edit_prediction/src/mercury.rs                               | 155 
crates/edit_prediction/src/prediction.rs                            |  29 
crates/edit_prediction/src/sweep_ai.rs                              |  96 
crates/edit_prediction/src/udiff.rs                                 | 221 
crates/edit_prediction/src/xml_edits.rs                             | 637 
crates/edit_prediction/src/zeta1.rs                                 |  90 
crates/edit_prediction/src/zeta2.rs                                 | 303 
crates/edit_prediction_cli/Cargo.toml                               |  16 
crates/edit_prediction_cli/src/anthropic_client.rs                  |  51 
crates/edit_prediction_cli/src/evaluate.rs                          | 641 
crates/edit_prediction_cli/src/example.rs                           | 797 
crates/edit_prediction_cli/src/format_prompt.rs                     | 280 
crates/edit_prediction_cli/src/headless.rs                          |   6 
crates/edit_prediction_cli/src/load_project.rs                      | 320 
crates/edit_prediction_cli/src/main.rs                              | 594 
crates/edit_prediction_cli/src/metrics.rs                           |  51 
crates/edit_prediction_cli/src/paths.rs                             |  68 
crates/edit_prediction_cli/src/predict.rs                           | 569 
crates/edit_prediction_cli/src/retrieve_context.rs                  | 172 
crates/edit_prediction_cli/src/score.rs                             | 119 
crates/edit_prediction_cli/src/source_location.rs                   |  70 
crates/edit_prediction_cli/src/teacher.prompt.md                    |   4 
crates/edit_prediction_cli/src/training/context.rs                  |  89 
crates/edit_prediction_cli/src/training/distill.rs                  |  94 
crates/edit_prediction_cli/src/training/mod.rs                      |   4 
crates/edit_prediction_cli/src/training/teacher.rs                  | 266 
crates/edit_prediction_context/Cargo.toml                           |   1 
crates/edit_prediction_context/src/assemble_excerpts.rs             |  13 
crates/edit_prediction_context/src/edit_prediction_context.rs       | 161 
crates/edit_prediction_context/src/edit_prediction_context_tests.rs |  12 
crates/edit_prediction_ui/Cargo.toml                                |   2 
crates/edit_prediction_ui/src/edit_prediction_context_view.rs       |  47 
crates/edit_prediction_ui/src/rate_prediction_modal.rs              |  43 
crates/zeta_prompt/Cargo.toml                                       |  15 
crates/zeta_prompt/LICENSE-GPL                                      |   0 
crates/zeta_prompt/src/zeta_prompt.rs                               | 165 
44 files changed, 2,492 insertions(+), 4,791 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3111,16 +3111,6 @@ dependencies = [
  "uuid",
 ]
 
-[[package]]
-name = "cloud_zeta2_prompt"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "cloud_llm_client",
- "indoc",
- "serde",
-]
-
 [[package]]
 name = "cmake"
 version = "0.1.54"
@@ -5119,7 +5109,6 @@ dependencies = [
  "clock",
  "cloud_api_types",
  "cloud_llm_client",
- "cloud_zeta2_prompt",
  "collections",
  "copilot",
  "credentials_provider",
@@ -5150,8 +5139,6 @@ dependencies = [
  "serde",
  "serde_json",
  "settings",
- "smol",
- "strsim",
  "strum 0.27.2",
  "telemetry",
  "telemetry_events",
@@ -5162,6 +5149,7 @@ dependencies = [
  "workspace",
  "worktree",
  "zed_actions",
+ "zeta_prompt",
  "zlog",
 ]
 
@@ -5175,11 +5163,10 @@ dependencies = [
  "clap",
  "client",
  "cloud_llm_client",
- "cloud_zeta2_prompt",
  "collections",
  "debug_adapter_extension",
+ "dirs 4.0.0",
  "edit_prediction",
- "edit_prediction_context",
  "extension",
  "fs",
  "futures 0.3.31",
@@ -5209,9 +5196,10 @@ dependencies = [
  "sqlez",
  "sqlez_macros",
  "terminal_view",
- "toml 0.8.23",
  "util",
+ "wasmtime",
  "watch",
+ "zeta_prompt",
  "zlog",
 ]
 
@@ -5239,6 +5227,7 @@ dependencies = [
  "text",
  "tree-sitter",
  "util",
+ "zeta_prompt",
  "zlog",
 ]
 
@@ -5260,7 +5249,6 @@ dependencies = [
  "buffer_diff",
  "client",
  "cloud_llm_client",
- "cloud_zeta2_prompt",
  "codestral",
  "command_palette_hooks",
  "copilot",
@@ -5291,6 +5279,7 @@ dependencies = [
  "util",
  "workspace",
  "zed_actions",
+ "zeta_prompt",
 ]
 
 [[package]]
@@ -20933,6 +20922,13 @@ dependencies = [
  "syn 2.0.106",
 ]
 
+[[package]]
+name = "zeta_prompt"
+version = "0.1.0"
+dependencies = [
+ "serde",
+]
+
 [[package]]
 name = "zip"
 version = "0.6.6"

Cargo.toml 🔗

@@ -32,7 +32,6 @@ members = [
     "crates/cloud_api_client",
     "crates/cloud_api_types",
     "crates/cloud_llm_client",
-    "crates/cloud_zeta2_prompt",
     "crates/collab",
     "crates/collab_ui",
     "crates/collections",
@@ -202,6 +201,7 @@ members = [
     "crates/zed_actions",
     "crates/zed_env_vars",
     "crates/edit_prediction_cli",
+    "crates/zeta_prompt",
     "crates/zlog",
     "crates/zlog_settings",
     "crates/ztracing",
@@ -266,7 +266,6 @@ clock = { path = "crates/clock" }
 cloud_api_client = { path = "crates/cloud_api_client" }
 cloud_api_types = { path = "crates/cloud_api_types" }
 cloud_llm_client = { path = "crates/cloud_llm_client" }
-cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
 collab_ui = { path = "crates/collab_ui" }
 collections = { path = "crates/collections", version = "0.1.0" }
 command_palette = { path = "crates/command_palette" }
@@ -425,6 +424,7 @@ zed = { path = "crates/zed" }
 zed_actions = { path = "crates/zed_actions" }
 zed_env_vars = { path = "crates/zed_env_vars" }
 edit_prediction = { path = "crates/edit_prediction" }
+zeta_prompt = { path = "crates/zeta_prompt" }
 zlog = { path = "crates/zlog" }
 zlog_settings = { path = "crates/zlog_settings" }
 ztracing = { path = "crates/ztracing" }
@@ -657,6 +657,7 @@ time = { version = "0.3", features = [
 tiny_http = "0.8"
 tokio = { version = "1" }
 tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
+tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io", "tokio"] }
 toml = "0.8"
 toml_edit = { version = "0.22", default-features = false, features = ["display", "parse", "serde"] }
 tower-http = "0.4.4"

crates/client/Cargo.toml 🔗

@@ -53,7 +53,7 @@ text.workspace = true
 thiserror.workspace = true
 time.workspace = true
 tiny_http.workspace = true
-tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
+tokio-socks.workspace = true
 tokio.workspace = true
 url.workspace = true
 util.workspace = true

crates/cloud_zeta2_prompt/Cargo.toml 🔗

@@ -1,18 +0,0 @@
-[package]
-name = "cloud_zeta2_prompt"
-version = "0.1.0"
-publish.workspace = true
-edition.workspace = true
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/cloud_zeta2_prompt.rs"
-
-[dependencies]
-anyhow.workspace = true
-cloud_llm_client.workspace = true
-indoc.workspace = true
-serde.workspace = true

crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs 🔗

@@ -1,485 +0,0 @@
-use anyhow::Result;
-use cloud_llm_client::predict_edits_v3::{
-    self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
-};
-use indoc::indoc;
-use std::cmp;
-use std::fmt::Write;
-use std::path::Path;
-use std::sync::Arc;
-
-pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
-
-pub const CURSOR_MARKER: &str = "<|user_cursor|>";
-/// NOTE: Differs from zed version of constant - includes a newline
-pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
-/// NOTE: Differs from zed version of constant - includes a newline
-pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
-
-const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
-    You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
-
-    ## Edit History
-
-    "#};
-
-const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
-    ---
-
-    Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
-    Do not include the cursor marker in your output.
-    If you're editing multiple files, be sure to reflect filename in the hunk's header.
-    "};
-
-const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#"
-    # Instructions
-
-    You are an edit prediction agent in a code editor.
-
-    Analyze the history of edits made by the user in order to infer what they are currently trying to accomplish.
-    Then complete the remainder of the current change if it is incomplete, or predict the next edit the user intends to make.
-    Always continue along the user's current trajectory, rather than changing course.
-
-    ## Output Format
-
-    You should briefly explain your understanding of the user's overall goal in one sentence, then explain what the next change
-    along the users current trajectory will be in another, and finally specify the next edit using the following XML-like format:
-
-    <edits path="my-project/src/myapp/cli.py">
-    <old_text>
-    OLD TEXT 1 HERE
-    </old_text>
-    <new_text>
-    NEW TEXT 1 HERE
-    </new_text>
-
-    <old_text>
-    OLD TEXT 1 HERE
-    </old_text>
-    <new_text>
-    NEW TEXT 1 HERE
-    </new_text>
-    </edits>
-
-    - Specify the file to edit using the `path` attribute.
-    - Use `<old_text>` and `<new_text>` tags to replace content
-    - `<old_text>` must exactly match existing file content, including indentation
-    - `<old_text>` cannot be empty
-    - Do not escape quotes, newlines, or other characters within tags
-    - Always close all tags properly
-    - Don't include the <|user_cursor|> marker in your output.
-
-    ## Edit History
-
-"#};
-
-const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
-    ---
-
-    Remember that the edits in the edit history have already been applied.
-"#};
-
-pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
-    let prompt_data = PromptData {
-        events: request.events.clone(),
-        cursor_point: request.cursor_point,
-        cursor_path: request.excerpt_path.clone(),
-        included_files: request.related_files.clone(),
-    };
-    match request.prompt_format {
-        PromptFormat::MinimalQwen => {
-            return Ok(MinimalQwenPrompt.render(&prompt_data));
-        }
-        PromptFormat::SeedCoder1120 => {
-            return Ok(SeedCoder1120Prompt.render(&prompt_data));
-        }
-        _ => (),
-    };
-
-    let insertions = match request.prompt_format {
-        PromptFormat::Minimal | PromptFormat::OldTextNewText => {
-            vec![(request.cursor_point, CURSOR_MARKER)]
-        }
-        PromptFormat::OnlySnippets => vec![],
-        PromptFormat::MinimalQwen => unreachable!(),
-        PromptFormat::SeedCoder1120 => unreachable!(),
-    };
-
-    let mut prompt = match request.prompt_format {
-        PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
-        PromptFormat::OnlySnippets => String::new(),
-        PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
-        PromptFormat::MinimalQwen => unreachable!(),
-        PromptFormat::SeedCoder1120 => unreachable!(),
-    };
-
-    if request.events.is_empty() {
-        prompt.push_str("(No edit history)\n\n");
-    } else {
-        let edit_preamble = if request.prompt_format == PromptFormat::Minimal {
-            "The following are the latest edits made by the user, from earlier to later.\n\n"
-        } else {
-            "Here are the latest edits made by the user, from earlier to later.\n\n"
-        };
-        prompt.push_str(edit_preamble);
-        push_events(&mut prompt, &request.events);
-    }
-
-    let excerpts_preamble = match request.prompt_format {
-        PromptFormat::Minimal => indoc! {"
-             ## Part of the file under the cursor
-
-             (The cursor marker <|user_cursor|> indicates the current user cursor position.
-             The file is in current state, edits from edit history has been applied.
-             We only show part of the file around the cursor.
-             You can only edit exactly this part of the file.
-             We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
-             "},
-        PromptFormat::OldTextNewText => indoc! {"
-            ## Code Excerpts
-
-            Here is some excerpts of code that you should take into account to predict the next edit.
-
-            The cursor position is marked by `<|user_cursor|>` as it stands after the last edit in the history.
-
-            In addition other excerpts are included to better understand what the edit will be, including the declaration
-            or references of symbols around the cursor, or other similar code snippets that may need to be updated
-            following patterns that appear in the edit history.
-
-            Consider each of them carefully in relation to the edit history, and that the user may not have navigated
-            to the next place they want to edit yet.
-
-            Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs.
-        "},
-        PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
-            indoc! {"
-            ## Code Excerpts
-
-            The cursor marker <|user_cursor|> indicates the current user cursor position.
-            The file is in current state, edits from edit history have been applied.
-        "}
-        }
-    };
-
-    prompt.push_str(excerpts_preamble);
-    prompt.push('\n');
-
-    let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
-    for related_file in &request.related_files {
-        if request.prompt_format == PromptFormat::Minimal {
-            write_codeblock_with_filename(
-                &related_file.path,
-                &related_file.excerpts,
-                if related_file.path == request.excerpt_path {
-                    &insertions
-                } else {
-                    &[]
-                },
-                related_file.max_row,
-                include_line_numbers,
-                &mut prompt,
-            );
-        } else {
-            write_codeblock(
-                &related_file.path,
-                &related_file.excerpts,
-                if related_file.path == request.excerpt_path {
-                    &insertions
-                } else {
-                    &[]
-                },
-                related_file.max_row,
-                include_line_numbers,
-                &mut prompt,
-            );
-        }
-    }
-
-    match request.prompt_format {
-        PromptFormat::OldTextNewText => {
-            prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
-        }
-        PromptFormat::Minimal => {
-            prompt.push_str(MINIMAL_PROMPT_REMINDER);
-        }
-        _ => {}
-    }
-
-    Ok(prompt)
-}
-
-pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
-    match prompt_format {
-        PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(),
-        _ => GenerationParams::default(),
-    }
-}
-
-pub fn write_codeblock<'a>(
-    path: &Path,
-    excerpts: impl IntoIterator<Item = &'a Excerpt>,
-    sorted_insertions: &[(Point, &str)],
-    file_line_count: Line,
-    include_line_numbers: bool,
-    output: &'a mut String,
-) {
-    writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
-
-    write_excerpts(
-        excerpts,
-        sorted_insertions,
-        file_line_count,
-        include_line_numbers,
-        output,
-    );
-    write!(output, "`````\n\n").unwrap();
-}
-
-fn write_codeblock_with_filename<'a>(
-    path: &Path,
-    excerpts: impl IntoIterator<Item = &'a Excerpt>,
-    sorted_insertions: &[(Point, &str)],
-    file_line_count: Line,
-    include_line_numbers: bool,
-    output: &'a mut String,
-) {
-    writeln!(output, "`````filename={}", DiffPathFmt(path)).unwrap();
-
-    write_excerpts(
-        excerpts,
-        sorted_insertions,
-        file_line_count,
-        include_line_numbers,
-        output,
-    );
-    write!(output, "`````\n\n").unwrap();
-}
-
-pub fn write_excerpts<'a>(
-    excerpts: impl IntoIterator<Item = &'a Excerpt>,
-    sorted_insertions: &[(Point, &str)],
-    file_line_count: Line,
-    include_line_numbers: bool,
-    output: &mut String,
-) {
-    let mut current_row = Line(0);
-    let mut sorted_insertions = sorted_insertions.iter().peekable();
-
-    for excerpt in excerpts {
-        if excerpt.start_line > current_row {
-            writeln!(output, "…").unwrap();
-        }
-        if excerpt.text.is_empty() {
-            return;
-        }
-
-        current_row = excerpt.start_line;
-
-        for mut line in excerpt.text.lines() {
-            if include_line_numbers {
-                write!(output, "{}|", current_row.0 + 1).unwrap();
-            }
-
-            while let Some((insertion_location, insertion_marker)) = sorted_insertions.peek() {
-                match current_row.cmp(&insertion_location.line) {
-                    cmp::Ordering::Equal => {
-                        let (prefix, suffix) = line.split_at(insertion_location.column as usize);
-                        output.push_str(prefix);
-                        output.push_str(insertion_marker);
-                        line = suffix;
-                        sorted_insertions.next();
-                    }
-                    cmp::Ordering::Less => break,
-                    cmp::Ordering::Greater => {
-                        sorted_insertions.next();
-                        break;
-                    }
-                }
-            }
-            output.push_str(line);
-            output.push('\n');
-            current_row.0 += 1;
-        }
-    }
-
-    if current_row < file_line_count {
-        writeln!(output, "…").unwrap();
-    }
-}
-
-pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>]) {
-    if events.is_empty() {
-        return;
-    };
-
-    writeln!(output, "`````diff").unwrap();
-    for event in events {
-        writeln!(output, "{}", event).unwrap();
-    }
-    writeln!(output, "`````\n").unwrap();
-}
-
-struct PromptData {
-    events: Vec<Arc<Event>>,
-    cursor_point: Point,
-    cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
-    included_files: Vec<RelatedFile>,
-}
-
-#[derive(Default)]
-pub struct GenerationParams {
-    pub temperature: Option<f32>,
-    pub top_p: Option<f32>,
-    pub stop: Option<Vec<String>>,
-}
-
-trait PromptFormatter {
-    fn render(&self, data: &PromptData) -> String;
-
-    fn generation_params() -> GenerationParams {
-        return GenerationParams::default();
-    }
-}
-
-struct MinimalQwenPrompt;
-
-impl PromptFormatter for MinimalQwenPrompt {
-    fn render(&self, data: &PromptData) -> String {
-        let edit_history = self.fmt_edit_history(data);
-        let context = self.fmt_context(data);
-
-        format!(
-            "{instructions}\n\n{edit_history}\n\n{context}",
-            instructions = MinimalQwenPrompt::INSTRUCTIONS,
-            edit_history = edit_history,
-            context = context
-        )
-    }
-}
-
-impl MinimalQwenPrompt {
-    const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n";
-
-    fn fmt_edit_history(&self, data: &PromptData) -> String {
-        if data.events.is_empty() {
-            "(No edit history)\n\n".to_string()
-        } else {
-            let mut events_str = String::new();
-            push_events(&mut events_str, &data.events);
-            format!(
-                "The following are the latest edits made by the user, from earlier to later.\n\n{}",
-                events_str
-            )
-        }
-    }
-
-    fn fmt_context(&self, data: &PromptData) -> String {
-        let mut context = String::new();
-        let include_line_numbers = true;
-
-        for related_file in &data.included_files {
-            writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap();
-
-            if related_file.path == data.cursor_path {
-                write!(context, "<|fim_prefix|>").unwrap();
-                write_excerpts(
-                    &related_file.excerpts,
-                    &[(data.cursor_point, "<|fim_suffix|>")],
-                    related_file.max_row,
-                    include_line_numbers,
-                    &mut context,
-                );
-                writeln!(context, "<|fim_middle|>").unwrap();
-            } else {
-                write_excerpts(
-                    &related_file.excerpts,
-                    &[],
-                    related_file.max_row,
-                    include_line_numbers,
-                    &mut context,
-                );
-            }
-        }
-        context
-    }
-}
-
-struct SeedCoder1120Prompt;
-
-impl PromptFormatter for SeedCoder1120Prompt {
-    fn render(&self, data: &PromptData) -> String {
-        let edit_history = self.fmt_edit_history(data);
-        let context = self.fmt_context(data);
-
-        format!(
-            "# Edit History:\n{edit_history}\n\n{context}",
-            edit_history = edit_history,
-            context = context
-        )
-    }
-
-    fn generation_params() -> GenerationParams {
-        GenerationParams {
-            temperature: Some(0.2),
-            top_p: Some(0.9),
-            stop: Some(vec!["<[end_of_sentence]>".into()]),
-        }
-    }
-}
-
-impl SeedCoder1120Prompt {
-    fn fmt_edit_history(&self, data: &PromptData) -> String {
-        if data.events.is_empty() {
-            "(No edit history)\n\n".to_string()
-        } else {
-            let mut events_str = String::new();
-            push_events(&mut events_str, &data.events);
-            events_str
-        }
-    }
-
-    fn fmt_context(&self, data: &PromptData) -> String {
-        let mut context = String::new();
-        let include_line_numbers = true;
-
-        for related_file in &data.included_files {
-            writeln!(context, "# Path: {}\n", DiffPathFmt(&related_file.path)).unwrap();
-
-            if related_file.path == data.cursor_path {
-                let fim_prompt = self.fmt_fim(&related_file, data.cursor_point);
-                context.push_str(&fim_prompt);
-            } else {
-                write_excerpts(
-                    &related_file.excerpts,
-                    &[],
-                    related_file.max_row,
-                    include_line_numbers,
-                    &mut context,
-                );
-            }
-        }
-        context
-    }
-
-    fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
-        let mut buf = String::new();
-        const FIM_SUFFIX: &str = "<[fim-suffix]>";
-        const FIM_PREFIX: &str = "<[fim-prefix]>";
-        const FIM_MIDDLE: &str = "<[fim-middle]>";
-        write!(buf, "{}", FIM_PREFIX).unwrap();
-        write_excerpts(
-            &file.excerpts,
-            &[(cursor_point, FIM_SUFFIX)],
-            file.max_row,
-            true,
-            &mut buf,
-        );
-
-        // Swap prefix and suffix parts
-        let index = buf.find(FIM_SUFFIX).unwrap();
-        let prefix = &buf[..index];
-        let suffix = &buf[index..];
-
-        format!("{}{}{}", suffix, prefix, FIM_MIDDLE)
-    }
-}

crates/edit_prediction/Cargo.toml 🔗

@@ -21,7 +21,6 @@ arrayvec.workspace = true
 brotli.workspace = true
 client.workspace = true
 cloud_llm_client.workspace = true
-cloud_zeta2_prompt.workspace = true
 collections.workspace = true
 copilot.workspace = true
 credentials_provider.workspace = true
@@ -50,8 +49,6 @@ semver.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true
-smol.workspace = true
-strsim.workspace = true
 strum.workspace = true
 telemetry.workspace = true
 telemetry_events.workspace = true
@@ -62,6 +59,7 @@ uuid.workspace = true
 workspace.workspace = true
 worktree.workspace = true
 zed_actions.workspace = true
+zeta_prompt.workspace = true
 
 [dev-dependencies]
 clock = { workspace = true, features = ["test-support"] }

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -1,14 +1,13 @@
 use anyhow::Result;
 use arrayvec::ArrayVec;
 use client::{Client, EditPredictionUsage, UserStore};
-use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
+use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
 use cloud_llm_client::{
     AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
     EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
     MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
     ZED_VERSION_HEADER_NAME,
 };
-use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
 use collections::{HashMap, HashSet};
 use db::kvp::{Dismissable, KEY_VALUE_STORE};
 use edit_prediction_context::EditPredictionExcerptOptions;
@@ -16,10 +15,7 @@ use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, Rel
 use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
 use futures::{
     AsyncReadExt as _, FutureExt as _, StreamExt as _,
-    channel::{
-        mpsc::{self, UnboundedReceiver},
-        oneshot,
-    },
+    channel::mpsc::{self, UnboundedReceiver},
     select_biased,
 };
 use gpui::BackgroundExecutor;
@@ -58,8 +54,10 @@ mod onboarding_modal;
 pub mod open_ai_response;
 mod prediction;
 pub mod sweep_ai;
+
+#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
 pub mod udiff;
-mod xml_edits;
+
 mod zed_edit_prediction_delegate;
 pub mod zeta1;
 pub mod zeta2;
@@ -72,7 +70,6 @@ use crate::mercury::Mercury;
 use crate::onboarding_modal::ZedPredictModal;
 pub use crate::prediction::EditPrediction;
 pub use crate::prediction::EditPredictionId;
-pub use crate::prediction::EditPredictionInputs;
 use crate::prediction::EditPredictionResult;
 pub use crate::sweep_ai::SweepAi;
 pub use telemetry_events::EditPredictionRating;
@@ -112,7 +109,6 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
         min_bytes: 128,
         target_before_cursor_over_total_bytes: 0.5,
     },
-    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
     prompt_format: PromptFormat::DEFAULT,
 };
 
@@ -162,7 +158,6 @@ pub struct EditPredictionStore {
     use_context: bool,
     options: ZetaOptions,
     update_required: bool,
-    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
     #[cfg(feature = "eval-support")]
     eval_cache: Option<Arc<dyn EvalCache>>,
     edit_prediction_model: EditPredictionModel,
@@ -183,10 +178,22 @@ pub enum EditPredictionModel {
     Mercury,
 }
 
+pub struct EditPredictionModelInput {
+    project: Entity<Project>,
+    buffer: Entity<Buffer>,
+    snapshot: BufferSnapshot,
+    position: Anchor,
+    events: Vec<Arc<zeta_prompt::Event>>,
+    related_files: Arc<[RelatedFile]>,
+    recent_paths: VecDeque<ProjectPath>,
+    trigger: PredictEditsRequestTrigger,
+    diagnostic_search_range: Range<Point>,
+    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
+}
+
 #[derive(Debug, Clone, PartialEq)]
 pub struct ZetaOptions {
     pub context: EditPredictionExcerptOptions,
-    pub max_prompt_bytes: usize,
     pub prompt_format: predict_edits_v3::PromptFormat,
 }
 
@@ -194,7 +201,8 @@ pub struct ZetaOptions {
 pub enum DebugEvent {
     ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
     ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
-    EditPredictionRequested(EditPredictionRequestedDebugEvent),
+    EditPredictionStarted(EditPredictionStartedDebugEvent),
+    EditPredictionFinished(EditPredictionFinishedDebugEvent),
 }
 
 #[derive(Debug)]
@@ -212,27 +220,30 @@ pub struct ContextRetrievalFinishedDebugEvent {
 }
 
 #[derive(Debug)]
-pub struct EditPredictionRequestedDebugEvent {
-    pub inputs: EditPredictionInputs,
-    pub retrieval_time: Duration,
+pub struct EditPredictionStartedDebugEvent {
     pub buffer: WeakEntity<Buffer>,
     pub position: Anchor,
-    pub local_prompt: Result<String, String>,
-    pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
+    pub prompt: Option<String>,
+}
+
+#[derive(Debug)]
+pub struct EditPredictionFinishedDebugEvent {
+    pub buffer: WeakEntity<Buffer>,
+    pub position: Anchor,
+    pub model_output: Option<String>,
 }
 
 pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 
 struct ProjectState {
-    events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
+    events: VecDeque<Arc<zeta_prompt::Event>>,
     last_event: Option<LastEvent>,
     recent_paths: VecDeque<ProjectPath>,
     registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
     current_prediction: Option<CurrentEditPrediction>,
     next_pending_prediction_id: usize,
     pending_predictions: ArrayVec<PendingPrediction, 2>,
-    context_updates_tx: smol::channel::Sender<()>,
-    context_updates_rx: smol::channel::Receiver<()>,
+    debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
     last_prediction_refresh: Option<(EntityId, Instant)>,
     cancelled_predictions: HashSet<usize>,
     context: Entity<RelatedExcerptStore>,
@@ -241,7 +252,7 @@ struct ProjectState {
 }
 
 impl ProjectState {
-    pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
+    pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
         self.events
             .iter()
             .cloned()
@@ -376,7 +387,7 @@ impl LastEvent {
         &self,
         license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
         cx: &App,
-    ) -> Option<Arc<predict_edits_v3::Event>> {
+    ) -> Option<Arc<zeta_prompt::Event>> {
         let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
         let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
 
@@ -396,7 +407,7 @@ impl LastEvent {
         if path == old_path && diff.is_empty() {
             None
         } else {
-            Some(Arc::new(predict_edits_v3::Event::BufferChange {
+            Some(Arc::new(zeta_prompt::Event::BufferChange {
                 old_path,
                 path,
                 diff,
@@ -481,7 +492,6 @@ impl EditPredictionStore {
                 },
             ),
             update_required: false,
-            debug_tx: None,
             #[cfg(feature = "eval-support")]
             eval_cache: None,
             edit_prediction_model: EditPredictionModel::Zeta2,
@@ -536,12 +546,6 @@ impl EditPredictionStore {
         self.eval_cache = Some(cache);
     }
 
-    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
-        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
-        self.debug_tx = Some(debug_watch_tx);
-        debug_watch_rx
-    }
-
     pub fn options(&self) -> &ZetaOptions {
         &self.options
     }
@@ -560,15 +564,35 @@ impl EditPredictionStore {
         }
     }
 
+    pub fn edit_history_for_project(
+        &self,
+        project: &Entity<Project>,
+    ) -> Vec<Arc<zeta_prompt::Event>> {
+        self.projects
+            .get(&project.entity_id())
+            .map(|project_state| project_state.events.iter().cloned().collect())
+            .unwrap_or_default()
+    }
+
     pub fn context_for_project<'a>(
         &'a self,
         project: &Entity<Project>,
         cx: &'a App,
-    ) -> &'a [RelatedFile] {
+    ) -> Arc<[RelatedFile]> {
         self.projects
             .get(&project.entity_id())
             .map(|project| project.context.read(cx).related_files())
-            .unwrap_or(&[])
+            .unwrap_or_else(|| vec![].into())
+    }
+
+    pub fn context_for_project_with_buffers<'a>(
+        &'a self,
+        project: &Entity<Project>,
+        cx: &'a App,
+    ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
+        self.projects
+            .get(&project.entity_id())
+            .map(|project| project.context.read(cx).related_files_with_buffers())
     }
 
     pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
@@ -599,85 +623,21 @@ impl EditPredictionStore {
         cx: &mut Context<Self>,
     ) -> &mut ProjectState {
         let entity_id = project.entity_id();
-        let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
         self.projects
             .entry(entity_id)
             .or_insert_with(|| ProjectState {
                 context: {
                     let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
-                    cx.subscribe(
-                        &related_excerpt_store,
-                        move |this, _, event, _| match event {
-                            RelatedExcerptStoreEvent::StartedRefresh => {
-                                if let Some(debug_tx) = this.debug_tx.clone() {
-                                    debug_tx
-                                        .unbounded_send(DebugEvent::ContextRetrievalStarted(
-                                            ContextRetrievalStartedDebugEvent {
-                                                project_entity_id: entity_id,
-                                                timestamp: Instant::now(),
-                                                search_prompt: String::new(),
-                                            },
-                                        ))
-                                        .ok();
-                                }
-                            }
-                            RelatedExcerptStoreEvent::FinishedRefresh {
-                                cache_hit_count,
-                                cache_miss_count,
-                                mean_definition_latency,
-                                max_definition_latency,
-                            } => {
-                                if let Some(debug_tx) = this.debug_tx.clone() {
-                                    debug_tx
-                                        .unbounded_send(DebugEvent::ContextRetrievalFinished(
-                                            ContextRetrievalFinishedDebugEvent {
-                                                project_entity_id: entity_id,
-                                                timestamp: Instant::now(),
-                                                metadata: vec![
-                                                    (
-                                                        "Cache Hits",
-                                                        format!(
-                                                            "{}/{}",
-                                                            cache_hit_count,
-                                                            cache_hit_count + cache_miss_count
-                                                        )
-                                                        .into(),
-                                                    ),
-                                                    (
-                                                        "Max LSP Time",
-                                                        format!(
-                                                            "{} ms",
-                                                            max_definition_latency.as_millis()
-                                                        )
-                                                        .into(),
-                                                    ),
-                                                    (
-                                                        "Mean LSP Time",
-                                                        format!(
-                                                            "{} ms",
-                                                            mean_definition_latency.as_millis()
-                                                        )
-                                                        .into(),
-                                                    ),
-                                                ],
-                                            },
-                                        ))
-                                        .ok();
-                                }
-                                if let Some(project_state) = this.projects.get(&entity_id) {
-                                    project_state.context_updates_tx.send_blocking(()).ok();
-                                }
-                            }
-                        },
-                    )
+                    cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
+                        this.handle_excerpt_store_event(entity_id, event);
+                    })
                     .detach();
                     related_excerpt_store
                 },
                 events: VecDeque::new(),
                 last_event: None,
                 recent_paths: VecDeque::new(),
-                context_updates_rx,
-                context_updates_tx,
+                debug_tx: None,
                 registered_buffers: HashMap::default(),
                 current_prediction: None,
                 cancelled_predictions: HashSet::default(),
@@ -689,12 +649,79 @@ impl EditPredictionStore {
             })
     }
 
-    pub fn project_context_updates(
-        &self,
+    pub fn remove_project(&mut self, project: &Entity<Project>) {
+        self.projects.remove(&project.entity_id());
+    }
+
+    fn handle_excerpt_store_event(
+        &mut self,
+        project_entity_id: EntityId,
+        event: &RelatedExcerptStoreEvent,
+    ) {
+        if let Some(project_state) = self.projects.get(&project_entity_id) {
+            if let Some(debug_tx) = project_state.debug_tx.clone() {
+                match event {
+                    RelatedExcerptStoreEvent::StartedRefresh => {
+                        debug_tx
+                            .unbounded_send(DebugEvent::ContextRetrievalStarted(
+                                ContextRetrievalStartedDebugEvent {
+                                    project_entity_id: project_entity_id,
+                                    timestamp: Instant::now(),
+                                    search_prompt: String::new(),
+                                },
+                            ))
+                            .ok();
+                    }
+                    RelatedExcerptStoreEvent::FinishedRefresh {
+                        cache_hit_count,
+                        cache_miss_count,
+                        mean_definition_latency,
+                        max_definition_latency,
+                    } => {
+                        debug_tx
+                            .unbounded_send(DebugEvent::ContextRetrievalFinished(
+                                ContextRetrievalFinishedDebugEvent {
+                                    project_entity_id: project_entity_id,
+                                    timestamp: Instant::now(),
+                                    metadata: vec![
+                                        (
+                                            "Cache Hits",
+                                            format!(
+                                                "{}/{}",
+                                                cache_hit_count,
+                                                cache_hit_count + cache_miss_count
+                                            )
+                                            .into(),
+                                        ),
+                                        (
+                                            "Max LSP Time",
+                                            format!("{} ms", max_definition_latency.as_millis())
+                                                .into(),
+                                        ),
+                                        (
+                                            "Mean LSP Time",
+                                            format!("{} ms", mean_definition_latency.as_millis())
+                                                .into(),
+                                        ),
+                                    ],
+                                },
+                            ))
+                            .ok();
+                    }
+                }
+            }
+        }
+    }
+
+    pub fn debug_info(
+        &mut self,
         project: &Entity<Project>,
-    ) -> Option<smol::channel::Receiver<()>> {
-        let project_state = self.projects.get(&project.entity_id())?;
-        Some(project_state.context_updates_rx.clone())
+        cx: &mut Context<Self>,
+    ) -> mpsc::UnboundedReceiver<DebugEvent> {
+        let project_state = self.get_or_init_project(project, cx);
+        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
+        project_state.debug_tx = Some(debug_watch_tx);
+        debug_watch_rx
     }
 
     fn handle_project_event(
@@ -1348,6 +1375,7 @@ impl EditPredictionStore {
         let project_state = self.projects.get(&project.entity_id()).unwrap();
         let events = project_state.events(cx);
         let has_events = !events.is_empty();
+        let debug_tx = project_state.debug_tx.clone();
 
         let snapshot = active_buffer.read(cx).snapshot();
         let cursor_point = position.to_point(&snapshot);
@@ -1357,55 +1385,29 @@ impl EditPredictionStore {
             Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
 
         let related_files = if self.use_context {
-            self.context_for_project(&project, cx).to_vec()
+            self.context_for_project(&project, cx)
         } else {
-            Vec::new()
+            Vec::new().into()
+        };
+
+        let inputs = EditPredictionModelInput {
+            project: project.clone(),
+            buffer: active_buffer.clone(),
+            snapshot: snapshot.clone(),
+            position,
+            events,
+            related_files,
+            recent_paths: project_state.recent_paths.clone(),
+            trigger,
+            diagnostic_search_range: diagnostic_search_range.clone(),
+            debug_tx,
         };
 
         let task = match self.edit_prediction_model {
-            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
-                self,
-                &project,
-                &active_buffer,
-                snapshot.clone(),
-                position,
-                events,
-                trigger,
-                cx,
-            ),
-            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
-                self,
-                &project,
-                &active_buffer,
-                snapshot.clone(),
-                position,
-                events,
-                related_files,
-                trigger,
-                cx,
-            ),
-            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
-                &project,
-                &active_buffer,
-                snapshot.clone(),
-                position,
-                events,
-                &project_state.recent_paths,
-                related_files,
-                diagnostic_search_range.clone(),
-                cx,
-            ),
-            EditPredictionModel::Mercury => self.mercury.request_prediction(
-                &project,
-                &active_buffer,
-                snapshot.clone(),
-                position,
-                events,
-                &project_state.recent_paths,
-                related_files,
-                diagnostic_search_range.clone(),
-                cx,
-            ),
+            EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
+            EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
+            EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
+            EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
         };
 
         cx.spawn(async move |this, cx| {
@@ -1706,6 +1708,20 @@ impl EditPredictionStore {
         }
     }
 
+    #[cfg(feature = "eval-support")]
+    pub fn set_context_for_buffer(
+        &mut self,
+        project: &Entity<Project>,
+        related_files: Vec<RelatedFile>,
+        cx: &mut Context<Self>,
+    ) {
+        self.get_or_init_project(project, cx)
+            .context
+            .update(cx, |store, _| {
+                store.set_related_files(related_files);
+            });
+    }
+
     fn is_file_open_source(
         &self,
         project: &Entity<Project>,
@@ -1729,14 +1745,14 @@ impl EditPredictionStore {
         self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
     }
 
-    fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
+    fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
         if !self.data_collection_choice.is_enabled() {
             return false;
         }
         events.iter().all(|event| {
             matches!(
                 event.as_ref(),
-                Event::BufferChange {
+                zeta_prompt::Event::BufferChange {
                     in_open_source_repo: true,
                     ..
                 }

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -1,5 +1,5 @@
 use super::*;
-use crate::zeta1::MAX_EVENT_TOKENS;
+use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
 use client::{UserStore, test::FakeServer};
 use clock::{FakeSystemClock, ReplicaId};
 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -7,7 +7,6 @@ use cloud_llm_client::{
     EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
     RejectEditPredictionsBody,
 };
-use edit_prediction_context::Line;
 use futures::{
     AsyncReadExt, StreamExt,
     channel::{mpsc, oneshot},
@@ -28,6 +27,7 @@ use settings::SettingsStore;
 use std::{path::Path, sync::Arc, time::Duration};
 use util::{path, rel_path::rel_path};
 use uuid::Uuid;
+use zeta_prompt::ZetaPromptInput;
 
 use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
 
@@ -65,18 +65,21 @@ async fn test_current_state(cx: &mut TestAppContext) {
     ep_store.update(cx, |ep_store, cx| {
         ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
     });
-    let (_request, respond_tx) = requests.predict.next().await.unwrap();
+    let (request, respond_tx) = requests.predict.next().await.unwrap();
 
     respond_tx
-        .send(model_response(indoc! {r"
-            --- a/root/1.txt
-            +++ b/root/1.txt
-            @@ ... @@
-             Hello!
-            -How
-            +How are you?
-             Bye
-        "}))
+        .send(model_response(
+            request,
+            indoc! {r"
+                --- a/root/1.txt
+                +++ b/root/1.txt
+                @@ ... @@
+                 Hello!
+                -How
+                +How are you?
+                 Bye
+            "},
+        ))
         .unwrap();
 
     cx.run_until_parked();
@@ -120,16 +123,20 @@ async fn test_current_state(cx: &mut TestAppContext) {
         });
     });
 
-    let (_request, respond_tx) = requests.predict.next().await.unwrap();
+    let (request, respond_tx) = requests.predict.next().await.unwrap();
     respond_tx
-        .send(model_response(indoc! {r#"
-            --- a/root/2.txt
-            +++ b/root/2.txt
-             Hola!
-            -Como
-            +Como estas?
-             Adios
-        "#}))
+        .send(model_response(
+            request,
+            indoc! {r#"
+                --- a/root/2.txt
+                +++ b/root/2.txt
+                @@ ... @@
+                 Hola!
+                -Como
+                +Como estas?
+                 Adios
+            "#},
+        ))
         .unwrap();
     cx.run_until_parked();
 
@@ -186,7 +193,7 @@ async fn test_simple_request(cx: &mut TestAppContext) {
         ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
     });
 
-    let (_, respond_tx) = requests.predict.next().await.unwrap();
+    let (request, respond_tx) = requests.predict.next().await.unwrap();
 
     // TODO Put back when we have a structured request again
     // assert_eq!(
@@ -202,15 +209,18 @@ async fn test_simple_request(cx: &mut TestAppContext) {
     // );
 
     respond_tx
-        .send(model_response(indoc! { r"
-            --- a/root/foo.md
-            +++ b/root/foo.md
-            @@ ... @@
-             Hello!
-            -How
-            +How are you?
-             Bye
-        "}))
+        .send(model_response(
+            request,
+            indoc! { r"
+                --- a/root/foo.md
+                +++ b/root/foo.md
+                @@ ... @@
+                 Hello!
+                -How
+                +How are you?
+                 Bye
+            "},
+        ))
         .unwrap();
 
     let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -276,15 +286,18 @@ async fn test_request_events(cx: &mut TestAppContext) {
     );
 
     respond_tx
-        .send(model_response(indoc! {r#"
-            --- a/root/foo.md
-            +++ b/root/foo.md
-            @@ ... @@
-             Hello!
-            -How
-            +How are you?
-             Bye
-        "#}))
+        .send(model_response(
+            request,
+            indoc! {r#"
+                --- a/root/foo.md
+                +++ b/root/foo.md
+                @@ ... @@
+                 Hello!
+                -How
+                +How are you?
+                 Bye
+        "#},
+        ))
         .unwrap();
 
     let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -324,18 +337,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
         ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
     });
 
-    const NO_OP_DIFF: &str = indoc! { r"
-        --- a/root/foo.md
-        +++ b/root/foo.md
-        @@ ... @@
-         Hello!
-        -How
-        +How
-         Bye
-    "};
-
-    let (_, respond_tx) = requests.predict.next().await.unwrap();
-    let response = model_response(NO_OP_DIFF);
+    let (request, respond_tx) = requests.predict.next().await.unwrap();
+    let response = model_response(request, "");
     let id = response.id.clone();
     respond_tx.send(response).unwrap();
 
@@ -389,13 +392,13 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
         ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
     });
 
-    let (_, respond_tx) = requests.predict.next().await.unwrap();
+    let (request, respond_tx) = requests.predict.next().await.unwrap();
 
     buffer.update(cx, |buffer, cx| {
         buffer.set_text("Hello!\nHow are you?\nBye", cx);
     });
 
-    let response = model_response(SIMPLE_DIFF);
+    let response = model_response(request, SIMPLE_DIFF);
     let id = response.id.clone();
     respond_tx.send(response).unwrap();
 
@@ -459,8 +462,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
         ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
     });
 
-    let (_, respond_tx) = requests.predict.next().await.unwrap();
-    let first_response = model_response(SIMPLE_DIFF);
+    let (request, respond_tx) = requests.predict.next().await.unwrap();
+    let first_response = model_response(request, SIMPLE_DIFF);
     let first_id = first_response.id.clone();
     respond_tx.send(first_response).unwrap();
 
@@ -482,8 +485,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
         ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
     });
 
-    let (_, respond_tx) = requests.predict.next().await.unwrap();
-    let second_response = model_response(SIMPLE_DIFF);
+    let (request, respond_tx) = requests.predict.next().await.unwrap();
+    let second_response = model_response(request, SIMPLE_DIFF);
     let second_id = second_response.id.clone();
     respond_tx.send(second_response).unwrap();
 
@@ -541,8 +544,8 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
         ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
     });
 
-    let (_, respond_tx) = requests.predict.next().await.unwrap();
-    let first_response = model_response(SIMPLE_DIFF);
+    let (request, respond_tx) = requests.predict.next().await.unwrap();
+    let first_response = model_response(request, SIMPLE_DIFF);
     let first_id = first_response.id.clone();
     respond_tx.send(first_response).unwrap();
 
@@ -564,17 +567,20 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
         ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
     });
 
-    let (_, respond_tx) = requests.predict.next().await.unwrap();
+    let (request, respond_tx) = requests.predict.next().await.unwrap();
     // worse than current prediction
-    let second_response = model_response(indoc! { r"
-        --- a/root/foo.md
-        +++ b/root/foo.md
-        @@ ... @@
-         Hello!
-        -How
-        +How are
-         Bye
-    "});
+    let second_response = model_response(
+        request,
+        indoc! { r"
+            --- a/root/foo.md
+            +++ b/root/foo.md
+            @@ ... @@
+             Hello!
+            -How
+            +How are
+             Bye
+        "},
+    );
     let second_id = second_response.id.clone();
     respond_tx.send(second_response).unwrap();
 
@@ -633,19 +639,19 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
         ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
     });
 
-    let (_, respond_first) = requests.predict.next().await.unwrap();
+    let (request1, respond_first) = requests.predict.next().await.unwrap();
 
     ep_store.update(cx, |ep_store, cx| {
         ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
     });
 
-    let (_, respond_second) = requests.predict.next().await.unwrap();
+    let (request, respond_second) = requests.predict.next().await.unwrap();
 
     // wait for throttle
     cx.run_until_parked();
 
     // second responds first
-    let second_response = model_response(SIMPLE_DIFF);
+    let second_response = model_response(request, SIMPLE_DIFF);
     let second_id = second_response.id.clone();
     respond_second.send(second_response).unwrap();
 
@@ -663,7 +669,7 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
         );
     });
 
-    let first_response = model_response(SIMPLE_DIFF);
+    let first_response = model_response(request1, SIMPLE_DIFF);
     let first_id = first_response.id.clone();
     respond_first.send(first_response).unwrap();
 
@@ -724,13 +730,13 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
         ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
     });
 
-    let (_, respond_first) = requests.predict.next().await.unwrap();
+    let (request1, respond_first) = requests.predict.next().await.unwrap();
 
     ep_store.update(cx, |ep_store, cx| {
         ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
     });
 
-    let (_, respond_second) = requests.predict.next().await.unwrap();
+    let (request2, respond_second) = requests.predict.next().await.unwrap();
 
     // wait for throttle, so requests are sent
     cx.run_until_parked();
@@ -754,9 +760,9 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
     // wait for throttle
     cx.run_until_parked();
 
-    let (_, respond_third) = requests.predict.next().await.unwrap();
+    let (request3, respond_third) = requests.predict.next().await.unwrap();
 
-    let first_response = model_response(SIMPLE_DIFF);
+    let first_response = model_response(request1, SIMPLE_DIFF);
     let first_id = first_response.id.clone();
     respond_first.send(first_response).unwrap();
 
@@ -774,7 +780,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
         );
     });
 
-    let cancelled_response = model_response(SIMPLE_DIFF);
+    let cancelled_response = model_response(request2, SIMPLE_DIFF);
     let cancelled_id = cancelled_response.id.clone();
     respond_second.send(cancelled_response).unwrap();
 
@@ -792,7 +798,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
         );
     });
 
-    let third_response = model_response(SIMPLE_DIFF);
+    let third_response = model_response(request3, SIMPLE_DIFF);
     let third_response_id = third_response.id.clone();
     respond_third.send(third_response).unwrap();
 
@@ -1036,7 +1042,24 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
 //     );
 // }
 
-fn model_response(text: &str) -> open_ai::Response {
+// Generate a model response that would apply the given diff to the active file.
+fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
+    let prompt = match &request.messages[0] {
+        open_ai::RequestMessage::User {
+            content: open_ai::MessageContent::Plain(content),
+        } => content,
+        _ => panic!("unexpected request {request:?}"),
+    };
+
+    let open = "<editable_region>\n";
+    let close = "</editable_region>";
+    let cursor = "<|user_cursor|>";
+
+    let start_ix = open.len() + prompt.find(open).unwrap();
+    let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
+    let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
+    let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
+
     open_ai::Response {
         id: Uuid::new_v4().to_string(),
         object: "response".into(),
@@ -1045,7 +1068,7 @@ fn model_response(text: &str) -> open_ai::Response {
         choices: vec![open_ai::Choice {
             index: 0,
             message: open_ai::RequestMessage::Assistant {
-                content: Some(open_ai::MessageContent::Plain(text.to_string())),
+                content: Some(open_ai::MessageContent::Plain(new_excerpt)),
                 tool_calls: vec![],
             },
             finish_reason: None,
@@ -1160,20 +1183,19 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
         .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
         .await;
 
-    let completion = EditPrediction {
+    let prediction = EditPrediction {
         edits,
         edit_preview,
         buffer: buffer.clone(),
         snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
         id: EditPredictionId("the-id".into()),
-        inputs: EditPredictionInputs {
+        inputs: ZetaPromptInput {
             events: Default::default(),
-            included_files: Default::default(),
-            cursor_point: cloud_llm_client::predict_edits_v3::Point {
-                line: Line(0),
-                column: 0,
-            },
+            related_files: Default::default(),
             cursor_path: Path::new("").into(),
+            cursor_excerpt: "".into(),
+            editable_range_in_excerpt: 0..0,
+            cursor_offset_in_excerpt: 0,
         },
         buffer_snapshotted_at: Instant::now(),
         response_received_at: Instant::now(),
@@ -1182,7 +1204,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
     cx.update(|cx| {
         assert_eq!(
             from_completion_edits(
-                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
                 &buffer,
                 cx
             ),
@@ -1192,7 +1214,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
         buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
         assert_eq!(
             from_completion_edits(
-                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
                 &buffer,
                 cx
             ),
@@ -1202,7 +1224,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
         buffer.update(cx, |buffer, cx| buffer.undo(cx));
         assert_eq!(
             from_completion_edits(
-                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
                 &buffer,
                 cx
             ),
@@ -1212,7 +1234,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
         buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
         assert_eq!(
             from_completion_edits(
-                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
                 &buffer,
                 cx
             ),
@@ -1222,7 +1244,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
         buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
         assert_eq!(
             from_completion_edits(
-                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
                 &buffer,
                 cx
             ),
@@ -1232,7 +1254,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
         buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
         assert_eq!(
             from_completion_edits(
-                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
                 &buffer,
                 cx
             ),
@@ -1242,7 +1264,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
         buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
         assert_eq!(
             from_completion_edits(
-                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
                 &buffer,
                 cx
             ),
@@ -1252,7 +1274,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
         buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
         assert_eq!(
             from_completion_edits(
-                &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+                &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
                 &buffer,
                 cx
             ),
@@ -1260,7 +1282,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
         );
 
         buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
-        assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
+        assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
     })
 }
 

crates/edit_prediction/src/mercury.rs 🔗

@@ -1,20 +1,17 @@
 use anyhow::{Context as _, Result};
-use cloud_llm_client::predict_edits_v3::Event;
 use credentials_provider::CredentialsProvider;
-use edit_prediction_context::RelatedFile;
 use futures::{AsyncReadExt as _, FutureExt, future::Shared};
 use gpui::{
-    App, AppContext as _, Entity, Task,
+    App, AppContext as _, Task,
     http_client::{self, AsyncBody, Method},
 };
-use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
-use project::{Project, ProjectPath};
-use std::{
-    collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
-};
+use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
+use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
+use zeta_prompt::ZetaPromptInput;
 
 use crate::{
-    EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
+    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
+    EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
     prediction::EditPredictionResult,
 };
 
@@ -38,16 +35,17 @@ impl Mercury {
         store_api_token_in_keychain(api_token, cx)
     }
 
-    pub fn request_prediction(
+    pub(crate) fn request_prediction(
         &self,
-        _project: &Entity<Project>,
-        active_buffer: &Entity<Buffer>,
-        snapshot: BufferSnapshot,
-        position: language::Anchor,
-        events: Vec<Arc<Event>>,
-        _recent_paths: &VecDeque<ProjectPath>,
-        related_files: Vec<RelatedFile>,
-        _diagnostic_search_range: Range<Point>,
+        EditPredictionModelInput {
+            buffer,
+            snapshot,
+            position,
+            events,
+            related_files,
+            debug_tx,
+            ..
+        }: EditPredictionModelInput,
         cx: &mut App,
     ) -> Task<Result<Option<EditPredictionResult>>> {
         let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
@@ -62,6 +60,7 @@ impl Mercury {
         let http_client = cx.http_client();
         let cursor_point = position.to_point(&snapshot);
         let buffer_snapshotted_at = Instant::now();
+        let active_buffer = buffer.clone();
 
         let result = cx.background_spawn(async move {
             let (editable_range, context_range) =
@@ -72,39 +71,39 @@ impl Mercury {
                     MAX_REWRITE_TOKENS,
                 );
 
-            let offset_range = editable_range.to_offset(&snapshot);
-            let prompt = build_prompt(
-                &events,
-                &related_files,
-                &snapshot,
-                full_path.as_ref(),
-                cursor_point,
-                editable_range,
-                context_range.clone(),
-            );
-
-            let inputs = EditPredictionInputs {
-                events: events,
-                included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
-                    path: full_path.clone(),
-                    max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
-                    excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
-                        start_line: cloud_llm_client::predict_edits_v3::Line(
-                            context_range.start.row,
-                        ),
-                        text: snapshot
-                            .text_for_range(context_range.clone())
-                            .collect::<String>()
-                            .into(),
-                    }],
-                }],
-                cursor_point: cloud_llm_client::predict_edits_v3::Point {
-                    column: cursor_point.column,
-                    line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
-                },
+            let context_offset_range = context_range.to_offset(&snapshot);
+
+            let editable_offset_range = editable_range.to_offset(&snapshot);
+
+            let inputs = zeta_prompt::ZetaPromptInput {
+                events,
+                related_files,
+                cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
+                    - context_range.start.to_offset(&snapshot),
                 cursor_path: full_path.clone(),
+                cursor_excerpt: snapshot
+                    .text_for_range(context_range)
+                    .collect::<String>()
+                    .into(),
+                editable_range_in_excerpt: (editable_offset_range.start
+                    - context_offset_range.start)
+                    ..(editable_offset_range.end - context_offset_range.start),
             };
 
+            let prompt = build_prompt(&inputs);
+
+            if let Some(debug_tx) = &debug_tx {
+                debug_tx
+                    .unbounded_send(DebugEvent::EditPredictionStarted(
+                        EditPredictionStartedDebugEvent {
+                            buffer: active_buffer.downgrade(),
+                            prompt: Some(prompt.clone()),
+                            position,
+                        },
+                    ))
+                    .ok();
+            }
+
             let request_body = open_ai::Request {
                 model: "mercury-coder".into(),
                 messages: vec![open_ai::RequestMessage::User {
@@ -160,6 +159,18 @@ impl Mercury {
             let id = mem::take(&mut response.id);
             let response_str = text_from_response(response).unwrap_or_default();
 
+            if let Some(debug_tx) = &debug_tx {
+                debug_tx
+                    .unbounded_send(DebugEvent::EditPredictionFinished(
+                        EditPredictionFinishedDebugEvent {
+                            buffer: active_buffer.downgrade(),
+                            model_output: Some(response_str.clone()),
+                            position,
+                        },
+                    ))
+                    .ok();
+            }
+
             let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
             let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
 
@@ -168,15 +179,16 @@ impl Mercury {
 
             if response_str != NO_PREDICTION_OUTPUT {
                 let old_text = snapshot
-                    .text_for_range(offset_range.clone())
+                    .text_for_range(editable_offset_range.clone())
                     .collect::<String>();
                 edits.extend(
                     language::text_diff(&old_text, &response_str)
                         .into_iter()
                         .map(|(range, text)| {
                             (
-                                snapshot.anchor_after(offset_range.start + range.start)
-                                    ..snapshot.anchor_before(offset_range.start + range.end),
+                                snapshot.anchor_after(editable_offset_range.start + range.start)
+                                    ..snapshot
+                                        .anchor_before(editable_offset_range.start + range.end),
                                 text,
                             )
                         }),
@@ -186,8 +198,6 @@ impl Mercury {
             anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
         });
 
-        let buffer = active_buffer.clone();
-
         cx.spawn(async move |cx| {
             let (id, edits, old_snapshot, response_received_at, inputs) =
                 result.await.context("Mercury edit prediction failed")?;
@@ -208,15 +218,7 @@ impl Mercury {
     }
 }
 
-fn build_prompt(
-    events: &[Arc<Event>],
-    related_files: &[RelatedFile],
-    cursor_buffer: &BufferSnapshot,
-    cursor_buffer_path: &Path,
-    cursor_point: Point,
-    editable_range: Range<Point>,
-    context_range: Range<Point>,
-) -> String {
+fn build_prompt(inputs: &ZetaPromptInput) -> String {
     const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
     const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
     const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
@@ -237,14 +239,14 @@ fn build_prompt(
         &mut prompt,
         RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
         |prompt| {
-            for related_file in related_files {
+            for related_file in inputs.related_files.iter() {
                 for related_excerpt in &related_file.excerpts {
                     push_delimited(
                         prompt,
                         RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
                         |prompt| {
                             prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
-                            prompt.push_str(related_file.path.path.as_unix_str());
+                            prompt.push_str(related_file.path.to_string_lossy().as_ref());
                             prompt.push('\n');
                             prompt.push_str(&related_excerpt.text.to_string());
                         },
@@ -259,21 +261,22 @@ fn build_prompt(
         CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
         |prompt| {
             prompt.push_str(CURRENT_FILE_PATH_PREFIX);
-            prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
+            prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
             prompt.push('\n');
 
-            let prefix_range = context_range.start..editable_range.start;
-            let suffix_range = editable_range.end..context_range.end;
-
-            prompt.extend(cursor_buffer.text_for_range(prefix_range));
+            prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]);
             push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
-                let range_before_cursor = editable_range.start..cursor_point;
-                let range_after_cursor = cursor_point..editable_range.end;
-                prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
+                prompt.push_str(
+                    &inputs.cursor_excerpt
+                        [inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt],
+                );
                 prompt.push_str(CURSOR_TAG);
-                prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
+                prompt.push_str(
+                    &inputs.cursor_excerpt
+                        [inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end],
+                );
             });
-            prompt.extend(cursor_buffer.text_for_range(suffix_range));
+            prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]);
         },
     );
 
@@ -281,8 +284,8 @@ fn build_prompt(
         &mut prompt,
         EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
         |prompt| {
-            for event in events {
-                writeln!(prompt, "{event}").unwrap();
+            for event in inputs.events.iter() {
+                zeta_prompt::write_event(prompt, &event);
             }
         },
     );

crates/edit_prediction/src/prediction.rs 🔗

@@ -1,6 +1,5 @@
 use std::{
     ops::Range,
-    path::Path,
     sync::Arc,
     time::{Duration, Instant},
 };
@@ -9,7 +8,7 @@ use cloud_llm_client::EditPredictionRejectReason;
 use edit_prediction_types::interpolate_edits;
 use gpui::{AsyncApp, Entity, SharedString};
 use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
-use serde::Serialize;
+use zeta_prompt::ZetaPromptInput;
 
 #[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
 pub struct EditPredictionId(pub SharedString);
@@ -40,7 +39,7 @@ impl EditPredictionResult {
         edits: Arc<[(Range<Anchor>, Arc<str>)]>,
         buffer_snapshotted_at: Instant,
         response_received_at: Instant,
-        inputs: EditPredictionInputs,
+        inputs: ZetaPromptInput,
         cx: &mut AsyncApp,
     ) -> Self {
         if edits.is_empty() {
@@ -94,15 +93,7 @@ pub struct EditPrediction {
     pub buffer: Entity<Buffer>,
     pub buffer_snapshotted_at: Instant,
     pub response_received_at: Instant,
-    pub inputs: EditPredictionInputs,
-}
-
-#[derive(Debug, Clone, Serialize)]
-pub struct EditPredictionInputs {
-    pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
-    pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
-    pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
-    pub cursor_path: Arc<Path>,
+    pub inputs: zeta_prompt::ZetaPromptInput,
 }
 
 impl EditPrediction {
@@ -133,9 +124,12 @@ impl std::fmt::Debug for EditPrediction {
 
 #[cfg(test)]
 mod tests {
+    use std::path::Path;
+
     use super::*;
     use gpui::{App, Entity, TestAppContext, prelude::*};
     use language::{Buffer, ToOffset as _};
+    use zeta_prompt::ZetaPromptInput;
 
     #[gpui::test]
     async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
@@ -154,14 +148,13 @@ mod tests {
             snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
             buffer: buffer.clone(),
             edit_preview,
-            inputs: EditPredictionInputs {
+            inputs: ZetaPromptInput {
                 events: vec![],
-                included_files: vec![],
-                cursor_point: cloud_llm_client::predict_edits_v3::Point {
-                    line: cloud_llm_client::predict_edits_v3::Line(0),
-                    column: 0,
-                },
+                related_files: vec![].into(),
                 cursor_path: Path::new("path.txt").into(),
+                cursor_offset_in_excerpt: 0,
+                cursor_excerpt: "".into(),
+                editable_range_in_excerpt: 0..0,
             },
             buffer_snapshotted_at: Instant::now(),
             response_received_at: Instant::now(),

crates/edit_prediction/src/sweep_ai.rs 🔗

@@ -1,26 +1,21 @@
 use anyhow::{Context as _, Result};
-use cloud_llm_client::predict_edits_v3::Event;
 use credentials_provider::CredentialsProvider;
-use edit_prediction_context::RelatedFile;
 use futures::{AsyncReadExt as _, FutureExt, future::Shared};
 use gpui::{
-    App, AppContext as _, Entity, Task,
+    App, AppContext as _, Task,
     http_client::{self, AsyncBody, Method},
 };
-use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
+use language::{Point, ToOffset as _};
 use lsp::DiagnosticSeverity;
-use project::{Project, ProjectPath};
 use serde::{Deserialize, Serialize};
 use std::{
-    collections::VecDeque,
     fmt::{self, Write as _},
-    ops::Range,
     path::Path,
     sync::Arc,
     time::Instant,
 };
 
-use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
+use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
 
 const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
 
@@ -44,40 +39,34 @@ impl SweepAi {
 
     pub fn request_prediction_with_sweep(
         &self,
-        project: &Entity<Project>,
-        active_buffer: &Entity<Buffer>,
-        snapshot: BufferSnapshot,
-        position: language::Anchor,
-        events: Vec<Arc<Event>>,
-        recent_paths: &VecDeque<ProjectPath>,
-        related_files: Vec<RelatedFile>,
-        diagnostic_search_range: Range<Point>,
+        inputs: EditPredictionModelInput,
         cx: &mut App,
     ) -> Task<Result<Option<EditPredictionResult>>> {
         let debug_info = self.debug_info.clone();
         let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
             return Task::ready(Ok(None));
         };
-        let full_path: Arc<Path> = snapshot
+        let full_path: Arc<Path> = inputs
+            .snapshot
             .file()
             .map(|file| file.full_path(cx))
             .unwrap_or_else(|| "untitled".into())
             .into();
 
-        let project_file = project::File::from_dyn(snapshot.file());
+        let project_file = project::File::from_dyn(inputs.snapshot.file());
         let repo_name = project_file
             .map(|file| file.worktree.read(cx).root_name_str())
             .unwrap_or("untitled")
             .into();
-        let offset = position.to_offset(&snapshot);
+        let offset = inputs.position.to_offset(&inputs.snapshot);
 
-        let recent_buffers = recent_paths.iter().cloned();
+        let recent_buffers = inputs.recent_paths.iter().cloned();
         let http_client = cx.http_client();
 
         let recent_buffer_snapshots = recent_buffers
             .filter_map(|project_path| {
-                let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
-                if active_buffer == &buffer {
+                let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
+                if inputs.buffer == buffer {
                     None
                 } else {
                     Some(buffer.read(cx).snapshot())
@@ -86,14 +75,13 @@ impl SweepAi {
             .take(3)
             .collect::<Vec<_>>();
 
-        let cursor_point = position.to_point(&snapshot);
         let buffer_snapshotted_at = Instant::now();
 
         let result = cx.background_spawn(async move {
-            let text = snapshot.text();
+            let text = inputs.snapshot.text();
 
             let mut recent_changes = String::new();
-            for event in &events {
+            for event in &inputs.events {
                 write_event(event.as_ref(), &mut recent_changes).unwrap();
             }
 
@@ -122,20 +110,23 @@ impl SweepAi {
                 })
                 .collect::<Vec<_>>();
 
-            let retrieval_chunks = related_files
+            let retrieval_chunks = inputs
+                .related_files
                 .iter()
                 .flat_map(|related_file| {
                     related_file.excerpts.iter().map(|excerpt| FileChunk {
-                        file_path: related_file.path.path.as_unix_str().to_string(),
-                        start_line: excerpt.point_range.start.row as usize,
-                        end_line: excerpt.point_range.end.row as usize,
+                        file_path: related_file.path.to_string_lossy().to_string(),
+                        start_line: excerpt.row_range.start as usize,
+                        end_line: excerpt.row_range.end as usize,
                         content: excerpt.text.to_string(),
                         timestamp: None,
                     })
                 })
                 .collect();
 
-            let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
+            let diagnostic_entries = inputs
+                .snapshot
+                .diagnostics_in_range(inputs.diagnostic_search_range, false);
             let mut diagnostic_content = String::new();
             let mut diagnostic_count = 0;
 
@@ -195,21 +186,14 @@ impl SweepAi {
             serde_json::to_writer(writer, &request_body)?;
             let body: AsyncBody = buf.into();
 
-            let inputs = EditPredictionInputs {
-                events,
-                included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
-                    path: full_path.clone(),
-                    max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
-                    excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
-                        start_line: cloud_llm_client::predict_edits_v3::Line(0),
-                        text: request_body.file_contents.into(),
-                    }],
-                }],
-                cursor_point: cloud_llm_client::predict_edits_v3::Point {
-                    column: cursor_point.column,
-                    line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
-                },
+            let ep_inputs = zeta_prompt::ZetaPromptInput {
+                events: inputs.events,
+                related_files: inputs.related_files.clone(),
                 cursor_path: full_path.clone(),
+                cursor_excerpt: request_body.file_contents.into(),
+                // we actually don't know
+                editable_range_in_excerpt: 0..inputs.snapshot.len(),
+                cursor_offset_in_excerpt: request_body.cursor_position,
             };
 
             let request = http_client::Request::builder()
@@ -237,15 +221,20 @@ impl SweepAi {
 
             let response: AutocompleteResponse = serde_json::from_slice(&body)?;
 
-            let old_text = snapshot
+            let old_text = inputs
+                .snapshot
                 .text_for_range(response.start_index..response.end_index)
                 .collect::<String>();
             let edits = language::text_diff(&old_text, &response.completion)
                 .into_iter()
                 .map(|(range, text)| {
                     (
-                        snapshot.anchor_after(response.start_index + range.start)
-                            ..snapshot.anchor_before(response.start_index + range.end),
+                        inputs
+                            .snapshot
+                            .anchor_after(response.start_index + range.start)
+                            ..inputs
+                                .snapshot
+                                .anchor_before(response.start_index + range.end),
                         text,
                     )
                 })
@@ -254,13 +243,13 @@ impl SweepAi {
             anyhow::Ok((
                 response.autocomplete_id,
                 edits,
-                snapshot,
+                inputs.snapshot,
                 response_received_at,
-                inputs,
+                ep_inputs,
             ))
         });
 
-        let buffer = active_buffer.clone();
+        let buffer = inputs.buffer.clone();
 
         cx.spawn(async move |cx| {
             let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
@@ -403,12 +392,9 @@ struct AdditionalCompletion {
     pub finish_reason: Option<String>,
 }
 
-fn write_event(
-    event: &cloud_llm_client::predict_edits_v3::Event,
-    f: &mut impl fmt::Write,
-) -> fmt::Result {
+fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
     match event {
-        cloud_llm_client::predict_edits_v3::Event::BufferChange {
+        zeta_prompt::Event::BufferChange {
             old_path,
             path,
             diff,

crates/edit_prediction/src/udiff.rs 🔗

@@ -14,68 +14,18 @@ use anyhow::anyhow;
 use collections::HashMap;
 use gpui::AsyncApp;
 use gpui::Entity;
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot};
+use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
 use project::Project;
 
-pub async fn parse_diff<'a>(
-    diff_str: &'a str,
-    get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
-) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
-    let mut diff = DiffParser::new(diff_str);
-    let mut edited_buffer = None;
-    let mut edits = Vec::new();
-
-    while let Some(event) = diff.next()? {
-        match event {
-            DiffEvent::Hunk {
-                path: file_path,
-                hunk,
-            } => {
-                let (buffer, ranges) = match edited_buffer {
-                    None => {
-                        edited_buffer = get_buffer(&Path::new(file_path.as_ref()));
-                        edited_buffer
-                            .as_ref()
-                            .context("Model tried to edit a file that wasn't included")?
-                    }
-                    Some(ref current) => current,
-                };
-
-                edits.extend(
-                    resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)
-                        .with_context(|| format!("Diff:\n{diff_str}"))?,
-                );
-            }
-            DiffEvent::FileEnd { renamed_to } => {
-                let (buffer, _) = edited_buffer
-                    .take()
-                    .context("Got a FileEnd event before an Hunk event")?;
-
-                if renamed_to.is_some() {
-                    anyhow::bail!("edit predictions cannot rename files");
-                }
-
-                if diff.next()?.is_some() {
-                    anyhow::bail!("Edited more than one file");
-                }
-
-                return Ok((buffer, edits));
-            }
-        }
-    }
-
-    Err(anyhow::anyhow!("No EOF"))
-}
-
-#[derive(Debug)]
-pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap<Cow<'a, str>, Entity<Buffer>>);
+#[derive(Clone, Debug)]
+pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
 
 #[must_use]
-pub async fn apply_diff<'a>(
-    diff_str: &'a str,
+pub async fn apply_diff(
+    diff_str: &str,
     project: &Entity<Project>,
     cx: &mut AsyncApp,
-) -> Result<OpenedBuffers<'a>> {
+) -> Result<OpenedBuffers> {
     let mut included_files = HashMap::default();
 
     for line in diff_str.lines() {
@@ -94,7 +44,7 @@ pub async fn apply_diff<'a>(
                 })??
                 .await?;
 
-            included_files.insert(path, buffer);
+            included_files.insert(path.to_string(), buffer);
         }
     }
 
@@ -113,7 +63,7 @@ pub async fn apply_diff<'a>(
                 let (buffer, ranges) = match current_file {
                     None => {
                         let buffer = included_files
-                            .get_mut(&file_path)
+                            .get_mut(file_path.as_ref())
                             .expect("Opened all files in diff");
 
                         current_file = Some((buffer, ranges.as_slice()));
@@ -167,6 +117,29 @@ pub async fn apply_diff<'a>(
     Ok(OpenedBuffers(included_files))
 }
 
+pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
+    let mut diff = DiffParser::new(diff_str);
+
+    let mut text = text.to_string();
+
+    while let Some(event) = diff.next()? {
+        match event {
+            DiffEvent::Hunk { hunk, .. } => {
+                let hunk_offset = text
+                    .find(&hunk.context)
+                    .ok_or_else(|| anyhow!("couldn't result hunk {:?}", hunk.context))?;
+                for edit in hunk.edits.iter().rev() {
+                    let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
+                    text.replace_range(range, &edit.text);
+                }
+            }
+            DiffEvent::FileEnd { .. } => {}
+        }
+    }
+
+    Ok(text)
+}
+
 struct PatchFile<'a> {
     old_path: Cow<'a, str>,
     new_path: Cow<'a, str>,
@@ -492,7 +465,6 @@ mod tests {
     use super::*;
     use gpui::TestAppContext;
     use indoc::indoc;
-    use language::Point;
     use pretty_assertions::assert_eq;
     use project::{FakeFs, Project};
     use serde_json::json;
@@ -817,137 +789,6 @@ mod tests {
         });
     }
 
-    #[gpui::test]
-    async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
-        let fs = init_test(cx);
-
-        let buffer_1_text = indoc! {r#"
-            one
-            two
-            three
-            four
-            five
-            one
-            two
-            three
-            four
-            five
-        "# };
-
-        fs.insert_tree(
-            path!("/root"),
-            json!({
-                "file1": buffer_1_text,
-            }),
-        )
-        .await;
-
-        let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
-        let buffer = project
-            .update(cx, |project, cx| {
-                project.open_local_buffer(path!("/root/file1"), cx)
-            })
-            .await
-            .unwrap();
-        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
-        let diff = indoc! {r#"
-            --- a/root/file1
-            +++ b/root/file1
-             one
-             two
-            -three
-            +3
-             four
-             five
-        "#};
-
-        let final_text = indoc! {r#"
-            one
-            two
-            three
-            four
-            five
-            one
-            two
-            3
-            four
-            five
-        "#};
-
-        apply_diff(diff, &project, &mut cx.to_async())
-            .await
-            .expect_err("Non-unique edits should fail");
-
-        let ranges = [buffer_snapshot.anchor_before(Point::new(1, 0))
-            ..buffer_snapshot.anchor_after(buffer_snapshot.max_point())];
-
-        let (edited_snapshot, edits) = parse_diff(diff, |_path| Some((&buffer_snapshot, &ranges)))
-            .await
-            .unwrap();
-
-        assert_eq!(edited_snapshot.remote_id(), buffer_snapshot.remote_id());
-        buffer.update(cx, |buffer, cx| {
-            buffer.edit(edits, None, cx);
-            assert_eq!(buffer.text(), final_text);
-        });
-    }
-
-    #[gpui::test]
-    async fn test_parse_diff_with_edits_within_line(cx: &mut TestAppContext) {
-        let fs = init_test(cx);
-
-        let buffer_1_text = indoc! {r#"
-            one two three four
-            five six seven eight
-            nine ten eleven twelve
-        "# };
-
-        fs.insert_tree(
-            path!("/root"),
-            json!({
-                "file1": buffer_1_text,
-            }),
-        )
-        .await;
-
-        let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
-        let buffer = project
-            .update(cx, |project, cx| {
-                project.open_local_buffer(path!("/root/file1"), cx)
-            })
-            .await
-            .unwrap();
-        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
-        let diff = indoc! {r#"
-            --- a/root/file1
-            +++ b/root/file1
-             one two three four
-            -five six seven eight
-            +five SIX seven eight!
-             nine ten eleven twelve
-        "#};
-
-        let (buffer, edits) = parse_diff(diff, |_path| {
-            Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_]))
-        })
-        .await
-        .unwrap();
-
-        let edits = edits
-            .into_iter()
-            .map(|(range, text)| (range.to_point(&buffer), text))
-            .collect::<Vec<_>>();
-        assert_eq!(
-            edits,
-            &[
-                (Point::new(1, 5)..Point::new(1, 8), "SIX".into()),
-                (Point::new(1, 20)..Point::new(1, 20), "!".into())
-            ]
-        );
-    }
-
     #[gpui::test]
     async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
         let fs = init_test(cx);

crates/edit_prediction/src/xml_edits.rs 🔗

@@ -1,637 +0,0 @@
-use anyhow::{Context as _, Result};
-use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
-use std::{cmp, ops::Range, path::Path, sync::Arc};
-
-const EDITS_TAG_NAME: &'static str = "edits";
-const OLD_TEXT_TAG_NAME: &'static str = "old_text";
-const NEW_TEXT_TAG_NAME: &'static str = "new_text";
-const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME];
-
-pub async fn parse_xml_edits<'a>(
-    input: &'a str,
-    get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
-) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
-    parse_xml_edits_inner(input, get_buffer)
-        .await
-        .with_context(|| format!("Failed to parse XML edits:\n{input}"))
-}
-
-async fn parse_xml_edits_inner<'a>(
-    input: &'a str,
-    get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
-) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
-    let xml_edits = extract_xml_replacements(input)?;
-
-    let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref())
-        .with_context(|| format!("no buffer for file {}", xml_edits.file_path))?;
-
-    let mut all_edits = vec![];
-    for (old_text, new_text) in xml_edits.replacements {
-        let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?;
-        let matched_old_text = buffer
-            .text_for_range(match_range.clone())
-            .collect::<String>();
-        let edits_within_hunk = language::text_diff(&matched_old_text, new_text);
-        all_edits.extend(
-            edits_within_hunk
-                .into_iter()
-                .map(move |(inner_range, inner_text)| {
-                    (
-                        buffer.anchor_after(match_range.start + inner_range.start)
-                            ..buffer.anchor_before(match_range.start + inner_range.end),
-                        inner_text,
-                    )
-                }),
-        );
-    }
-
-    Ok((buffer, all_edits))
-}
-
-fn fuzzy_match_in_ranges(
-    old_text: &str,
-    buffer: &BufferSnapshot,
-    context_ranges: &[Range<Anchor>],
-) -> Result<Range<usize>> {
-    let mut state = FuzzyMatcher::new(buffer, old_text);
-    let mut best_match = None;
-    let mut tie_match_range = None;
-
-    for range in context_ranges {
-        let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
-        match (best_match_cost, state.match_range(range.to_offset(buffer))) {
-            (Some(lowest_cost), Some((new_cost, new_range))) => {
-                if new_cost == lowest_cost {
-                    tie_match_range = Some(new_range);
-                } else if new_cost < lowest_cost {
-                    tie_match_range.take();
-                    best_match = Some((new_cost, new_range));
-                }
-            }
-            (None, Some(new_match)) => {
-                best_match = Some(new_match);
-            }
-            (None, None) | (Some(_), None) => {}
-        };
-    }
-
-    if let Some((_, best_match_range)) = best_match {
-        if let Some(tie_match_range) = tie_match_range {
-            anyhow::bail!(
-                "Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
-                best_match_range.clone(),
-                buffer.text_for_range(best_match_range).collect::<String>(),
-                tie_match_range.clone(),
-                buffer.text_for_range(tie_match_range).collect::<String>()
-            );
-        }
-        return Ok(best_match_range);
-    }
-
-    anyhow::bail!(
-        "Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
-        old_text,
-        context_ranges
-            .iter()
-            .map(|range| buffer.text_for_range(range.clone()).collect::<String>())
-            .collect::<Vec<String>>()
-            .join("```\n```")
-    );
-}
-
-#[derive(Debug)]
-struct XmlEdits<'a> {
-    file_path: &'a str,
-    /// Vec of (old_text, new_text) pairs
-    replacements: Vec<(&'a str, &'a str)>,
-}
-
-fn extract_xml_replacements(input: &str) -> Result<XmlEdits<'_>> {
-    let mut cursor = 0;
-
-    let (edits_body_start, edits_attrs) =
-        find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?;
-
-    let file_path = edits_attrs
-        .trim_start()
-        .strip_prefix("path")
-        .context("no path attribute on edits tag")?
-        .trim_end()
-        .strip_prefix('=')
-        .context("no value for path attribute")?
-        .trim()
-        .trim_start_matches('"')
-        .trim_end_matches('"');
-
-    cursor = edits_body_start;
-    let mut edits_list = Vec::new();
-
-    while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? {
-        let old_body_end = find_tag_close(input, &mut cursor)?;
-        let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]);
-
-        let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)?
-            .context("no new_text tag following old_text")?;
-        let new_body_end = find_tag_close(input, &mut cursor)?;
-        let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]);
-
-        edits_list.push((old_text, new_text));
-    }
-
-    Ok(XmlEdits {
-        file_path,
-        replacements: edits_list,
-    })
-}
-
-/// Trims a single leading and trailing newline
-fn trim_surrounding_newlines(input: &str) -> &str {
-    let start = input.strip_prefix('\n').unwrap_or(input);
-    let end = start.strip_suffix('\n').unwrap_or(start);
-    end
-}
-
-fn find_tag_open<'a>(
-    input: &'a str,
-    cursor: &mut usize,
-    expected_tag: &str,
-) -> Result<Option<(usize, &'a str)>> {
-    let mut search_pos = *cursor;
-
-    while search_pos < input.len() {
-        let Some(tag_start) = input[search_pos..].find("<") else {
-            break;
-        };
-        let tag_start = search_pos + tag_start;
-        if !input[tag_start + 1..].starts_with(expected_tag) {
-            search_pos = search_pos + tag_start + 1;
-            continue;
-        };
-
-        let after_tag_name = tag_start + expected_tag.len() + 1;
-        let close_bracket = input[after_tag_name..]
-            .find('>')
-            .with_context(|| format!("missing > after <{}", expected_tag))?;
-        let attrs_end = after_tag_name + close_bracket;
-        let body_start = attrs_end + 1;
-
-        let attributes = input[after_tag_name..attrs_end].trim();
-        *cursor = body_start;
-
-        return Ok(Some((body_start, attributes)));
-    }
-
-    Ok(None)
-}
-
-fn find_tag_close(input: &str, cursor: &mut usize) -> Result<usize> {
-    let mut depth = 1;
-    let mut search_pos = *cursor;
-
-    while search_pos < input.len() && depth > 0 {
-        let Some(bracket_offset) = input[search_pos..].find('<') else {
-            break;
-        };
-        let bracket_pos = search_pos + bracket_offset;
-
-        if input[bracket_pos..].starts_with("</")
-            && let Some(close_end) = input[bracket_pos + 2..].find('>')
-        {
-            let close_start = bracket_pos + 2;
-            let tag_name = input[close_start..close_start + close_end].trim();
-
-            if XML_TAGS.contains(&tag_name) {
-                depth -= 1;
-                if depth == 0 {
-                    *cursor = close_start + close_end + 1;
-                    return Ok(bracket_pos);
-                }
-            }
-            search_pos = close_start + close_end + 1;
-            continue;
-        } else if let Some(close_bracket_offset) = input[bracket_pos..].find('>') {
-            let close_bracket_pos = bracket_pos + close_bracket_offset;
-            let tag_name = &input[bracket_pos + 1..close_bracket_pos].trim();
-            if XML_TAGS.contains(&tag_name) {
-                depth += 1;
-            }
-        }
-
-        search_pos = bracket_pos + 1;
-    }
-
-    anyhow::bail!("no closing tag found")
-}
-
-const REPLACEMENT_COST: u32 = 1;
-const INSERTION_COST: u32 = 3;
-const DELETION_COST: u32 = 10;
-
-/// A fuzzy matcher that can process text chunks incrementally
-/// and return the best match found so far at each step.
-struct FuzzyMatcher<'a> {
-    snapshot: &'a BufferSnapshot,
-    query_lines: Vec<&'a str>,
-    matrix: SearchMatrix,
-}
-
-impl<'a> FuzzyMatcher<'a> {
-    fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self {
-        let query_lines = old_text.lines().collect();
-        Self {
-            snapshot,
-            query_lines,
-            matrix: SearchMatrix::new(0),
-        }
-    }
-
-    fn match_range(&mut self, range: Range<usize>) -> Option<(u32, Range<usize>)> {
-        let point_range = range.to_point(&self.snapshot);
-        let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize;
-
-        self.matrix
-            .reset(self.query_lines.len() + 1, buffer_line_count + 1);
-        let query_line_count = self.query_lines.len();
-
-        for row in 0..query_line_count {
-            let query_line = self.query_lines[row].trim();
-            let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
-
-            self.matrix.set(
-                row + 1,
-                0,
-                SearchState::new(leading_deletion_cost, SearchDirection::Up),
-            );
-
-            let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines();
-
-            let mut col = 0;
-            while let Some(buffer_line) = buffer_lines.next() {
-                let buffer_line = buffer_line.trim();
-                let up = SearchState::new(
-                    self.matrix
-                        .get(row, col + 1)
-                        .cost
-                        .saturating_add(DELETION_COST),
-                    SearchDirection::Up,
-                );
-                let left = SearchState::new(
-                    self.matrix
-                        .get(row + 1, col)
-                        .cost
-                        .saturating_add(INSERTION_COST),
-                    SearchDirection::Left,
-                );
-                let diagonal = SearchState::new(
-                    if query_line == buffer_line {
-                        self.matrix.get(row, col).cost
-                    } else if fuzzy_eq(query_line, buffer_line) {
-                        self.matrix.get(row, col).cost + REPLACEMENT_COST
-                    } else {
-                        self.matrix
-                            .get(row, col)
-                            .cost
-                            .saturating_add(DELETION_COST + INSERTION_COST)
-                    },
-                    SearchDirection::Diagonal,
-                );
-                self.matrix
-                    .set(row + 1, col + 1, up.min(left).min(diagonal));
-                col += 1;
-            }
-        }
-
-        // Find all matches with the best cost
-        let mut best_cost = u32::MAX;
-        let mut matches_with_best_cost = Vec::new();
-
-        for col in 1..=buffer_line_count {
-            let cost = self.matrix.get(query_line_count, col).cost;
-            if cost < best_cost {
-                best_cost = cost;
-                matches_with_best_cost.clear();
-                matches_with_best_cost.push(col as u32);
-            } else if cost == best_cost {
-                matches_with_best_cost.push(col as u32);
-            }
-        }
-
-        // Find ranges for the matches
-        for &match_end_col in &matches_with_best_cost {
-            let mut matched_lines = 0;
-            let mut query_row = query_line_count;
-            let mut match_start_col = match_end_col;
-            while query_row > 0 && match_start_col > 0 {
-                let current = self.matrix.get(query_row, match_start_col as usize);
-                match current.direction {
-                    SearchDirection::Diagonal => {
-                        query_row -= 1;
-                        match_start_col -= 1;
-                        matched_lines += 1;
-                    }
-                    SearchDirection::Up => {
-                        query_row -= 1;
-                    }
-                    SearchDirection::Left => {
-                        match_start_col -= 1;
-                    }
-                }
-            }
-
-            let buffer_row_start = match_start_col + point_range.start.row;
-            let buffer_row_end = match_end_col + point_range.start.row;
-
-            let matched_buffer_row_count = buffer_row_end - buffer_row_start;
-            let matched_ratio = matched_lines as f32
-                / (matched_buffer_row_count as f32).max(query_line_count as f32);
-            if matched_ratio >= 0.8 {
-                let buffer_start_ix = self
-                    .snapshot
-                    .point_to_offset(Point::new(buffer_row_start, 0));
-                let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
-                    buffer_row_end - 1,
-                    self.snapshot.line_len(buffer_row_end - 1),
-                ));
-                return Some((best_cost, buffer_start_ix..buffer_end_ix));
-            }
-        }
-
-        None
-    }
-}
-
-fn fuzzy_eq(left: &str, right: &str) -> bool {
-    const THRESHOLD: f64 = 0.8;
-
-    let min_levenshtein = left.len().abs_diff(right.len());
-    let min_normalized_levenshtein =
-        1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
-    if min_normalized_levenshtein < THRESHOLD {
-        return false;
-    }
-
-    strsim::normalized_levenshtein(left, right) >= THRESHOLD
-}
-
-#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
-enum SearchDirection {
-    Up,
-    Left,
-    Diagonal,
-}
-
-#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
-struct SearchState {
-    cost: u32,
-    direction: SearchDirection,
-}
-
-impl SearchState {
-    fn new(cost: u32, direction: SearchDirection) -> Self {
-        Self { cost, direction }
-    }
-}
-
-struct SearchMatrix {
-    cols: usize,
-    rows: usize,
-    data: Vec<SearchState>,
-}
-
-impl SearchMatrix {
-    fn new(cols: usize) -> Self {
-        SearchMatrix {
-            cols,
-            rows: 0,
-            data: Vec::new(),
-        }
-    }
-
-    fn reset(&mut self, rows: usize, cols: usize) {
-        self.rows = rows;
-        self.cols = cols;
-        self.data
-            .fill(SearchState::new(0, SearchDirection::Diagonal));
-        self.data.resize(
-            self.rows * self.cols,
-            SearchState::new(0, SearchDirection::Diagonal),
-        );
-    }
-
-    fn get(&self, row: usize, col: usize) -> SearchState {
-        debug_assert!(row < self.rows);
-        debug_assert!(col < self.cols);
-        self.data[row * self.cols + col]
-    }
-
-    fn set(&mut self, row: usize, col: usize, state: SearchState) {
-        debug_assert!(row < self.rows && col < self.cols);
-        self.data[row * self.cols + col] = state;
-    }
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use gpui::TestAppContext;
-    use indoc::indoc;
-    use language::Point;
-    use project::{FakeFs, Project};
-    use serde_json::json;
-    use settings::SettingsStore;
-    use util::path;
-
-    #[test]
-    fn test_extract_xml_edits() {
-        let input = indoc! {r#"
-            <edits path="test.rs">
-            <old_text>
-            old content
-            </old_text>
-            <new_text>
-            new content
-            </new_text>
-            </edits>
-        "#};
-
-        let result = extract_xml_replacements(input).unwrap();
-        assert_eq!(result.file_path, "test.rs");
-        assert_eq!(result.replacements.len(), 1);
-        assert_eq!(result.replacements[0].0, "old content");
-        assert_eq!(result.replacements[0].1, "new content");
-    }
-
-    #[test]
-    fn test_extract_xml_edits_with_wrong_closing_tags() {
-        let input = indoc! {r#"
-            <edits path="test.rs">
-            <old_text>
-            old content
-            </new_text>
-            <new_text>
-            new content
-            </old_text>
-            </ edits >
-        "#};
-
-        let result = extract_xml_replacements(input).unwrap();
-        assert_eq!(result.file_path, "test.rs");
-        assert_eq!(result.replacements.len(), 1);
-        assert_eq!(result.replacements[0].0, "old content");
-        assert_eq!(result.replacements[0].1, "new content");
-    }
-
-    #[test]
-    fn test_extract_xml_edits_with_xml_like_content() {
-        let input = indoc! {r#"
-            <edits path="component.tsx">
-            <old_text>
-            <foo><bar></bar></foo>
-            </old_text>
-            <new_text>
-            <foo><bar><baz></baz></bar></foo>
-            </new_text>
-            </edits>
-        "#};
-
-        let result = extract_xml_replacements(input).unwrap();
-        assert_eq!(result.file_path, "component.tsx");
-        assert_eq!(result.replacements.len(), 1);
-        assert_eq!(result.replacements[0].0, "<foo><bar></bar></foo>");
-        assert_eq!(
-            result.replacements[0].1,
-            "<foo><bar><baz></baz></bar></foo>"
-        );
-    }
-
-    #[test]
-    fn test_extract_xml_edits_with_conflicting_content() {
-        let input = indoc! {r#"
-            <edits path="component.tsx">
-            <old_text>
-            <new_text></new_text>
-            </old_text>
-            <new_text>
-            <old_text></old_text>
-            </new_text>
-            </edits>
-        "#};
-
-        let result = extract_xml_replacements(input).unwrap();
-        assert_eq!(result.file_path, "component.tsx");
-        assert_eq!(result.replacements.len(), 1);
-        assert_eq!(result.replacements[0].0, "<new_text></new_text>");
-        assert_eq!(result.replacements[0].1, "<old_text></old_text>");
-    }
-
-    #[test]
-    fn test_extract_xml_edits_multiple_pairs() {
-        let input = indoc! {r#"
-            Some reasoning before edits. Lots of thinking going on here
-
-            <edits path="test.rs">
-            <old_text>
-            first old
-            </old_text>
-            <new_text>
-            first new
-            </new_text>
-            <old_text>
-            second old
-            </edits>
-            <new_text>
-            second new
-            </old_text>
-            </edits>
-        "#};
-
-        let result = extract_xml_replacements(input).unwrap();
-        assert_eq!(result.file_path, "test.rs");
-        assert_eq!(result.replacements.len(), 2);
-        assert_eq!(result.replacements[0].0, "first old");
-        assert_eq!(result.replacements[0].1, "first new");
-        assert_eq!(result.replacements[1].0, "second old");
-        assert_eq!(result.replacements[1].1, "second new");
-    }
-
-    #[test]
-    fn test_extract_xml_edits_unexpected_eof() {
-        let input = indoc! {r#"
-            <edits path="test.rs">
-            <old_text>
-            first old
-            </
-        "#};
-
-        extract_xml_replacements(input).expect_err("Unexpected end of file");
-    }
-
-    #[gpui::test]
-    async fn test_parse_xml_edits(cx: &mut TestAppContext) {
-        let fs = init_test(cx);
-
-        let buffer_1_text = indoc! {r#"
-            one two three four
-            five six seven eight
-            nine ten eleven twelve
-            thirteen fourteen fifteen
-            sixteen seventeen eighteen
-        "#};
-
-        fs.insert_tree(
-            path!("/root"),
-            json!({
-                "file1": buffer_1_text,
-            }),
-        )
-        .await;
-
-        let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
-        let buffer = project
-            .update(cx, |project, cx| {
-                project.open_local_buffer(path!("/root/file1"), cx)
-            })
-            .await
-            .unwrap();
-        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
-        let edits = indoc! {r#"
-            <edits path="root/file1">
-            <old_text>
-            nine ten eleven twelve
-            </old_text>
-            <new_text>
-            nine TEN eleven twelve!
-            </new_text>
-            </edits>
-        "#};
-
-        let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
-        let (buffer, edits) = parse_xml_edits(edits, |_path| {
-            Some((&buffer_snapshot, included_ranges.as_slice()))
-        })
-        .await
-        .unwrap();
-
-        let edits = edits
-            .into_iter()
-            .map(|(range, text)| (range.to_point(&buffer), text))
-            .collect::<Vec<_>>();
-        assert_eq!(
-            edits,
-            &[
-                (Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
-                (Point::new(2, 22)..Point::new(2, 22), "!".into())
-            ]
-        );
-    }
-
-    fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
-        cx.update(|cx| {
-            let settings_store = SettingsStore::test(cx);
-            cx.set_global(settings_store);
-        });
-
-        FakeFs::new(cx.background_executor.clone())
-    }
-}

crates/edit_prediction/src/zeta1.rs 🔗

@@ -1,22 +1,23 @@
 use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
 
 use crate::{
-    EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
+    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
+    EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
     cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
-    prediction::{EditPredictionInputs, EditPredictionResult},
+    prediction::EditPredictionResult,
 };
 use anyhow::{Context as _, Result};
 use cloud_llm_client::{
     PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
-    predict_edits_v3::Event,
 };
 use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 use language::{
-    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
+    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
 };
 use project::{Project, ProjectPath};
 use release_channel::AppVersion;
 use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
+use zeta_prompt::{Event, ZetaPromptInput};
 
 const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
 const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
@@ -29,24 +30,27 @@ pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 
 pub(crate) fn request_prediction_with_zeta1(
     store: &mut EditPredictionStore,
-    project: &Entity<Project>,
-    buffer: &Entity<Buffer>,
-    snapshot: BufferSnapshot,
-    position: language::Anchor,
-    events: Vec<Arc<Event>>,
-    trigger: PredictEditsRequestTrigger,
+    EditPredictionModelInput {
+        project,
+        buffer,
+        snapshot,
+        position,
+        events,
+        trigger,
+        debug_tx,
+        ..
+    }: EditPredictionModelInput,
     cx: &mut Context<EditPredictionStore>,
 ) -> Task<Result<Option<EditPredictionResult>>> {
-    let buffer = buffer.clone();
     let buffer_snapshotted_at = Instant::now();
     let client = store.client.clone();
     let llm_token = store.llm_token.clone();
     let app_version = AppVersion::global(cx);
 
     let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
-        let can_collect_file = store.can_collect_file(project, file, cx);
+        let can_collect_file = store.can_collect_file(&project, file, cx);
         let git_info = if can_collect_file {
-            git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
+            git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
         } else {
             None
         };
@@ -120,33 +124,33 @@ pub(crate) fn request_prediction_with_zeta1(
         )
         .await;
 
-        let inputs = EditPredictionInputs {
+        let context_start_offset = context_range.start.to_offset(&snapshot);
+        let editable_offset_range = editable_range.to_offset(&snapshot);
+
+        let inputs = ZetaPromptInput {
             events: included_events.into(),
-            included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
-                path: full_path.clone(),
-                max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
-                excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
-                    start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
-                    text: snapshot
-                        .text_for_range(context_range)
-                        .collect::<String>()
-                        .into(),
-                }],
-            }],
-            cursor_point: cloud_llm_client::predict_edits_v3::Point {
-                column: cursor_point.column,
-                line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
-            },
+            related_files: vec![].into(),
             cursor_path: full_path,
+            cursor_excerpt: snapshot
+                .text_for_range(context_range)
+                .collect::<String>()
+                .into(),
+            editable_range_in_excerpt: (editable_range.start - context_start_offset)
+                ..(editable_offset_range.end - context_start_offset),
+            cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
         };
 
-        // let response = perform_predict_edits(PerformPredictEditsParams {
-        //     client,
-        //     llm_token,
-        //     app_version,
-        //     body,
-        // })
-        // .await;
+        if let Some(debug_tx) = &debug_tx {
+            debug_tx
+                .unbounded_send(DebugEvent::EditPredictionStarted(
+                    EditPredictionStartedDebugEvent {
+                        buffer: buffer.downgrade(),
+                        prompt: Some(serde_json::to_string(&inputs).unwrap()),
+                        position,
+                    },
+                ))
+                .ok();
+        }
 
         let (response, usage) = match response {
             Ok(response) => response,
@@ -189,6 +193,18 @@ pub(crate) fn request_prediction_with_zeta1(
             .ok();
         }
 
+        if let Some(debug_tx) = &debug_tx {
+            debug_tx
+                .unbounded_send(DebugEvent::EditPredictionFinished(
+                    EditPredictionFinishedDebugEvent {
+                        buffer: buffer.downgrade(),
+                        model_output: Some(response.output_excerpt.clone()),
+                        position,
+                    },
+                ))
+                .ok();
+        }
+
         let edit_prediction = process_completion_response(
             response,
             buffer,
@@ -226,7 +242,7 @@ fn process_completion_response(
     buffer: Entity<Buffer>,
     snapshot: &BufferSnapshot,
     editable_range: Range<usize>,
-    inputs: EditPredictionInputs,
+    inputs: ZetaPromptInput,
     buffer_snapshotted_at: Instant,
     received_response_at: Instant,
     cx: &AsyncApp,

crates/edit_prediction/src/zeta2.rs 🔗

@@ -3,46 +3,39 @@ use crate::EvalCacheEntryKind;
 use crate::open_ai_response::text_from_response;
 use crate::prediction::EditPredictionResult;
 use crate::{
-    DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
-    EditPredictionRequestedDebugEvent, EditPredictionStore,
+    DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, EditPredictionId,
+    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
 };
-use anyhow::{Result, anyhow, bail};
-use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
-use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
-use cloud_zeta2_prompt::CURSOR_MARKER;
-use edit_prediction_context::{EditPredictionExcerpt, Line};
-use edit_prediction_context::{RelatedExcerpt, RelatedFile};
-use futures::channel::oneshot;
-use gpui::{Entity, Task, prelude::*};
-use language::{Anchor, BufferSnapshot};
-use language::{Buffer, Point, ToOffset as _, ToPoint};
-use project::{Project, ProjectItem as _};
+use anyhow::{Result, anyhow};
+use cloud_llm_client::EditPredictionRejectReason;
+use gpui::{Task, prelude::*};
+use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
 use release_channel::AppVersion;
-use std::{
-    env,
-    path::Path,
-    sync::Arc,
-    time::{Duration, Instant},
-};
+use std::{path::Path, sync::Arc, time::Instant};
+use zeta_prompt::CURSOR_MARKER;
+use zeta_prompt::format_zeta_prompt;
+
+const MAX_CONTEXT_TOKENS: usize = 150;
+const MAX_REWRITE_TOKENS: usize = 350;
 
 pub fn request_prediction_with_zeta2(
     store: &mut EditPredictionStore,
-    project: &Entity<Project>,
-    active_buffer: &Entity<Buffer>,
-    active_snapshot: BufferSnapshot,
-    position: Anchor,
-    events: Vec<Arc<Event>>,
-    mut included_files: Vec<RelatedFile>,
-    trigger: PredictEditsRequestTrigger,
+    EditPredictionModelInput {
+        buffer,
+        snapshot,
+        position,
+        related_files,
+        events,
+        debug_tx,
+        ..
+    }: EditPredictionModelInput,
     cx: &mut Context<EditPredictionStore>,
 ) -> Task<Result<Option<EditPredictionResult>>> {
-    let options = store.options.clone();
     let buffer_snapshotted_at = Instant::now();
 
-    let Some((excerpt_path, active_project_path)) = active_snapshot
+    let Some(excerpt_path) = snapshot
         .file()
         .map(|file| -> Arc<Path> { file.full_path(cx).into() })
-        .zip(active_buffer.read(cx).project_path(cx))
     else {
         return Task::ready(Err(anyhow!("No file path for excerpt")));
     };
@@ -50,148 +43,35 @@ pub fn request_prediction_with_zeta2(
     let client = store.client.clone();
     let llm_token = store.llm_token.clone();
     let app_version = AppVersion::global(cx);
-    let debug_tx = store.debug_tx.clone();
-
-    let file = active_buffer.read(cx).file();
-
-    let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
-
-    // TODO data collection
-    let can_collect_data = file
-        .as_ref()
-        .map_or(false, |file| store.can_collect_file(project, file, cx));
 
     #[cfg(feature = "eval-support")]
     let eval_cache = store.eval_cache.clone();
 
     let request_task = cx.background_spawn({
-        let active_buffer = active_buffer.clone();
         async move {
-            let cursor_offset = position.to_offset(&active_snapshot);
-            let cursor_point = cursor_offset.to_point(&active_snapshot);
-
-            let before_retrieval = Instant::now();
-
-            let excerpt_options = options.context;
-
-            let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
-                cursor_point,
-                &active_snapshot,
-                &excerpt_options,
-            ) else {
-                return Ok((None, None));
-            };
-
-            let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
-                ..active_snapshot.anchor_before(excerpt.range.end);
-            let related_excerpt = RelatedExcerpt {
-                anchor_range: excerpt_anchor_range.clone(),
-                point_range: Point::new(excerpt.line_range.start.0, 0)
-                    ..Point::new(excerpt.line_range.end.0, 0),
-                text: active_snapshot.as_rope().slice(excerpt.range),
-            };
-
-            if let Some(buffer_ix) = included_files
-                .iter()
-                .position(|file| file.buffer.entity_id() == active_buffer.entity_id())
-            {
-                let file = &mut included_files[buffer_ix];
-                file.excerpts.push(related_excerpt);
-                file.merge_excerpts();
-                let last_ix = included_files.len() - 1;
-                included_files.swap(buffer_ix, last_ix);
-            } else {
-                let active_file = RelatedFile {
-                    path: active_project_path,
-                    buffer: active_buffer.downgrade(),
-                    excerpts: vec![related_excerpt],
-                    max_row: active_snapshot.max_point().row,
-                };
-                included_files.push(active_file);
-            }
-
-            let included_files = included_files
-                .iter()
-                .map(|related_file| predict_edits_v3::RelatedFile {
-                    path: Arc::from(related_file.path.path.as_std_path()),
-                    max_row: Line(related_file.max_row),
-                    excerpts: related_file
-                        .excerpts
-                        .iter()
-                        .map(|excerpt| predict_edits_v3::Excerpt {
-                            start_line: Line(excerpt.point_range.start.row),
-                            text: excerpt.text.to_string().into(),
-                        })
-                        .collect(),
-                })
-                .collect::<Vec<_>>();
-
-            let cloud_request = predict_edits_v3::PredictEditsRequest {
-                excerpt_path,
-                excerpt: String::new(),
-                excerpt_line_range: Line(0)..Line(0),
-                excerpt_range: 0..0,
-                cursor_point: predict_edits_v3::Point {
-                    line: predict_edits_v3::Line(cursor_point.row),
-                    column: cursor_point.column,
-                },
-                related_files: included_files,
+            let cursor_offset = position.to_offset(&snapshot);
+            let (editable_offset_range, prompt_input) = zeta2_prompt_input(
+                &snapshot,
+                related_files,
                 events,
-                can_collect_data,
-                debug_info: debug_tx.is_some(),
-                prompt_max_bytes: Some(options.max_prompt_bytes),
-                prompt_format: options.prompt_format,
-                excerpt_parent: None,
-                git_info: None,
-                trigger,
-            };
-
-            let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
-
-            let inputs = EditPredictionInputs {
-                included_files: cloud_request.related_files,
-                events: cloud_request.events,
-                cursor_point: cloud_request.cursor_point,
-                cursor_path: cloud_request.excerpt_path,
-            };
-
-            let retrieval_time = Instant::now() - before_retrieval;
+                excerpt_path,
+                cursor_offset,
+            );
 
-            let debug_response_tx = if let Some(debug_tx) = &debug_tx {
-                let (response_tx, response_rx) = oneshot::channel();
+            let prompt = format_zeta_prompt(&prompt_input);
 
+            if let Some(debug_tx) = &debug_tx {
                 debug_tx
-                    .unbounded_send(DebugEvent::EditPredictionRequested(
-                        EditPredictionRequestedDebugEvent {
-                            inputs: inputs.clone(),
-                            retrieval_time,
-                            buffer: active_buffer.downgrade(),
-                            local_prompt: match prompt_result.as_ref() {
-                                Ok(prompt) => Ok(prompt.clone()),
-                                Err(err) => Err(err.to_string()),
-                            },
+                    .unbounded_send(DebugEvent::EditPredictionStarted(
+                        EditPredictionStartedDebugEvent {
+                            buffer: buffer.downgrade(),
+                            prompt: Some(prompt.clone()),
                             position,
-                            response_rx,
                         },
                     ))
                     .ok();
-                Some(response_tx)
-            } else {
-                None
-            };
-
-            if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
-                if let Some(debug_response_tx) = debug_response_tx {
-                    debug_response_tx
-                        .send((Err("Request skipped".to_string()), Duration::ZERO))
-                        .ok();
-                }
-                anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
             }
 
-            let prompt = prompt_result?;
-            let generation_params =
-                cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
             let request = open_ai::Request {
                 model: EDIT_PREDICTIONS_MODEL_ID.clone(),
                 messages: vec![open_ai::RequestMessage::User {
@@ -199,8 +79,8 @@ pub fn request_prediction_with_zeta2(
                 }],
                 stream: false,
                 max_completion_tokens: None,
-                stop: generation_params.stop.unwrap_or_default(),
-                temperature: generation_params.temperature.or(Some(0.7)),
+                stop: Default::default(),
+                temperature: Default::default(),
                 tool_choice: None,
                 parallel_tool_calls: None,
                 tools: vec![],
@@ -210,7 +90,6 @@ pub fn request_prediction_with_zeta2(
 
             log::trace!("Sending edit prediction request");
 
-            let before_request = Instant::now();
             let response = EditPredictionStore::send_raw_llm_request(
                 request,
                 client,
@@ -223,68 +102,53 @@ pub fn request_prediction_with_zeta2(
             )
             .await;
             let received_response_at = Instant::now();
-            let request_time = received_response_at - before_request;
 
             log::trace!("Got edit prediction response");
 
-            if let Some(debug_response_tx) = debug_response_tx {
-                debug_response_tx
-                    .send((
-                        response
-                            .as_ref()
-                            .map_err(|err| err.to_string())
-                            .map(|response| response.0.clone()),
-                        request_time,
-                    ))
-                    .ok();
-            }
-
             let (res, usage) = response?;
             let request_id = EditPredictionId(res.id.clone().into());
             let Some(mut output_text) = text_from_response(res) else {
                 return Ok((Some((request_id, None)), usage));
             };
 
+            if let Some(debug_tx) = &debug_tx {
+                debug_tx
+                    .unbounded_send(DebugEvent::EditPredictionFinished(
+                        EditPredictionFinishedDebugEvent {
+                            buffer: buffer.downgrade(),
+                            position,
+                            model_output: Some(output_text.clone()),
+                        },
+                    ))
+                    .ok();
+            }
+
             if output_text.contains(CURSOR_MARKER) {
                 log::trace!("Stripping out {CURSOR_MARKER} from response");
                 output_text = output_text.replace(CURSOR_MARKER, "");
             }
 
-            let get_buffer_from_context = |path: &Path| {
-                if Some(path) == active_file_full_path.as_deref() {
-                    Some((
-                        &active_snapshot,
-                        std::slice::from_ref(&excerpt_anchor_range),
-                    ))
-                } else {
-                    None
-                }
-            };
-
-            let (_, edits) = match options.prompt_format {
-                PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
-                    if output_text.contains("--- a/\n+++ b/\nNo edits") {
-                        let edits = vec![];
-                        (&active_snapshot, edits)
-                    } else {
-                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
-                    }
-                }
-                PromptFormat::OldTextNewText => {
-                    crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
-                }
-                _ => {
-                    bail!("unsupported prompt format {}", options.prompt_format)
-                }
-            };
+            let old_text = snapshot
+                .text_for_range(editable_offset_range.clone())
+                .collect::<String>();
+            let edits: Vec<_> = language::text_diff(&old_text, &output_text)
+                .into_iter()
+                .map(|(range, text)| {
+                    (
+                        snapshot.anchor_after(editable_offset_range.start + range.start)
+                            ..snapshot.anchor_before(editable_offset_range.start + range.end),
+                        text,
+                    )
+                })
+                .collect();
 
             anyhow::Ok((
                 Some((
                     request_id,
                     Some((
-                        inputs,
-                        active_buffer,
-                        active_snapshot.clone(),
+                        prompt_input,
+                        buffer,
+                        snapshot.clone(),
                         edits,
                         received_response_at,
                     )),
@@ -325,3 +189,40 @@ pub fn request_prediction_with_zeta2(
         ))
     })
 }
+
+pub fn zeta2_prompt_input(
+    snapshot: &language::BufferSnapshot,
+    related_files: Arc<[zeta_prompt::RelatedFile]>,
+    events: Vec<Arc<zeta_prompt::Event>>,
+    excerpt_path: Arc<Path>,
+    cursor_offset: usize,
+) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
+    let cursor_point = cursor_offset.to_point(snapshot);
+
+    let (editable_range, context_range) =
+        crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
+            cursor_point,
+            snapshot,
+            MAX_CONTEXT_TOKENS,
+            MAX_REWRITE_TOKENS,
+        );
+
+    let context_start_offset = context_range.start.to_offset(snapshot);
+    let editable_offset_range = editable_range.to_offset(snapshot);
+    let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
+    let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
+        ..(editable_offset_range.end - context_start_offset);
+
+    let prompt_input = zeta_prompt::ZetaPromptInput {
+        cursor_path: excerpt_path,
+        cursor_excerpt: snapshot
+            .text_for_range(context_range)
+            .collect::<String>()
+            .into(),
+        editable_range_in_excerpt,
+        cursor_offset_in_excerpt,
+        events,
+        related_files,
+    };
+    (editable_offset_range, prompt_input)
+}

crates/edit_prediction_cli/Cargo.toml 🔗

@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
 workspace = true
 
 [[bin]]
-name = "ep_cli"
+name = "ep"
 path = "src/main.rs"
 
 [dependencies]
@@ -20,10 +20,9 @@ chrono.workspace = true
 clap.workspace = true
 client.workspace = true
 cloud_llm_client.workspace= true
-cloud_zeta2_prompt.workspace = true
 collections.workspace = true
 debug_adapter_extension.workspace = true
-edit_prediction_context.workspace = true
+dirs.workspace = true
 extension.workspace = true
 fs.workspace = true
 futures.workspace = true
@@ -51,12 +50,21 @@ smol.workspace = true
 sqlez.workspace = true
 sqlez_macros.workspace = true
 terminal_view.workspace = true
-toml.workspace = true
 util.workspace = true
 watch.workspace = true
 edit_prediction = { workspace = true, features = ["eval-support"] }
+wasmtime.workspace = true
+zeta_prompt.workspace = true
 zlog.workspace = true
 
+# Wasmtime is included as a dependency in order to enable the same
+# features that are enabled in Zed.
+#
+# If we don't enable these features we get crashes when creating
+# a Tree-sitter WasmStore.
+[package.metadata.cargo-machete]
+ignored = ["wasmtime"]
+
 [dev-dependencies]
 indoc.workspace = true
 gpui = { workspace = true, features = ["test-support"] }

crates/edit_prediction_cli/src/training/llm_client.rs → crates/edit_prediction_cli/src/anthropic_client.rs 🔗

@@ -5,11 +5,13 @@ use anthropic::{
 use anyhow::Result;
 use http_client::HttpClient;
 use indoc::indoc;
+use reqwest_client::ReqwestClient;
 use sqlez::bindable::Bind;
 use sqlez::bindable::StaticColumnCount;
 use sqlez_macros::sql;
 use std::hash::Hash;
 use std::hash::Hasher;
+use std::path::Path;
 use std::sync::Arc;
 
 pub struct PlainLlmClient {
@@ -18,7 +20,8 @@ pub struct PlainLlmClient {
 }
 
 impl PlainLlmClient {
-    fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
+    fn new() -> Result<Self> {
+        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
         let api_key = std::env::var("ANTHROPIC_API_KEY")
             .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
         Ok(Self {
@@ -29,12 +32,12 @@ impl PlainLlmClient {
 
     async fn generate(
         &self,
-        model: String,
+        model: &str,
         max_tokens: u64,
         messages: Vec<Message>,
     ) -> Result<AnthropicResponse> {
         let request = AnthropicRequest {
-            model,
+            model: model.to_string(),
             max_tokens,
             messages,
             tools: Vec::new(),
@@ -105,11 +108,12 @@ struct SerializableMessage {
 }
 
 impl BatchingLlmClient {
-    fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
+    fn new(cache_path: &Path) -> Result<Self> {
+        let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
         let api_key = std::env::var("ANTHROPIC_API_KEY")
             .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
 
-        let connection = sqlez::connection::Connection::open_file(&cache_path);
+        let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
         let mut statement = sqlez::statement::Statement::prepare(
             &connection,
             indoc! {"
@@ -182,16 +186,16 @@ impl BatchingLlmClient {
 
     async fn generate(
         &self,
-        model: String,
+        model: &str,
         max_tokens: u64,
         messages: Vec<Message>,
     ) -> Result<Option<AnthropicResponse>> {
-        let response = self.lookup(&model, max_tokens, &messages)?;
+        let response = self.lookup(model, max_tokens, &messages)?;
         if let Some(response) = response {
             return Ok(Some(response));
         }
 
-        self.mark_for_batch(&model, max_tokens, &messages)?;
+        self.mark_for_batch(model, max_tokens, &messages)?;
 
         Ok(None)
     }
@@ -258,7 +262,7 @@ impl BatchingLlmClient {
                         }
                     }
                 }
-                log::info!("Uploaded {} successful requests", success_count);
+                log::info!("Downloaded {} successful requests", success_count);
             }
         }
 
@@ -363,23 +367,20 @@ fn message_content_to_string(content: &[RequestContent]) -> String {
         .join("\n")
 }
 
-pub enum LlmClient {
+pub enum AnthropicClient {
     // No batching
     Plain(PlainLlmClient),
     Batch(BatchingLlmClient),
     Dummy,
 }
 
-impl LlmClient {
-    pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
-        Ok(Self::Plain(PlainLlmClient::new(http_client)?))
+impl AnthropicClient {
+    pub fn plain() -> Result<Self> {
+        Ok(Self::Plain(PlainLlmClient::new()?))
     }
 
-    pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
-        Ok(Self::Batch(BatchingLlmClient::new(
-            cache_path,
-            http_client,
-        )?))
+    pub fn batch(cache_path: &Path) -> Result<Self> {
+        Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
     }
 
     #[allow(dead_code)]
@@ -389,29 +390,29 @@ impl LlmClient {
 
     pub async fn generate(
         &self,
-        model: String,
+        model: &str,
         max_tokens: u64,
         messages: Vec<Message>,
     ) -> Result<Option<AnthropicResponse>> {
         match self {
-            LlmClient::Plain(plain_llm_client) => plain_llm_client
+            AnthropicClient::Plain(plain_llm_client) => plain_llm_client
                 .generate(model, max_tokens, messages)
                 .await
                 .map(Some),
-            LlmClient::Batch(batching_llm_client) => {
+            AnthropicClient::Batch(batching_llm_client) => {
                 batching_llm_client
                     .generate(model, max_tokens, messages)
                     .await
             }
-            LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
         }
     }
 
     pub async fn sync_batches(&self) -> Result<()> {
         match self {
-            LlmClient::Plain(_) => Ok(()),
-            LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
-            LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+            AnthropicClient::Plain(_) => Ok(()),
+            AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
+            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
         }
     }
 }

crates/edit_prediction_cli/src/evaluate.rs 🔗

@@ -1,641 +0,0 @@
-use crate::metrics::{self, Scores};
-use std::{
-    collections::HashMap,
-    io::{IsTerminal, Write},
-    sync::Arc,
-};
-
-use anyhow::Result;
-use edit_prediction::{EditPredictionStore, udiff::DiffLine};
-use gpui::{AsyncApp, Entity};
-use project::Project;
-use util::ResultExt as _;
-
-use crate::{
-    EvaluateArguments, PredictionOptions,
-    example::{Example, NamedExample},
-    headless::ZetaCliAppState,
-    paths::print_run_data_dir,
-    predict::{PredictionDetails, perform_predict, setup_store},
-};
-
-#[derive(Debug)]
-pub(crate) struct ExecutionData {
-    execution_id: String,
-    diff: String,
-    reasoning: String,
-}
-
-pub async fn run_evaluate(
-    args: EvaluateArguments,
-    app_state: &Arc<ZetaCliAppState>,
-    cx: &mut AsyncApp,
-) {
-    if args.example_paths.is_empty() {
-        eprintln!("No examples provided");
-        return;
-    }
-
-    let all_tasks = args.example_paths.into_iter().map(|path| {
-        let options = args.options.clone();
-        let app_state = app_state.clone();
-        let example = NamedExample::load(&path).expect("Failed to load example");
-
-        cx.spawn(async move |cx| {
-            let project = example.setup_project(&app_state, cx).await.unwrap();
-
-            let providers = (0..args.repetitions)
-                .map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap())
-                .collect::<Vec<_>>();
-
-            let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
-
-            let tasks = providers
-                .into_iter()
-                .enumerate()
-                .map(move |(repetition_ix, store)| {
-                    let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
-                    let example = example.clone();
-                    let project = project.clone();
-                    let options = options.clone();
-
-                    cx.spawn(async move |cx| {
-                        let name = example.name.clone();
-                        run_evaluate_one(
-                            example,
-                            repetition_ix,
-                            project,
-                            store,
-                            options,
-                            !args.skip_prediction,
-                            cx,
-                        )
-                        .await
-                        .map_err(|err| (err, name, repetition_ix))
-                    })
-                });
-            futures::future::join_all(tasks).await
-        })
-    });
-    let all_results = futures::future::join_all(all_tasks).await;
-
-    write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
-    if let Some(mut output_file) =
-        std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
-    {
-        write_aggregated_scores(&mut output_file, &all_results).log_err();
-    };
-
-    if args.repetitions > 1 {
-        if let Err(e) = write_bucketed_analysis(&all_results) {
-            eprintln!("Failed to write bucketed analysis: {:?}", e);
-        }
-    }
-
-    print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
-}
-
-fn write_aggregated_scores(
-    w: &mut impl std::io::Write,
-    all_results: &Vec<
-        Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
-    >,
-) -> Result<()> {
-    let mut successful = Vec::new();
-    let mut failed_count = 0;
-
-    for result in all_results.iter().flatten() {
-        match result {
-            Ok((eval_result, _execution_data)) => successful.push(eval_result),
-            Err((err, name, repetition_ix)) => {
-                if failed_count == 0 {
-                    writeln!(w, "## Errors\n")?;
-                }
-
-                failed_count += 1;
-                writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
-            }
-        }
-    }
-
-    if successful.len() > 1 {
-        let edit_scores = successful
-            .iter()
-            .filter_map(|r| r.edit_scores.clone())
-            .collect::<Vec<_>>();
-        let has_edit_predictions = edit_scores.len() > 0;
-        let aggregated_result = EvaluationResult {
-            context_scores: Scores::aggregate(successful.iter().map(|r| &r.context_scores)),
-            edit_scores: has_edit_predictions.then(|| EditScores::aggregate(&edit_scores)),
-            prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
-            generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
-                / successful.len(),
-        };
-
-        writeln!(w, "\n{}", "-".repeat(80))?;
-        writeln!(w, "\n## TOTAL SCORES")?;
-        writeln!(w, "{:#}", aggregated_result)?;
-    }
-
-    if successful.len() + failed_count > 1 {
-        writeln!(
-            w,
-            "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
-            successful.len(),
-            successful.len() + failed_count,
-            (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
-        )?;
-    }
-
-    Ok(())
-}
-
-pub async fn run_evaluate_one(
-    example: NamedExample,
-    repetition_ix: Option<u16>,
-    project: Entity<Project>,
-    store: Entity<EditPredictionStore>,
-    prediction_options: PredictionOptions,
-    predict: bool,
-    cx: &mut AsyncApp,
-) -> Result<(EvaluationResult, ExecutionData)> {
-    let predict_result = perform_predict(
-        example.clone(),
-        project,
-        store,
-        repetition_ix,
-        prediction_options,
-        cx,
-    )
-    .await?;
-
-    let evaluation_result = evaluate(&example.example, &predict_result, predict);
-
-    if repetition_ix.is_none() {
-        write_eval_result(
-            &example,
-            &predict_result,
-            &evaluation_result,
-            &mut std::io::stdout(),
-            std::io::stdout().is_terminal(),
-            predict,
-        )?;
-    }
-
-    if let Some(mut results_file) =
-        std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
-    {
-        write_eval_result(
-            &example,
-            &predict_result,
-            &evaluation_result,
-            &mut results_file,
-            false,
-            predict,
-        )
-        .log_err();
-    }
-
-    let execution_data = ExecutionData {
-        execution_id: if let Some(rep_ix) = repetition_ix {
-            format!("{:03}", rep_ix)
-        } else {
-            example.name.clone()
-        },
-        diff: predict_result.diff.clone(),
-        reasoning: std::fs::read_to_string(
-            predict_result
-                .run_example_dir
-                .join("prediction_response.md"),
-        )
-        .unwrap_or_default(),
-    };
-
-    anyhow::Ok((evaluation_result, execution_data))
-}
-
-fn write_eval_result(
-    example: &NamedExample,
-    predictions: &PredictionDetails,
-    evaluation_result: &EvaluationResult,
-    out: &mut impl Write,
-    use_color: bool,
-    predict: bool,
-) -> Result<()> {
-    if predict {
-        writeln!(
-            out,
-            "## Expected edit prediction:\n\n```diff\n{}\n```\n",
-            compare_diffs(
-                &example.example.expected_patch,
-                &predictions.diff,
-                use_color
-            )
-        )?;
-        writeln!(
-            out,
-            "## Actual edit prediction:\n\n```diff\n{}\n```\n",
-            compare_diffs(
-                &predictions.diff,
-                &example.example.expected_patch,
-                use_color
-            )
-        )?;
-    }
-
-    writeln!(out, "{:#}", evaluation_result)?;
-
-    anyhow::Ok(())
-}
-
-#[derive(Debug, Default, Clone)]
-pub struct EditScores {
-    pub line_match: Scores,
-    pub chr_f: f64,
-}
-
-impl EditScores {
-    pub fn aggregate(scores: &[EditScores]) -> EditScores {
-        let line_match = Scores::aggregate(scores.iter().map(|s| &s.line_match));
-        let chr_f = scores.iter().map(|s| s.chr_f).sum::<f64>() / scores.len() as f64;
-
-        EditScores { line_match, chr_f }
-    }
-}
-
-#[derive(Debug, Default)]
-pub struct EvaluationResult {
-    pub edit_scores: Option<EditScores>,
-    pub context_scores: Scores,
-    pub prompt_len: usize,
-    pub generated_len: usize,
-}
-
-impl std::fmt::Display for EvaluationResult {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        if f.alternate() {
-            self.fmt_table(f)
-        } else {
-            self.fmt_markdown(f)
-        }
-    }
-}
-
-impl EvaluationResult {
-    fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        write!(
-            f,
-            r#"
-### Context Scores
-{}
-"#,
-            self.context_scores.to_markdown(),
-        )?;
-        if let Some(scores) = &self.edit_scores {
-            write!(
-                f,
-                r#"
-                ### Edit Prediction Scores
-                {}"#,
-                scores.line_match.to_markdown()
-            )?;
-        }
-        Ok(())
-    }
-
-    fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        writeln!(f, "#### Prompt Statistics")?;
-        writeln!(f, "─────────────────────────")?;
-        writeln!(f, "Prompt_len  Generated_len")?;
-        writeln!(f, "─────────────────────────")?;
-        writeln!(f, "{:<11} {:<14}", self.prompt_len, self.generated_len,)?;
-        writeln!(f)?;
-        writeln!(f)?;
-        writeln!(f, "#### Performance Scores")?;
-        writeln!(
-            f,
-            "──────────────────────────────────────────────────────────────────"
-        )?;
-        writeln!(
-            f,
-            "                   TP     FP     FN     Precision   Recall     F1"
-        )?;
-        writeln!(
-            f,
-            "──────────────────────────────────────────────────────────────────"
-        )?;
-        writeln!(
-            f,
-            "Context Retrieval  {:<6} {:<6} {:<6} {:>8.2}  {:>7.2}  {:>6.2}",
-            self.context_scores.true_positives,
-            self.context_scores.false_positives,
-            self.context_scores.false_negatives,
-            self.context_scores.precision() * 100.0,
-            self.context_scores.recall() * 100.0,
-            self.context_scores.f1_score() * 100.0
-        )?;
-        if let Some(edit_scores) = &self.edit_scores {
-            let line_match = &edit_scores.line_match;
-            writeln!(f, "Edit Prediction")?;
-            writeln!(
-                f,
-                "  ├─ exact lines   {:<6} {:<6} {:<6} {:>8.2}  {:>7.2}  {:>6.2}",
-                line_match.true_positives,
-                line_match.false_positives,
-                line_match.false_negatives,
-                line_match.precision() * 100.0,
-                line_match.recall() * 100.0,
-                line_match.f1_score() * 100.0
-            )?;
-            writeln!(
-                f,
-                "  └─ diff chrF     {:<6} {:<6} {:<6} {:>8} {:>8}  {:>6.2}",
-                "-", "-", "-", "-", "-", edit_scores.chr_f
-            )?;
-        }
-        Ok(())
-    }
-}
-
-fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
-    let mut eval_result = EvaluationResult {
-        prompt_len: preds.prompt_len,
-        generated_len: preds.generated_len,
-        ..Default::default()
-    };
-
-    if predict {
-        // todo: alternatives for patches
-        let expected_patch = example
-            .expected_patch
-            .lines()
-            .map(DiffLine::parse)
-            .collect::<Vec<_>>();
-        let actual_patch = preds.diff.lines().map(DiffLine::parse).collect::<Vec<_>>();
-
-        let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
-        let chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch);
-
-        eval_result.edit_scores = Some(EditScores { line_match, chr_f });
-    }
-
-    eval_result
-}
-
-/// Return annotated `patch_a` so that:
-/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
-/// Additions and deletions that are present in `patch_b` will be highlighted in green.
-pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
-    let green = if use_color { "\x1b[32m✓ " } else { "" };
-    let red = if use_color { "\x1b[31m✗ " } else { "" };
-    let neutral = if use_color { "  " } else { "" };
-    let reset = if use_color { "\x1b[0m" } else { "" };
-    let lines_a = patch_a.lines().map(DiffLine::parse);
-    let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
-
-    let annotated = lines_a
-        .map(|line| match line {
-            DiffLine::Addition(_) | DiffLine::Deletion(_) => {
-                if lines_b.contains(&line) {
-                    format!("{green}{line}{reset}")
-                } else {
-                    format!("{red}{line}{reset}")
-                }
-            }
-            _ => format!("{neutral}{line}{reset}"),
-        })
-        .collect::<Vec<String>>();
-
-    annotated.join("\n")
-}
-
-fn write_bucketed_analysis(
-    all_results: &Vec<
-        Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
-    >,
-) -> Result<()> {
-    #[derive(Debug)]
-    struct EditBucket {
-        diff: String,
-        is_correct: bool,
-        execution_indices: Vec<String>,
-        reasoning_samples: Vec<String>,
-    }
-
-    let mut total_executions = 0;
-    let mut empty_predictions = Vec::new();
-    let mut errors = Vec::new();
-
-    let mut buckets: HashMap<String, EditBucket> = HashMap::new();
-
-    for result in all_results.iter().flatten() {
-        total_executions += 1;
-
-        let (evaluation_result, execution_data) = match result {
-            Ok((eval_result, execution_data)) => {
-                if execution_data.diff.is_empty() {
-                    empty_predictions.push(execution_data);
-                    continue;
-                }
-                (eval_result, execution_data)
-            }
-            Err(err) => {
-                errors.push(err);
-                continue;
-            }
-        };
-
-        buckets
-            .entry(execution_data.diff.clone())
-            .and_modify(|bucket| {
-                bucket
-                    .execution_indices
-                    .push(execution_data.execution_id.clone());
-                bucket
-                    .reasoning_samples
-                    .push(execution_data.reasoning.clone());
-            })
-            .or_insert_with(|| EditBucket {
-                diff: execution_data.diff.clone(),
-                is_correct: {
-                    evaluation_result
-                        .edit_scores
-                        .as_ref()
-                        .map_or(false, |edit_scores| {
-                            edit_scores.line_match.false_positives == 0
-                                && edit_scores.line_match.false_negatives == 0
-                                && edit_scores.line_match.true_positives > 0
-                        })
-                },
-                execution_indices: vec![execution_data.execution_id.clone()],
-                reasoning_samples: vec![execution_data.reasoning.clone()],
-            });
-    }
-
-    let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
-    sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
-        (true, false) => std::cmp::Ordering::Less,
-        (false, true) => std::cmp::Ordering::Greater,
-        _ => b.execution_indices.len().cmp(&a.execution_indices.len()),
-    });
-
-    let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
-    let mut output = std::fs::File::create(&output_path)?;
-
-    writeln!(output, "# Bucketed Edit Analysis\n")?;
-
-    writeln!(output, "## Summary\n")?;
-    writeln!(output, "- **Total executions**: {}", total_executions)?;
-
-    let correct_count: usize = sorted_buckets
-        .iter()
-        .filter(|b| b.is_correct)
-        .map(|b| b.execution_indices.len())
-        .sum();
-
-    let incorrect_count: usize = sorted_buckets
-        .iter()
-        .filter(|b| !b.is_correct)
-        .map(|b| b.execution_indices.len())
-        .sum();
-
-    writeln!(
-        output,
-        "- **Correct predictions**: {} ({:.1}%)",
-        correct_count,
-        (correct_count as f64 / total_executions as f64) * 100.0
-    )?;
-
-    writeln!(
-        output,
-        "- **Incorrect predictions**: {} ({:.1}%)",
-        incorrect_count,
-        (incorrect_count as f64 / total_executions as f64) * 100.0
-    )?;
-
-    writeln!(
-        output,
-        "- **No Predictions**: {} ({:.1}%)",
-        empty_predictions.len(),
-        (empty_predictions.len() as f64 / total_executions as f64) * 100.0
-    )?;
-
-    let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
-    writeln!(
-        output,
-        "- **Unique incorrect edit patterns**: {}\n",
-        unique_incorrect
-    )?;
-
-    writeln!(output, "---\n")?;
-
-    for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
-        if idx == 0 {
-            writeln!(
-                output,
-                "## Correct Predictions ({} occurrences)\n",
-                bucket.execution_indices.len()
-            )?;
-        }
-
-        writeln!(output, "**Predicted Edit:**\n")?;
-        writeln!(output, "```diff")?;
-        writeln!(output, "{}", bucket.diff)?;
-        writeln!(output, "```\n")?;
-
-        writeln!(
-            output,
-            "**Executions:** {}\n",
-            bucket.execution_indices.join(", ")
-        )?;
-        writeln!(output, "---\n")?;
-    }
-
-    for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
-        writeln!(
-            output,
-            "## Incorrect Prediction #{} ({} occurrences)\n",
-            idx + 1,
-            bucket.execution_indices.len()
-        )?;
-
-        writeln!(output, "**Predicted Edit:**\n")?;
-        writeln!(output, "```diff")?;
-        writeln!(output, "{}", bucket.diff)?;
-        writeln!(output, "```\n")?;
-
-        writeln!(
-            output,
-            "**Executions:** {}\n",
-            bucket.execution_indices.join(", ")
-        )?;
-
-        for (exec_id, reasoning) in bucket
-            .execution_indices
-            .iter()
-            .zip(bucket.reasoning_samples.iter())
-        {
-            writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
-        }
-
-        writeln!(output, "\n---\n")?;
-    }
-
-    if !empty_predictions.is_empty() {
-        writeln!(
-            output,
-            "## No Predictions ({} occurrences)\n",
-            empty_predictions.len()
-        )?;
-
-        for execution_data in &empty_predictions {
-            writeln!(
-                output,
-                "{}",
-                fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
-            )?;
-        }
-        writeln!(output, "\n---\n")?;
-    }
-
-    if !errors.is_empty() {
-        writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
-
-        for (err, name, repetition_ix) in &errors {
-            writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
-        }
-        writeln!(output, "\n---\n")?;
-    }
-
-    fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
-        let exec_content = format!(
-            "\n### Execution {} `{}/{}/prediction_response.md`{}",
-            exec_id,
-            crate::paths::RUN_DIR.display(),
-            exec_id,
-            indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
-        );
-        indent_text(&exec_content, 2)
-    }
-
-    fn indent_text(text: &str, spaces: usize) -> String {
-        let indent = " ".repeat(spaces);
-        text.lines()
-            .collect::<Vec<_>>()
-            .join(&format!("\n{}", indent))
-    }
-
-    Ok(())
-}
-
-fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
-    let err = format!("{err:?}")
-        .replace("<edits", "```xml\n<edits")
-        .replace("</edits>", "</edits>\n```");
-    format!(
-        "### ERROR {name}{}\n\n{err}\n",
-        repetition_ix
-            .map(|ix| format!(" [RUN {ix:03}]"))
-            .unwrap_or_default()
-    )
-}

crates/edit_prediction_cli/src/example.rs 🔗

@@ -1,59 +1,103 @@
+use crate::{
+    PredictionProvider, PromptFormat,
+    metrics::ClassificationMetrics,
+    paths::{REPOS_DIR, WORKTREES_DIR},
+};
+use anyhow::{Context as _, Result};
+use edit_prediction::udiff::OpenedBuffers;
+use gpui::Entity;
+use http_client::Url;
+use language::{Anchor, Buffer};
+use project::Project;
+use serde::{Deserialize, Serialize};
+use std::sync::Arc;
 use std::{
     borrow::Cow,
-    cell::RefCell,
-    fmt::{self, Display},
-    fs,
-    hash::Hash,
-    hash::Hasher,
-    io::Write,
+    io::{Read, Write},
     mem,
     path::{Path, PathBuf},
-    sync::{Arc, OnceLock},
 };
+use zeta_prompt::RelatedFile;
 
-use crate::headless::ZetaCliAppState;
-use anyhow::{Context as _, Result, anyhow};
-use clap::ValueEnum;
-use cloud_zeta2_prompt::CURSOR_MARKER;
-use collections::HashMap;
-use edit_prediction::udiff::OpenedBuffers;
-use futures::{
-    AsyncWriteExt as _,
-    lock::{Mutex, OwnedMutexGuard},
-};
-use futures::{FutureExt as _, future::Shared};
-use gpui::{AsyncApp, Entity, Task, http_client::Url};
-use language::{Anchor, Buffer};
-use project::{Project, ProjectPath};
-use pulldown_cmark::CowStr;
-use serde::{Deserialize, Serialize};
-use util::{paths::PathStyle, rel_path::RelPath};
-
-use crate::paths::{REPOS_DIR, WORKTREES_DIR};
-
-const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
-const EDIT_HISTORY_HEADING: &str = "Edit History";
-const CURSOR_POSITION_HEADING: &str = "Cursor Position";
-const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
-const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
-const REPOSITORY_URL_FIELD: &str = "repository_url";
-const REVISION_FIELD: &str = "revision";
-
-#[derive(Debug, Clone)]
-pub struct NamedExample {
-    pub name: String,
-    pub example: Example,
-}
-
-#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
+#[derive(Clone, Debug, Serialize, Deserialize)]
 pub struct Example {
+    #[serde(default)]
+    pub name: String,
     pub repository_url: String,
     pub revision: String,
     pub uncommitted_diff: String,
-    pub cursor_path: PathBuf,
+    pub cursor_path: Arc<Path>,
     pub cursor_position: String,
     pub edit_history: String,
     pub expected_patch: String,
+
+    /// The full content of the file where an edit is being predicted, and the
+    /// actual cursor offset.
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub buffer: Option<ExampleBuffer>,
+
+    /// The context retrieved for the prediction. This requires the worktree to
+    /// be loaded and the language server to be started.
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub context: Option<ExampleContext>,
+
+    /// The input and expected output from the edit prediction model.
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub prompt: Option<ExamplePrompt>,
+
+    /// The actual predictions from the model.
+    #[serde(default, skip_serializing_if = "Vec::is_empty")]
+    pub predictions: Vec<ExamplePrediction>,
+
+    /// The scores, for how well the actual predictions match the expected
+    /// predictions.
+    #[serde(default, skip_serializing_if = "Vec::is_empty")]
+    pub score: Vec<ExampleScore>,
+
+    /// The application state used to process this example.
+    #[serde(skip)]
+    pub state: Option<ExampleState>,
+}
+
+#[derive(Clone, Debug)]
+pub struct ExampleState {
+    pub project: Entity<Project>,
+    pub buffer: Entity<Buffer>,
+    pub cursor_position: Anchor,
+    pub _open_buffers: OpenedBuffers,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExampleContext {
+    pub files: Arc<[RelatedFile]>,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExampleBuffer {
+    pub content: String,
+    pub cursor_row: u32,
+    pub cursor_column: u32,
+    pub cursor_offset: usize,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExamplePrompt {
+    pub input: String,
+    pub expected_output: String,
+    pub format: PromptFormat,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExamplePrediction {
+    pub actual_patch: String,
+    pub actual_output: String,
+    pub provider: PredictionProvider,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExampleScore {
+    pub delta_chr_f: f32,
+    pub line_match: ClassificationMetrics,
 }
 
 impl Example {
@@ -90,485 +134,244 @@ impl Example {
         }
     }
 
-    pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
-        let (repo_owner, repo_name) = self.repo_name()?;
-
-        let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
-        let repo_lock = lock_repo(&repo_dir).await;
+    pub fn worktree_path(&self) -> PathBuf {
+        WORKTREES_DIR
+            .join(&self.name)
+            .join(self.repo_name().unwrap().1.as_ref())
+    }
 
-        if !repo_dir.is_dir() {
-            fs::create_dir_all(&repo_dir)?;
-            run_git(&repo_dir, &["init"]).await?;
-            run_git(
-                &repo_dir,
-                &["remote", "add", "origin", &self.repository_url],
-            )
-            .await?;
-        }
+    pub fn repo_path(&self) -> PathBuf {
+        let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
+        REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
+    }
+}
 
-        // Resolve the example to a revision, fetching it if needed.
-        let revision = run_git(
-            &repo_dir,
-            &["rev-parse", &format!("{}^{{commit}}", self.revision)],
-        )
-        .await;
-        let revision = if let Ok(revision) = revision {
-            revision
+pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
+    let mut examples = Vec::new();
+
+    let stdin_path: PathBuf = PathBuf::from("-");
+
+    let inputs = if inputs.is_empty() {
+        &[stdin_path]
+    } else {
+        inputs
+    };
+
+    for path in inputs {
+        let is_stdin = path.as_path() == Path::new("-");
+        let content = if is_stdin {
+            let mut buffer = String::new();
+            std::io::stdin()
+                .read_to_string(&mut buffer)
+                .expect("Failed to read from stdin");
+            buffer
         } else {
-            if run_git(
-                &repo_dir,
-                &["fetch", "--depth", "1", "origin", &self.revision],
-            )
-            .await
-            .is_err()
-            {
-                run_git(&repo_dir, &["fetch", "origin"]).await?;
-            }
-            let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
-            if revision != self.revision {
-                run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
-            }
-            revision
+            std::fs::read_to_string(path)
+                .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
         };
-
-        // Create the worktree for this example if needed.
-        let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
-        if worktree_path.is_dir() {
-            run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
-            run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
-            run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
+        let filename = path.file_stem().unwrap().to_string_lossy().to_string();
+        let ext = if !is_stdin {
+            path.extension()
+                .map(|ext| ext.to_string_lossy().to_string())
+                .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
         } else {
-            let worktree_path_string = worktree_path.to_string_lossy();
-            run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
-            run_git(
-                &repo_dir,
-                &["worktree", "add", "-f", &worktree_path_string, &file_name],
-            )
-            .await?;
-        }
-        drop(repo_lock);
-
-        // Apply the uncommitted diff for this example.
-        if !self.uncommitted_diff.is_empty() {
-            let mut apply_process = smol::process::Command::new("git")
-                .current_dir(&worktree_path)
-                .args(&["apply", "-"])
-                .stdin(std::process::Stdio::piped())
-                .spawn()?;
-
-            let mut stdin = apply_process.stdin.take().unwrap();
-            stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
-            stdin.close().await?;
-            drop(stdin);
-
-            let apply_result = apply_process.output().await?;
-            if !apply_result.status.success() {
-                anyhow::bail!(
-                    "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
-                    apply_result.status,
-                    String::from_utf8_lossy(&apply_result.stderr),
-                    String::from_utf8_lossy(&apply_result.stdout),
-                );
+            "jsonl".to_string()
+        };
+
+        match ext.as_ref() {
+            "json" => {
+                let mut example =
+                    serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
+                        panic!("Failed to parse example file: {}\n{error}", path.display())
+                    });
+                if example.name.is_empty() {
+                    example.name = filename;
+                }
+                examples.push(example);
+            }
+            "jsonl" => examples.extend(
+                content
+                    .lines()
+                    .enumerate()
+                    .map(|(line_ix, line)| {
+                        let mut example =
+                            serde_json::from_str::<Example>(line).unwrap_or_else(|_| {
+                                panic!(
+                                    "Failed to parse example on {}:{}",
+                                    path.display(),
+                                    line_ix + 1
+                                )
+                            });
+                        if example.name.is_empty() {
+                            example.name = format!("{filename}-{line_ix}")
+                        }
+                        example
+                    })
+                    .collect::<Vec<Example>>(),
+            ),
+            "md" => {
+                examples.push(parse_markdown_example(filename, &content).unwrap());
+            }
+            ext => {
+                panic!("{} has invalid example extension `{ext}`", path.display())
             }
         }
-
-        Ok(worktree_path)
-    }
-
-    pub fn unique_name(&self) -> String {
-        let mut hasher = std::hash::DefaultHasher::new();
-        self.hash(&mut hasher);
-        let disambiguator = hasher.finish();
-        let hash = format!("{:04x}", disambiguator);
-        format!("{}_{}", &self.revision[..8], &hash[..4])
     }
+    examples
 }
 
-pub type ActualExcerpt = Excerpt;
-
-#[derive(Clone, Debug, Serialize, Deserialize)]
-pub struct Excerpt {
-    pub path: PathBuf,
-    pub text: String,
-}
-
-#[derive(ValueEnum, Debug, Clone)]
-pub enum ExampleFormat {
-    Json,
-    Toml,
-    Md,
+pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
+    let mut content = String::new();
+    for example in examples {
+        let line = serde_json::to_string(example).unwrap();
+        content.push_str(&line);
+        content.push('\n');
+    }
+    if let Some(output_path) = output_path {
+        std::fs::write(output_path, content).expect("Failed to write examples");
+    } else {
+        std::io::stdout().write_all(&content.as_bytes()).unwrap();
+    }
 }
 
-impl NamedExample {
-    pub fn load(path: impl AsRef<Path>) -> Result<Self> {
-        let path = path.as_ref();
-        let content = std::fs::read_to_string(path)?;
-        let ext = path.extension();
-
-        match ext.and_then(|s| s.to_str()) {
-            Some("json") => Ok(Self {
-                name: path.file_stem().unwrap_or_default().display().to_string(),
-                example: serde_json::from_str(&content)?,
-            }),
-            Some("toml") => Ok(Self {
-                name: path.file_stem().unwrap_or_default().display().to_string(),
-                example: toml::from_str(&content)?,
-            }),
-            Some("md") => Self::parse_md(&content),
-            Some(_) => {
-                anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
-            }
-            None => {
-                anyhow::bail!(
-                    "Failed to determine example type since the file does not have an extension."
-                );
-            }
-        }
+fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
+    use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
+
+    const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
+    const EDIT_HISTORY_HEADING: &str = "Edit History";
+    const CURSOR_POSITION_HEADING: &str = "Cursor Position";
+    const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
+    const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
+    const REPOSITORY_URL_FIELD: &str = "repository_url";
+    const REVISION_FIELD: &str = "revision";
+
+    let parser = Parser::new(input);
+
+    let mut example = Example {
+        name: id,
+        repository_url: String::new(),
+        revision: String::new(),
+        uncommitted_diff: String::new(),
+        cursor_path: PathBuf::new().into(),
+        cursor_position: String::new(),
+        edit_history: String::new(),
+        expected_patch: String::new(),
+        buffer: None,
+        context: None,
+        prompt: None,
+        predictions: Vec::new(),
+        score: Vec::new(),
+        state: None,
+    };
+
+    let mut name = String::new();
+    let mut text = String::new();
+    let mut block_info: CowStr = "".into();
+
+    #[derive(PartialEq)]
+    enum Section {
+        UncommittedDiff,
+        EditHistory,
+        CursorPosition,
+        ExpectedExcerpts,
+        ExpectedPatch,
+        Other,
     }
 
-    pub fn parse_md(input: &str) -> Result<Self> {
-        use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
-
-        let parser = Parser::new(input);
-
-        let mut named = NamedExample {
-            name: String::new(),
-            example: Example {
-                repository_url: String::new(),
-                revision: String::new(),
-                uncommitted_diff: String::new(),
-                cursor_path: PathBuf::new(),
-                cursor_position: String::new(),
-                edit_history: String::new(),
-                expected_patch: String::new(),
-            },
-        };
+    let mut current_section = Section::Other;
 
-        let mut text = String::new();
-        let mut block_info: CowStr = "".into();
-
-        #[derive(PartialEq)]
-        enum Section {
-            UncommittedDiff,
-            EditHistory,
-            CursorPosition,
-            ExpectedExcerpts,
-            ExpectedPatch,
-            Other,
-        }
+    for event in parser {
+        match event {
+            Event::Text(line) => {
+                text.push_str(&line);
 
-        let mut current_section = Section::Other;
-
-        for event in parser {
-            match event {
-                Event::Text(line) => {
-                    text.push_str(&line);
-
-                    if !named.name.is_empty()
-                        && current_section == Section::Other
-                        // in h1 section
-                        && let Some((field, value)) = line.split_once('=')
-                    {
-                        match field.trim() {
-                            REPOSITORY_URL_FIELD => {
-                                named.example.repository_url = value.trim().to_string();
-                            }
-                            REVISION_FIELD => {
-                                named.example.revision = value.trim().to_string();
-                            }
-                            _ => {}
-                        }
-                    }
-                }
-                Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
-                    if !named.name.is_empty() {
-                        anyhow::bail!(
-                            "Found multiple H1 headings. There should only be one with the name of the example."
-                        );
-                    }
-                    named.name = mem::take(&mut text);
-                }
-                Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
-                    let title = mem::take(&mut text);
-                    current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
-                        Section::UncommittedDiff
-                    } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
-                        Section::EditHistory
-                    } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
-                        Section::CursorPosition
-                    } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
-                        Section::ExpectedPatch
-                    } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
-                        Section::ExpectedExcerpts
-                    } else {
-                        Section::Other
-                    };
-                }
-                Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
-                    mem::take(&mut text);
-                }
-                Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
-                    mem::take(&mut text);
-                }
-                Event::End(TagEnd::Heading(level)) => {
-                    anyhow::bail!("Unexpected heading level: {level}");
-                }
-                Event::Start(Tag::CodeBlock(kind)) => {
-                    match kind {
-                        CodeBlockKind::Fenced(info) => {
-                            block_info = info;
-                        }
-                        CodeBlockKind::Indented => {
-                            anyhow::bail!("Unexpected indented codeblock");
-                        }
-                    };
-                }
-                Event::Start(_) => {
-                    text.clear();
-                    block_info = "".into();
-                }
-                Event::End(TagEnd::CodeBlock) => {
-                    let block_info = block_info.trim();
-                    match current_section {
-                        Section::UncommittedDiff => {
-                            named.example.uncommitted_diff = mem::take(&mut text);
-                        }
-                        Section::EditHistory => {
-                            named.example.edit_history.push_str(&mem::take(&mut text));
-                        }
-                        Section::CursorPosition => {
-                            named.example.cursor_path = block_info.into();
-                            named.example.cursor_position = mem::take(&mut text);
-                        }
-                        Section::ExpectedExcerpts => {
-                            mem::take(&mut text);
+                if let Some((field, value)) = line.split_once('=') {
+                    match field.trim() {
+                        REPOSITORY_URL_FIELD => {
+                            example.repository_url = value.trim().to_string();
                         }
-                        Section::ExpectedPatch => {
-                            named.example.expected_patch = mem::take(&mut text);
+                        REVISION_FIELD => {
+                            example.revision = value.trim().to_string();
                         }
-                        Section::Other => {}
+                        _ => {}
                     }
                 }
-                _ => {}
             }
-        }
-
-        if named.example.cursor_path.as_path() == Path::new("")
-            || named.example.cursor_position.is_empty()
-        {
-            anyhow::bail!("Missing cursor position codeblock");
-        }
-
-        Ok(named)
-    }
-
-    pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
-        match format {
-            ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
-            ExampleFormat::Toml => {
-                Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
+            Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
+                if !name.is_empty() {
+                    anyhow::bail!(
+                        "Found multiple H1 headings. There should only be one with the name of the example."
+                    );
+                }
+                name = mem::take(&mut text);
             }
-            ExampleFormat::Md => Ok(write!(out, "{}", self)?),
-        }
-    }
-
-    pub async fn setup_project(
-        &self,
-        app_state: &Arc<ZetaCliAppState>,
-        cx: &mut AsyncApp,
-    ) -> Result<Entity<Project>> {
-        let worktree_path = self.setup_worktree().await?;
-
-        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
-
-        AUTHENTICATED
-            .get_or_init(|| {
-                let client = app_state.client.clone();
-                cx.spawn(async move |cx| {
-                    client
-                        .sign_in_with_optional_connect(true, cx)
-                        .await
-                        .unwrap();
-                })
-                .shared()
-            })
-            .clone()
-            .await;
-
-        let project = cx.update(|cx| {
-            Project::local(
-                app_state.client.clone(),
-                app_state.node_runtime.clone(),
-                app_state.user_store.clone(),
-                app_state.languages.clone(),
-                app_state.fs.clone(),
-                None,
-                cx,
-            )
-        })?;
-
-        let worktree = project
-            .update(cx, |project, cx| {
-                project.create_worktree(&worktree_path, true, cx)
-            })?
-            .await?;
-        worktree
-            .read_with(cx, |worktree, _cx| {
-                worktree.as_local().unwrap().scan_complete()
-            })?
-            .await;
-
-        anyhow::Ok(project)
-    }
-
-    pub async fn setup_worktree(&self) -> Result<PathBuf> {
-        self.example.setup_worktree(self.file_name()).await
-    }
-
-    pub fn file_name(&self) -> String {
-        self.name
-            .chars()
-            .map(|c| {
-                if c.is_whitespace() {
-                    '-'
+            Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
+                let title = mem::take(&mut text);
+                current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
+                    Section::UncommittedDiff
+                } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
+                    Section::EditHistory
+                } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
+                    Section::CursorPosition
+                } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
+                    Section::ExpectedPatch
+                } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
+                    Section::ExpectedExcerpts
                 } else {
-                    c.to_ascii_lowercase()
+                    Section::Other
+                };
+            }
+            Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
+                mem::take(&mut text);
+            }
+            Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
+                mem::take(&mut text);
+            }
+            Event::End(TagEnd::Heading(level)) => {
+                anyhow::bail!("Unexpected heading level: {level}");
+            }
+            Event::Start(Tag::CodeBlock(kind)) => {
+                match kind {
+                    CodeBlockKind::Fenced(info) => {
+                        block_info = info;
+                    }
+                    CodeBlockKind::Indented => {
+                        anyhow::bail!("Unexpected indented codeblock");
+                    }
+                };
+            }
+            Event::Start(_) => {
+                text.clear();
+                block_info = "".into();
+            }
+            Event::End(TagEnd::CodeBlock) => {
+                let block_info = block_info.trim();
+                match current_section {
+                    Section::UncommittedDiff => {
+                        example.uncommitted_diff = mem::take(&mut text);
+                    }
+                    Section::EditHistory => {
+                        example.edit_history.push_str(&mem::take(&mut text));
+                    }
+                    Section::CursorPosition => {
+                        example.cursor_path = Path::new(block_info).into();
+                        example.cursor_position = mem::take(&mut text);
+                    }
+                    Section::ExpectedExcerpts => {
+                        mem::take(&mut text);
+                    }
+                    Section::ExpectedPatch => {
+                        example.expected_patch = mem::take(&mut text);
+                    }
+                    Section::Other => {}
                 }
-            })
-            .collect()
-    }
-
-    pub async fn cursor_position(
-        &self,
-        project: &Entity<Project>,
-        cx: &mut AsyncApp,
-    ) -> Result<(Entity<Buffer>, Anchor)> {
-        let worktree = project.read_with(cx, |project, cx| {
-            project.visible_worktrees(cx).next().unwrap()
-        })?;
-        let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
-        let cursor_buffer = project
-            .update(cx, |project, cx| {
-                project.open_buffer(
-                    ProjectPath {
-                        worktree_id: worktree.read(cx).id(),
-                        path: cursor_path,
-                    },
-                    cx,
-                )
-            })?
-            .await?;
-        let cursor_offset_within_excerpt = self
-            .example
-            .cursor_position
-            .find(CURSOR_MARKER)
-            .ok_or_else(|| anyhow!("missing cursor marker"))?;
-        let mut cursor_excerpt = self.example.cursor_position.clone();
-        cursor_excerpt.replace_range(
-            cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
-            "",
-        );
-        let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
-            let text = buffer.text();
-
-            let mut matches = text.match_indices(&cursor_excerpt);
-            let Some((excerpt_offset, _)) = matches.next() else {
-                anyhow::bail!(
-                    "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
-                );
-            };
-            assert!(matches.next().is_none());
-
-            Ok(excerpt_offset)
-        })??;
-
-        let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
-        let cursor_anchor =
-            cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
-        Ok((cursor_buffer, cursor_anchor))
-    }
-
-    #[must_use]
-    pub async fn apply_edit_history(
-        &self,
-        project: &Entity<Project>,
-        cx: &mut AsyncApp,
-    ) -> Result<OpenedBuffers<'_>> {
-        edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
-    }
-}
-
-async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
-    let output = smol::process::Command::new("git")
-        .current_dir(repo_path)
-        .args(args)
-        .output()
-        .await?;
-
-    anyhow::ensure!(
-        output.status.success(),
-        "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
-        args.join(" "),
-        repo_path.display(),
-        output.status,
-        String::from_utf8_lossy(&output.stderr),
-        String::from_utf8_lossy(&output.stdout),
-    );
-    Ok(String::from_utf8(output.stdout)?.trim().to_string())
-}
-
-impl Display for NamedExample {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        write!(f, "# {}\n\n", self.name)?;
-        write!(
-            f,
-            "{REPOSITORY_URL_FIELD} = {}\n",
-            self.example.repository_url
-        )?;
-        write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
-
-        write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
-        write!(f, "`````diff\n")?;
-        write!(f, "{}", self.example.uncommitted_diff)?;
-        write!(f, "`````\n")?;
-
-        if !self.example.edit_history.is_empty() {
-            write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
-        }
-
-        write!(
-            f,
-            "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
-            self.example.cursor_path.display(),
-            self.example.cursor_position
-        )?;
-        write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
-
-        if !self.example.expected_patch.is_empty() {
-            write!(
-                f,
-                "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
-                self.example.expected_patch
-            )?;
+            }
+            _ => {}
         }
-
-        Ok(())
     }
-}
-
-thread_local! {
-    static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
-}
+    if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
+        anyhow::bail!("Missing cursor position codeblock");
+    }
 
-#[must_use]
-pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
-    REPO_LOCKS
-        .with(|cell| {
-            cell.borrow_mut()
-                .entry(path.as_ref().to_path_buf())
-                .or_default()
-                .clone()
-        })
-        .lock_owned()
-        .await
+    Ok(example)
 }

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -0,0 +1,280 @@
+use crate::{
+    PromptFormat,
+    example::{Example, ExamplePrompt},
+    headless::EpAppState,
+    retrieve_context::run_context_retrieval,
+};
+use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
+use gpui::AsyncApp;
+use std::sync::Arc;
+use zeta_prompt::format_zeta_prompt;
+
+pub async fn run_format_prompt(
+    example: &mut Example,
+    prompt_format: PromptFormat,
+    app_state: Arc<EpAppState>,
+    mut cx: AsyncApp,
+) {
+    run_context_retrieval(example, app_state, cx.clone()).await;
+
+    let prompt = match prompt_format {
+        PromptFormat::Teacher => TeacherPrompt::format(example),
+        PromptFormat::Zeta2 => {
+            let ep_store = cx
+                .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+                .unwrap();
+
+            let state = example.state.as_ref().unwrap();
+            let snapshot = state
+                .buffer
+                .read_with(&cx, |buffer, _| buffer.snapshot())
+                .unwrap();
+            let project = state.project.clone();
+            let (_, input) = ep_store
+                .update(&mut cx, |ep_store, _cx| {
+                    zeta2_prompt_input(
+                        &snapshot,
+                        example.context.as_ref().unwrap().files.clone(),
+                        ep_store.edit_history_for_project(&project),
+                        example.cursor_path.clone(),
+                        example.buffer.as_ref().unwrap().cursor_offset,
+                    )
+                })
+                .unwrap();
+            format_zeta_prompt(&input)
+        }
+    };
+
+    example.prompt = Some(ExamplePrompt {
+        input: prompt,
+        expected_output: example.expected_patch.clone(), // TODO
+        format: prompt_format,
+    });
+}
+
+pub trait PromptFormatter {
+    fn format(example: &Example) -> String;
+}
+
+pub trait PromptParser {
+    /// Return unified diff patch of prediction given raw LLM response
+    fn parse(example: &Example, response: &str) -> String;
+}
+
+pub struct TeacherPrompt;
+
+impl PromptFormatter for TeacherPrompt {
+    fn format(example: &Example) -> String {
+        let edit_history = Self::format_edit_history(&example.edit_history);
+        let context = Self::format_context(example);
+        let editable_region = Self::format_editable_region(example);
+
+        let prompt = Self::PROMPT
+            .replace("{{context}}", &context)
+            .replace("{{edit_history}}", &edit_history)
+            .replace("{{editable_region}}", &editable_region);
+
+        prompt
+    }
+}
+
+impl TeacherPrompt {
+    const PROMPT: &str = include_str!("teacher.prompt.md");
+    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
+    pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
+
+    /// Truncate edit history to this number of last lines
+    const MAX_HISTORY_LINES: usize = 128;
+
+    fn format_edit_history(edit_history: &str) -> String {
+        // Strip comments ("garbage lines") from edit history
+        let lines = edit_history
+            .lines()
+            .filter(|&s| Self::is_udiff_content_line(s))
+            .collect::<Vec<_>>();
+
+        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
+            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
+        } else {
+            &lines
+        };
+
+        if history_lines.is_empty() {
+            return "(No edit history)".to_string();
+        }
+
+        history_lines.join("\n")
+    }
+
+    fn format_context(example: &Example) -> String {
+        if example.context.is_none() {
+            panic!("Missing context retriever step");
+        }
+
+        let mut prompt = String::new();
+        zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
+
+        prompt
+    }
+
+    fn format_editable_region(example: &Example) -> String {
+        let mut result = String::new();
+
+        let path_str = example.cursor_path.to_string_lossy();
+        result.push_str(&format!("`````path=\"{path_str}\"\n"));
+        result.push_str(Self::EDITABLE_REGION_START);
+
+        // TODO: control number of lines around cursor
+        result.push_str(&example.cursor_position);
+        if !example.cursor_position.ends_with('\n') {
+            result.push('\n');
+        }
+
+        result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
+        result.push_str("`````");
+
+        result
+    }
+
+    fn extract_editable_region(text: &str) -> String {
+        let start = text
+            .find(Self::EDITABLE_REGION_START)
+            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
+        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
+
+        let region = &text[start..end];
+
+        region.replace("<|user_cursor|>", "")
+    }
+
+    fn is_udiff_content_line(s: &str) -> bool {
+        s.starts_with("-")
+            || s.starts_with("+")
+            || s.starts_with(" ")
+            || s.starts_with("---")
+            || s.starts_with("+++")
+            || s.starts_with("@@")
+    }
+}
+
+impl PromptParser for TeacherPrompt {
+    fn parse(example: &Example, response: &str) -> String {
+        // Ideally, we should always be able to find cursor position in the retrieved context.
+        // In reality, sometimes we don't find it for these reasons:
+        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
+        //    (can be fixed by getting cursor coordinates at the load_example stage)
+        // 2. Context retriever just didn't include cursor line.
+        //
+        // In that case, fallback to using `cursor_position` as excerpt.
+        let cursor_file = &example
+            .buffer
+            .as_ref()
+            .expect("`buffer` should be filled in in the context collection step")
+            .content;
+
+        // Extract updated (new) editable region from the model response
+        let new_editable_region = extract_last_codeblock(response);
+
+        // Reconstruct old editable region we sent to the model
+        let old_editable_region = Self::format_editable_region(example);
+        let old_editable_region = Self::extract_editable_region(&old_editable_region);
+        if !cursor_file.contains(&old_editable_region) {
+            panic!("Something's wrong: editable_region is not found in the cursor file")
+        }
+
+        // Apply editable region to a larger context and compute diff.
+        // This is needed to get a better context lines around the editable region
+        let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
+        let diff = language::unified_diff(&cursor_file, &edited_file);
+
+        let diff = indoc::formatdoc! {"
+            --- a/{path}
+            +++ b/{path}
+            {diff}
+            ",
+            path = example.cursor_path.to_string_lossy(),
+            diff = diff,
+        };
+
+        diff
+    }
+}
+
+fn extract_last_codeblock(text: &str) -> String {
+    let mut last_block = None;
+    let mut search_start = 0;
+
+    while let Some(start) = text[search_start..].find("```") {
+        let start = start + search_start;
+        let bytes = text.as_bytes();
+        let mut backtick_end = start;
+
+        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
+            backtick_end += 1;
+        }
+
+        let backtick_count = backtick_end - start;
+        let closing_backticks = "`".repeat(backtick_count);
+
+        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
+            backtick_end += 1;
+        }
+
+        if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
+            let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
+            last_block = Some(code_block.to_string());
+            search_start = backtick_end + end_pos + backtick_count;
+        } else {
+            break;
+        }
+    }
+
+    last_block.unwrap_or_else(|| text.to_string())
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_extract_last_code_block() {
+        let text = indoc::indoc! {"
+            Some thinking
+
+            ```
+            first block
+            ```
+
+            `````path='something' lines=1:2
+            last block
+            `````
+            "};
+        let last_block = extract_last_codeblock(text);
+        assert_eq!(last_block, "last block");
+    }
+
+    #[test]
+    fn test_extract_editable_region() {
+        let text = indoc::indoc! {"
+            some lines
+            are
+            here
+            <|editable_region_start|>
+            one
+            two three
+
+            <|editable_region_end|>
+            more
+            lines here
+            "};
+        let parsed = TeacherPrompt::extract_editable_region(text);
+        assert_eq!(
+            parsed,
+            indoc::indoc! {"
+            one
+            two three
+
+            "}
+        );
+    }
+}

crates/edit_prediction_cli/src/headless.rs 🔗

@@ -16,7 +16,7 @@ use std::sync::Arc;
 use util::ResultExt as _;
 
 /// Headless subset of `workspace::AppState`.
-pub struct ZetaCliAppState {
+pub struct EpAppState {
     pub languages: Arc<LanguageRegistry>,
     pub client: Arc<Client>,
     pub user_store: Entity<UserStore>,
@@ -25,7 +25,7 @@ pub struct ZetaCliAppState {
 }
 
 // TODO: dedupe with crates/eval/src/eval.rs
-pub fn init(cx: &mut App) -> ZetaCliAppState {
+pub fn init(cx: &mut App) -> EpAppState {
     let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned()));
 
     let app_version = AppVersion::load(
@@ -112,7 +112,7 @@ pub fn init(cx: &mut App) -> ZetaCliAppState {
     prompt_store::init(cx);
     terminal_view::init(cx);
 
-    ZetaCliAppState {
+    EpAppState {
         languages,
         client,
         user_store,

crates/edit_prediction_cli/src/load_project.rs 🔗

@@ -0,0 +1,320 @@
+use crate::{
+    example::{Example, ExampleBuffer, ExampleState},
+    headless::EpAppState,
+};
+use anyhow::{Result, anyhow};
+use collections::HashMap;
+use edit_prediction::EditPredictionStore;
+use edit_prediction::udiff::OpenedBuffers;
+use futures::{
+    AsyncWriteExt as _,
+    lock::{Mutex, OwnedMutexGuard},
+};
+use gpui::{AsyncApp, Entity};
+use language::{Anchor, Buffer, ToOffset, ToPoint};
+use project::buffer_store::BufferStoreEvent;
+use project::{Project, ProjectPath};
+use std::{
+    cell::RefCell,
+    fs,
+    path::{Path, PathBuf},
+    sync::Arc,
+};
+use util::{paths::PathStyle, rel_path::RelPath};
+use zeta_prompt::CURSOR_MARKER;
+
+pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
+    if example.state.is_some() {
+        return;
+    }
+
+    let project = setup_project(example, &app_state, &mut cx).await;
+    let buffer_store = project
+        .read_with(&cx, |project, _| project.buffer_store().clone())
+        .unwrap();
+
+    let ep_store = cx
+        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+        .unwrap();
+
+    cx.subscribe(&buffer_store, {
+        let project = project.clone();
+        move |_, event, cx| match event {
+            BufferStoreEvent::BufferAdded(buffer) => {
+                ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
+            }
+            _ => {}
+        }
+    })
+    .unwrap()
+    .detach();
+
+    let _open_buffers = apply_edit_history(example, &project, &mut cx)
+        .await
+        .unwrap();
+    let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
+    example.buffer = buffer
+        .read_with(&cx, |buffer, _cx| {
+            let cursor_point = cursor_position.to_point(&buffer);
+            Some(ExampleBuffer {
+                content: buffer.text(),
+                cursor_row: cursor_point.row,
+                cursor_column: cursor_point.column,
+                cursor_offset: cursor_position.to_offset(&buffer),
+            })
+        })
+        .unwrap();
+    example.state = Some(ExampleState {
+        buffer,
+        project,
+        cursor_position,
+        _open_buffers,
+    });
+}
+
+async fn cursor_position(
+    example: &Example,
+    project: &Entity<Project>,
+    cx: &mut AsyncApp,
+) -> (Entity<Buffer>, Anchor) {
+    let worktree = project
+        .read_with(cx, |project, cx| {
+            project.visible_worktrees(cx).next().unwrap()
+        })
+        .unwrap();
+
+    let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
+        .unwrap()
+        .into_arc();
+    let cursor_buffer = project
+        .update(cx, |project, cx| {
+            project.open_buffer(
+                ProjectPath {
+                    worktree_id: worktree.read(cx).id(),
+                    path: cursor_path,
+                },
+                cx,
+            )
+        })
+        .unwrap()
+        .await
+        .unwrap();
+    let cursor_offset_within_excerpt = example
+        .cursor_position
+        .find(CURSOR_MARKER)
+        .ok_or_else(|| anyhow!("missing cursor marker"))
+        .unwrap();
+    let mut cursor_excerpt = example.cursor_position.clone();
+    cursor_excerpt.replace_range(
+        cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
+        "",
+    );
+    let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
+        let text = buffer.text();
+
+        let mut matches = text.match_indices(&cursor_excerpt);
+        let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
+            panic!(
+                "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
+            );
+        });
+        assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
+        excerpt_offset
+    }).unwrap();
+
+    let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
+    let cursor_anchor = cursor_buffer
+        .read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
+        .unwrap();
+
+    (cursor_buffer, cursor_anchor)
+}
+
+async fn setup_project(
+    example: &mut Example,
+    app_state: &Arc<EpAppState>,
+    cx: &mut AsyncApp,
+) -> Entity<Project> {
+    setup_worktree(example).await;
+
+    let project = cx
+        .update(|cx| {
+            Project::local(
+                app_state.client.clone(),
+                app_state.node_runtime.clone(),
+                app_state.user_store.clone(),
+                app_state.languages.clone(),
+                app_state.fs.clone(),
+                None,
+                cx,
+            )
+        })
+        .unwrap();
+
+    let worktree = project
+        .update(cx, |project, cx| {
+            project.create_worktree(&example.worktree_path(), true, cx)
+        })
+        .unwrap()
+        .await
+        .unwrap();
+    worktree
+        .read_with(cx, |worktree, _cx| {
+            worktree.as_local().unwrap().scan_complete()
+        })
+        .unwrap()
+        .await;
+    project
+}
+
+pub async fn setup_worktree(example: &Example) {
+    let repo_dir = example.repo_path();
+    let repo_lock = lock_repo(&repo_dir).await;
+
+    if !repo_dir.is_dir() {
+        fs::create_dir_all(&repo_dir).unwrap();
+        run_git(&repo_dir, &["init"]).await.unwrap();
+        run_git(
+            &repo_dir,
+            &["remote", "add", "origin", &example.repository_url],
+        )
+        .await
+        .unwrap();
+    }
+
+    // Resolve the example to a revision, fetching it if needed.
+    let revision = run_git(
+        &repo_dir,
+        &["rev-parse", &format!("{}^{{commit}}", example.revision)],
+    )
+    .await;
+    let revision = if let Ok(revision) = revision {
+        revision
+    } else {
+        if run_git(
+            &repo_dir,
+            &["fetch", "--depth", "1", "origin", &example.revision],
+        )
+        .await
+        .is_err()
+        {
+            run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
+        }
+        let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
+            .await
+            .unwrap();
+        if revision != example.revision {
+            run_git(&repo_dir, &["tag", &example.revision, &revision])
+                .await
+                .unwrap();
+        }
+        revision
+    };
+
+    // Create the worktree for this example if needed.
+    let worktree_path = example.worktree_path();
+    if worktree_path.is_dir() {
+        run_git(&worktree_path, &["clean", "--force", "-d"])
+            .await
+            .unwrap();
+        run_git(&worktree_path, &["reset", "--hard", "HEAD"])
+            .await
+            .unwrap();
+        run_git(&worktree_path, &["checkout", revision.as_str()])
+            .await
+            .unwrap();
+    } else {
+        let worktree_path_string = worktree_path.to_string_lossy();
+        run_git(
+            &repo_dir,
+            &["branch", "-f", &example.name, revision.as_str()],
+        )
+        .await
+        .unwrap();
+        run_git(
+            &repo_dir,
+            &[
+                "worktree",
+                "add",
+                "-f",
+                &worktree_path_string,
+                &example.name,
+            ],
+        )
+        .await
+        .unwrap();
+    }
+    drop(repo_lock);
+
+    // Apply the uncommitted diff for this example.
+    if !example.uncommitted_diff.is_empty() {
+        let mut apply_process = smol::process::Command::new("git")
+            .current_dir(&worktree_path)
+            .args(&["apply", "-"])
+            .stdin(std::process::Stdio::piped())
+            .spawn()
+            .unwrap();
+
+        let mut stdin = apply_process.stdin.take().unwrap();
+        stdin
+            .write_all(example.uncommitted_diff.as_bytes())
+            .await
+            .unwrap();
+        stdin.close().await.unwrap();
+        drop(stdin);
+
+        let apply_result = apply_process.output().await.unwrap();
+        if !apply_result.status.success() {
+            panic!(
+                "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
+                apply_result.status,
+                String::from_utf8_lossy(&apply_result.stderr),
+                String::from_utf8_lossy(&apply_result.stdout),
+            );
+        }
+    }
+}
+
+async fn apply_edit_history(
+    example: &Example,
+    project: &Entity<Project>,
+    cx: &mut AsyncApp,
+) -> Result<OpenedBuffers> {
+    edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
+}
+
+thread_local! {
+    static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
+}
+
+#[must_use]
+pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
+    REPO_LOCKS
+        .with(|cell| {
+            cell.borrow_mut()
+                .entry(path.as_ref().to_path_buf())
+                .or_default()
+                .clone()
+        })
+        .lock_owned()
+        .await
+}
+
+async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
+    let output = smol::process::Command::new("git")
+        .current_dir(repo_path)
+        .args(args)
+        .output()
+        .await?;
+
+    anyhow::ensure!(
+        output.status.success(),
+        "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
+        args.join(" "),
+        repo_path.display(),
+        output.status,
+        String::from_utf8_lossy(&output.stderr),
+        String::from_utf8_lossy(&output.stdout),
+    );
+    Ok(String::from_utf8(output.stdout)?.trim().to_string())
+}

crates/edit_prediction_cli/src/main.rs 🔗

@@ -1,522 +1,196 @@
-mod evaluate;
+mod anthropic_client;
 mod example;
+mod format_prompt;
 mod headless;
+mod load_project;
 mod metrics;
 mod paths;
 mod predict;
-mod source_location;
-mod training;
-mod util;
+mod retrieve_context;
+mod score;
 
-use crate::{
-    evaluate::run_evaluate,
-    example::{ExampleFormat, NamedExample},
-    headless::ZetaCliAppState,
-    predict::run_predict,
-    source_location::SourceLocation,
-    training::{context::ContextType, distill::run_distill},
-    util::{open_buffer, open_buffer_with_language_server},
-};
-use ::util::{ResultExt, paths::PathStyle};
-use anyhow::{Result, anyhow};
-use clap::{Args, Parser, Subcommand, ValueEnum};
-use cloud_llm_client::predict_edits_v3;
-use edit_prediction::udiff::DiffLine;
-use edit_prediction_context::EditPredictionExcerptOptions;
-use gpui::{Application, AsyncApp, Entity, prelude::*};
-use language::{Bias, Buffer, BufferSnapshot, Point};
-use metrics::delta_chr_f;
-use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
+use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
+use edit_prediction::EditPredictionStore;
+use gpui::Application;
 use reqwest_client::ReqwestClient;
-use std::io::{self};
-use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
+use serde::{Deserialize, Serialize};
+use std::{path::PathBuf, sync::Arc};
+
+use crate::example::{read_examples, write_examples};
+use crate::format_prompt::run_format_prompt;
+use crate::load_project::run_load_project;
+use crate::predict::run_prediction;
+use crate::retrieve_context::run_context_retrieval;
+use crate::score::run_scoring;
 
 #[derive(Parser, Debug)]
-#[command(name = "zeta")]
-struct ZetaCliArgs {
+#[command(name = "ep")]
+struct EpArgs {
     #[arg(long, default_value_t = false)]
     printenv: bool,
+    #[clap(long, default_value_t = 10)]
+    max_parallelism: usize,
     #[command(subcommand)]
     command: Option<Command>,
+    #[clap(global = true)]
+    inputs: Vec<PathBuf>,
+    #[arg(long, short, global = true)]
+    output: Option<PathBuf>,
+    #[arg(long, short, global = true)]
+    in_place: bool,
 }
 
 #[derive(Subcommand, Debug)]
 enum Command {
-    Context(ContextArgs),
-    Predict(PredictArguments),
-    Eval(EvaluateArguments),
-    Distill(DistillArguments),
-    ConvertExample {
-        path: PathBuf,
-        #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
-        output_format: ExampleFormat,
-    },
-    Score {
-        golden_patch: PathBuf,
-        actual_patch: PathBuf,
-    },
+    /// Parse markdown examples and output a combined .jsonl file
+    ParseExample,
+    /// Create git worktrees for each example and load file contents
+    LoadBuffer,
+    /// Retrieve context for input examples.
+    Context,
+    /// Generate a prompt string for a specific model
+    FormatPrompt(FormatPromptArgs),
+    /// Runs edit prediction
+    Predict(PredictArgs),
+    /// Computes a score based on actual and expected patches
+    Score(PredictArgs),
+    /// Print aggregated scores
+    Eval(PredictArgs),
+    /// Remove git repositories and worktrees
     Clean,
 }
 
 #[derive(Debug, Args)]
-struct ContextArgs {
-    #[arg(long)]
-    provider: ContextProvider,
-    #[arg(long)]
-    worktree: PathBuf,
-    #[arg(long)]
-    cursor: SourceLocation,
-    #[arg(long)]
-    use_language_server: bool,
-    #[arg(long)]
-    edit_history: Option<FileOrStdin>,
-    #[clap(flatten)]
-    zeta2_args: Zeta2Args,
-}
-
-#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
-enum ContextProvider {
-    Zeta1,
-    #[default]
-    Zeta2,
-}
-
-#[derive(Clone, Debug, Args)]
-struct Zeta2Args {
-    #[arg(long, default_value_t = 8192)]
-    max_prompt_bytes: usize,
-    #[arg(long, default_value_t = 2048)]
-    max_excerpt_bytes: usize,
-    #[arg(long, default_value_t = 1024)]
-    min_excerpt_bytes: usize,
-    #[arg(long, default_value_t = 0.66)]
-    target_before_cursor_over_total_bytes: f32,
-    #[arg(long, default_value_t = 1024)]
-    max_diagnostic_bytes: usize,
-    #[arg(long, value_enum, default_value_t = PromptFormat::default())]
+struct FormatPromptArgs {
+    #[clap(long)]
     prompt_format: PromptFormat,
-    #[arg(long, value_enum, default_value_t = Default::default())]
-    output_format: OutputFormat,
-    #[arg(long, default_value_t = 42)]
-    file_indexing_parallelism: usize,
-    #[arg(long, default_value_t = false)]
-    disable_imports_gathering: bool,
-    #[arg(long, default_value_t = u8::MAX)]
-    max_retrieved_definitions: u8,
 }
 
-#[derive(Debug, Args)]
-pub struct PredictArguments {
-    #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
-    format: PredictionsOutputFormat,
-    example_path: PathBuf,
-    #[clap(flatten)]
-    options: PredictionOptions,
+#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
+enum PromptFormat {
+    Teacher,
+    Zeta2,
 }
 
 #[derive(Debug, Args)]
-pub struct DistillArguments {
-    split_commit_dataset: PathBuf,
-    #[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
-    context_type: ContextType,
-    #[clap(long)]
-    batch: Option<String>,
-}
-
-#[derive(Clone, Debug, Args)]
-pub struct PredictionOptions {
-    #[clap(flatten)]
-    zeta2: Zeta2Args,
+struct PredictArgs {
     #[clap(long)]
     provider: PredictionProvider,
-    #[clap(long, value_enum, default_value_t = CacheMode::default())]
-    cache: CacheMode,
-}
-
-#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
-pub enum CacheMode {
-    /// Use cached LLM requests and responses, except when multiple repetitions are requested
-    #[default]
-    Auto,
-    /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
-    #[value(alias = "request")]
-    Requests,
-    /// Ignore existing cache entries for both LLM and search.
-    Skip,
-    /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
-    /// Useful for reproducing results and fixing bugs outside of search queries
-    Force,
-}
-
-impl CacheMode {
-    fn use_cached_llm_responses(&self) -> bool {
-        self.assert_not_auto();
-        matches!(self, CacheMode::Requests | CacheMode::Force)
-    }
-
-    fn use_cached_search_results(&self) -> bool {
-        self.assert_not_auto();
-        matches!(self, CacheMode::Force)
-    }
-
-    fn assert_not_auto(&self) {
-        assert_ne!(
-            *self,
-            CacheMode::Auto,
-            "Cache mode should not be auto at this point!"
-        );
-    }
-}
-
-#[derive(clap::ValueEnum, Debug, Clone)]
-pub enum PredictionsOutputFormat {
-    Json,
-    Md,
-    Diff,
+    #[clap(long, default_value_t = 1)]
+    repetitions: usize,
 }
 
-#[derive(Debug, Args)]
-pub struct EvaluateArguments {
-    example_paths: Vec<PathBuf>,
-    #[clap(flatten)]
-    options: PredictionOptions,
-    #[clap(short, long, default_value_t = 1, alias = "repeat")]
-    repetitions: u16,
-    #[arg(long)]
-    skip_prediction: bool,
-}
-
-#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
+#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
 enum PredictionProvider {
+    Sweep,
+    Mercury,
     Zeta1,
-    #[default]
     Zeta2,
-    Sweep,
-}
-
-fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
-    edit_prediction::ZetaOptions {
-        context: EditPredictionExcerptOptions {
-            max_bytes: args.max_excerpt_bytes,
-            min_bytes: args.min_excerpt_bytes,
-            target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
-        },
-        max_prompt_bytes: args.max_prompt_bytes,
-        prompt_format: args.prompt_format.into(),
-    }
-}
-
-#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
-enum PromptFormat {
-    OnlySnippets,
-    #[default]
-    OldTextNewText,
-    Minimal,
-    MinimalQwen,
-    SeedCoder1120,
+    Teacher,
 }
 
-impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
-    fn into(self) -> predict_edits_v3::PromptFormat {
-        match self {
-            Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
-            Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
-            Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
-            Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
-            Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
+impl EpArgs {
+    fn output_path(&self) -> Option<PathBuf> {
+        if self.in_place {
+            if self.inputs.len() == 1 {
+                self.inputs.first().cloned()
+            } else {
+                panic!("--in-place requires exactly one input file")
+            }
+        } else {
+            self.output.clone()
         }
     }
 }
 
-#[derive(clap::ValueEnum, Default, Debug, Clone)]
-enum OutputFormat {
-    #[default]
-    Prompt,
-    Request,
-    Full,
-}
-
-#[derive(Debug, Clone)]
-enum FileOrStdin {
-    File(PathBuf),
-    Stdin,
-}
+fn main() {
+    zlog::init();
+    zlog::init_output_stderr();
+    let args = EpArgs::parse();
 
-impl FileOrStdin {
-    async fn read_to_string(&self) -> Result<String, std::io::Error> {
-        match self {
-            FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
-            FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
-        }
+    if args.printenv {
+        ::util::shell_env::print_env();
+        return;
     }
-}
-
-impl FromStr for FileOrStdin {
-    type Err = <PathBuf as FromStr>::Err;
 
-    fn from_str(s: &str) -> Result<Self, Self::Err> {
-        match s {
-            "-" => Ok(Self::Stdin),
-            _ => Ok(Self::File(PathBuf::from_str(s)?)),
+    let output = args.output_path();
+    let command = match args.command {
+        Some(cmd) => cmd,
+        None => {
+            EpArgs::command().print_help().unwrap();
+            return;
         }
-    }
-}
-
-struct LoadedContext {
-    full_path_str: String,
-    snapshot: BufferSnapshot,
-    clipped_cursor: Point,
-    worktree: Entity<Worktree>,
-    project: Entity<Project>,
-    buffer: Entity<Buffer>,
-    lsp_open_handle: Option<OpenLspBufferHandle>,
-}
-
-async fn load_context(
-    args: &ContextArgs,
-    app_state: &Arc<ZetaCliAppState>,
-    cx: &mut AsyncApp,
-) -> Result<LoadedContext> {
-    let ContextArgs {
-        worktree: worktree_path,
-        cursor,
-        use_language_server,
-        ..
-    } = args;
-
-    let worktree_path = worktree_path.canonicalize()?;
-
-    let project = cx.update(|cx| {
-        Project::local(
-            app_state.client.clone(),
-            app_state.node_runtime.clone(),
-            app_state.user_store.clone(),
-            app_state.languages.clone(),
-            app_state.fs.clone(),
-            None,
-            cx,
-        )
-    })?;
-
-    let worktree = project
-        .update(cx, |project, cx| {
-            project.create_worktree(&worktree_path, true, cx)
-        })?
-        .await?;
-
-    let mut ready_languages = HashSet::default();
-    let (lsp_open_handle, buffer) = if *use_language_server {
-        let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
-            project.clone(),
-            worktree.clone(),
-            cursor.path.clone(),
-            &mut ready_languages,
-            cx,
-        )
-        .await?;
-        (Some(lsp_open_handle), buffer)
-    } else {
-        let buffer =
-            open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
-        (None, buffer)
     };
 
-    let full_path_str = worktree
-        .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
-        .display(PathStyle::local())
-        .to_string();
-
-    let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
-    let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
-    if clipped_cursor != cursor.point {
-        let max_row = snapshot.max_point().row;
-        if cursor.point.row < max_row {
-            return Err(anyhow!(
-                "Cursor position {:?} is out of bounds (line length is {})",
-                cursor.point,
-                snapshot.line_len(cursor.point.row)
-            ));
-        } else {
-            return Err(anyhow!(
-                "Cursor position {:?} is out of bounds (max row is {})",
-                cursor.point,
-                max_row
-            ));
+    match &command {
+        Command::Clean => {
+            std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
+            return;
         }
+        _ => {}
     }
 
-    Ok(LoadedContext {
-        full_path_str,
-        snapshot,
-        clipped_cursor,
-        worktree,
-        project,
-        buffer,
-        lsp_open_handle,
-    })
-}
-
-async fn zeta2_context(
-    args: ContextArgs,
-    app_state: &Arc<ZetaCliAppState>,
-    cx: &mut AsyncApp,
-) -> Result<String> {
-    let LoadedContext {
-        worktree,
-        project,
-        buffer,
-        clipped_cursor,
-        lsp_open_handle: _handle,
-        ..
-    } = load_context(&args, app_state, cx).await?;
-
-    // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
-    // the whole worktree.
-    worktree
-        .read_with(cx, |worktree, _cx| {
-            worktree.as_local().unwrap().scan_complete()
-        })?
-        .await;
-    let output = cx
-        .update(|cx| {
-            let store = cx.new(|cx| {
-                edit_prediction::EditPredictionStore::new(
-                    app_state.client.clone(),
-                    app_state.user_store.clone(),
-                    cx,
-                )
-            });
-            store.update(cx, |store, cx| {
-                store.set_options(zeta2_args_to_options(&args.zeta2_args));
-                store.register_buffer(&buffer, &project, cx);
-            });
-            cx.spawn(async move |cx| {
-                let updates_rx = store.update(cx, |store, cx| {
-                    let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
-                    store.set_use_context(true);
-                    store.refresh_context(&project, &buffer, cursor, cx);
-                    store.project_context_updates(&project).unwrap()
-                })?;
-
-                updates_rx.recv().await.ok();
-
-                let context = store.update(cx, |store, cx| {
-                    store.context_for_project(&project, cx).to_vec()
-                })?;
-
-                anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
-            })
-        })?
-        .await?;
-
-    Ok(output)
-}
-
-async fn zeta1_context(
-    args: ContextArgs,
-    app_state: &Arc<ZetaCliAppState>,
-    cx: &mut AsyncApp,
-) -> Result<edit_prediction::zeta1::GatherContextOutput> {
-    let LoadedContext {
-        full_path_str,
-        snapshot,
-        clipped_cursor,
-        ..
-    } = load_context(&args, app_state, cx).await?;
-
-    let events = match args.edit_history {
-        Some(events) => events.read_to_string().await?,
-        None => String::new(),
-    };
-
-    let prompt_for_events = move || (events, 0);
-    cx.update(|cx| {
-        edit_prediction::zeta1::gather_context(
-            full_path_str,
-            &snapshot,
-            clipped_cursor,
-            prompt_for_events,
-            cloud_llm_client::PredictEditsRequestTrigger::Cli,
-            cx,
-        )
-    })?
-    .await
-}
-
-fn main() {
-    zlog::init();
-    zlog::init_output_stderr();
-    let args = ZetaCliArgs::parse();
+    let mut examples = read_examples(&args.inputs);
     let http_client = Arc::new(ReqwestClient::new());
     let app = Application::headless().with_http_client(http_client);
 
     app.run(move |cx| {
         let app_state = Arc::new(headless::init(cx));
+        EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
+
         cx.spawn(async move |cx| {
-            match args.command {
-                None => {
-                    if args.printenv {
-                        ::util::shell_env::print_env();
-                    } else {
-                        panic!("Expected a command");
-                    }
-                }
-                Some(Command::Context(context_args)) => {
-                    let result = match context_args.provider {
-                        ContextProvider::Zeta1 => {
-                            let context =
-                                zeta1_context(context_args, &app_state, cx).await.unwrap();
-                            serde_json::to_string_pretty(&context.body).unwrap()
-                        }
-                        ContextProvider::Zeta2 => {
-                            zeta2_context(context_args, &app_state, cx).await.unwrap()
+            match &command {
+                Command::Predict(args) => predict::sync_batches(&args.provider).await,
+                _ => (),
+            };
+
+            for data in examples.chunks_mut(args.max_parallelism) {
+                let mut futures = Vec::new();
+                for example in data.iter_mut() {
+                    let cx = cx.clone();
+                    let app_state = app_state.clone();
+                    futures.push(async {
+                        match &command {
+                            Command::ParseExample => {}
+                            Command::LoadBuffer => {
+                                run_load_project(example, app_state.clone(), cx).await;
+                            }
+                            Command::Context => {
+                                run_context_retrieval(example, app_state, cx).await;
+                            }
+                            Command::FormatPrompt(args) => {
+                                run_format_prompt(example, args.prompt_format, app_state, cx).await;
+                            }
+                            Command::Predict(args) => {
+                                run_prediction(
+                                    example,
+                                    Some(args.provider),
+                                    args.repetitions,
+                                    app_state.clone(),
+                                    cx,
+                                )
+                                .await;
+                            }
+                            Command::Score(args) | Command::Eval(args) => {
+                                run_scoring(example, &args, app_state, cx).await;
+                            }
+                            Command::Clean => {
+                                unreachable!()
+                            }
                         }
-                    };
-                    println!("{}", result);
-                }
-                Some(Command::Predict(arguments)) => {
-                    run_predict(arguments, &app_state, cx).await;
-                }
-                Some(Command::Eval(arguments)) => {
-                    run_evaluate(arguments, &app_state, cx).await;
+                    });
                 }
-                Some(Command::Distill(arguments)) => {
-                    let _guard = cx
-                        .update(|cx| gpui_tokio::Tokio::handle(cx))
-                        .unwrap()
-                        .enter();
-                    run_distill(arguments).await.log_err();
-                }
-                Some(Command::ConvertExample {
-                    path,
-                    output_format,
-                }) => {
-                    let example = NamedExample::load(path).unwrap();
-                    example.write(output_format, io::stdout()).unwrap();
-                }
-                Some(Command::Score {
-                    golden_patch,
-                    actual_patch,
-                }) => {
-                    let golden_content = std::fs::read_to_string(golden_patch).unwrap();
-                    let actual_content = std::fs::read_to_string(actual_patch).unwrap();
-
-                    let golden_diff: Vec<DiffLine> = golden_content
-                        .lines()
-                        .map(|line| DiffLine::parse(line))
-                        .collect();
+                futures::future::join_all(futures).await;
+            }
 
-                    let actual_diff: Vec<DiffLine> = actual_content
-                        .lines()
-                        .map(|line| DiffLine::parse(line))
-                        .collect();
+            if args.output.is_some() || !matches!(command, Command::Eval(_)) {
+                write_examples(&examples, output.as_ref());
+            }
 
-                    let score = delta_chr_f(&golden_diff, &actual_diff);
-                    println!("{:.2}", score);
-                }
-                Some(Command::Clean) => {
-                    std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
-                }
+            match &command {
+                Command::Predict(args) => predict::sync_batches(&args.provider).await,
+                Command::Eval(_) => score::print_report(&examples),
+                _ => (),
             };
 
             let _ = cx.update(|cx| cx.quit());

crates/edit_prediction_cli/src/metrics.rs 🔗

@@ -1,30 +1,34 @@
 use collections::{HashMap, HashSet};
 use edit_prediction::udiff::DiffLine;
+use serde::{Deserialize, Serialize};
 
 type Counts = HashMap<String, usize>;
 type CountsDelta = HashMap<String, isize>;
 
-#[derive(Default, Debug, Clone)]
-pub struct Scores {
+#[derive(Default, Debug, Clone, Serialize, Deserialize)]
+pub struct ClassificationMetrics {
     pub true_positives: usize,
     pub false_positives: usize,
     pub false_negatives: usize,
 }
 
-impl Scores {
-    pub fn from_sets(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
+impl ClassificationMetrics {
+    pub fn from_sets(
+        expected: &HashSet<String>,
+        actual: &HashSet<String>,
+    ) -> ClassificationMetrics {
         let true_positives = expected.intersection(actual).count();
         let false_positives = actual.difference(expected).count();
         let false_negatives = expected.difference(actual).count();
 
-        Scores {
+        ClassificationMetrics {
             true_positives,
             false_positives,
             false_negatives,
         }
     }
 
-    pub fn from_counts(expected: &Counts, actual: &Counts) -> Scores {
+    pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
         let mut true_positives = 0;
         let mut false_positives = 0;
         let mut false_negatives = 0;
@@ -45,32 +49,16 @@ impl Scores {
             }
         }
 
-        Scores {
+        ClassificationMetrics {
             true_positives,
             false_positives,
             false_negatives,
         }
     }
 
-    pub fn to_markdown(&self) -> String {
-        format!(
-            "
-Precision       : {:.4}
-Recall          : {:.4}
-F1 Score        : {:.4}
-True Positives  : {}
-False Positives : {}
-False Negatives : {}",
-            self.precision(),
-            self.recall(),
-            self.f1_score(),
-            self.true_positives,
-            self.false_positives,
-            self.false_negatives
-        )
-    }
-
-    pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
+    pub fn aggregate<'a>(
+        scores: impl Iterator<Item = &'a ClassificationMetrics>,
+    ) -> ClassificationMetrics {
         let mut true_positives = 0;
         let mut false_positives = 0;
         let mut false_negatives = 0;
@@ -81,7 +69,7 @@ False Negatives : {}",
             false_negatives += score.false_negatives;
         }
 
-        Scores {
+        ClassificationMetrics {
             true_positives,
             false_positives,
             false_negatives,
@@ -115,7 +103,10 @@ False Negatives : {}",
     }
 }
 
-pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine]) -> Scores {
+pub fn line_match_score(
+    expected_patch: &[DiffLine],
+    actual_patch: &[DiffLine],
+) -> ClassificationMetrics {
     let expected_change_lines = expected_patch
         .iter()
         .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
@@ -128,7 +119,7 @@ pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine])
         .map(|line| line.to_string())
         .collect();
 
-    Scores::from_sets(&expected_change_lines, &actual_change_lines)
+    ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
 }
 
 enum ChrfWhitespace {
@@ -204,7 +195,7 @@ pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
         let expected_counts = ngram_delta_to_counts(&expected_delta);
         let actual_counts = ngram_delta_to_counts(&actual_delta);
 
-        let score = Scores::from_counts(&expected_counts, &actual_counts);
+        let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
         total_precision += score.precision();
         total_recall += score.recall();
     }

crates/edit_prediction_cli/src/paths.rs 🔗

@@ -1,57 +1,25 @@
-use std::{env, path::PathBuf, sync::LazyLock};
+use std::{
+    path::{Path, PathBuf},
+    sync::LazyLock,
+};
 
-pub static TARGET_ZETA_DIR: LazyLock<PathBuf> =
-    LazyLock::new(|| env::current_dir().unwrap().join("target/zeta"));
-pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("cache"));
-pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("repos"));
-pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("worktrees"));
+pub static DATA_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
+    let dir = dirs::home_dir().unwrap().join(".zed_ep");
+    ensure_dir(&dir)
+});
+pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("cache")));
+pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("repos")));
+pub static WORKTREES_DIR: LazyLock<PathBuf> =
+    LazyLock::new(|| ensure_dir(&DATA_DIR.join("worktrees")));
 pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
-    TARGET_ZETA_DIR
+    DATA_DIR
         .join("runs")
         .join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
 });
-pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> =
-    LazyLock::new(|| TARGET_ZETA_DIR.join("latest"));
-
-pub fn print_run_data_dir(deep: bool, use_color: bool) {
-    println!("\n## Run Data\n");
-    let mut files = Vec::new();
-
-    let current_dir = std::env::current_dir().unwrap();
-    for file in std::fs::read_dir(&*RUN_DIR).unwrap() {
-        let file = file.unwrap();
-        if file.file_type().unwrap().is_dir() && deep {
-            for file in std::fs::read_dir(file.path()).unwrap() {
-                let path = file.unwrap().path();
-                let path = path.strip_prefix(&current_dir).unwrap_or(&path);
-                files.push(format!(
-                    "- {}/{}{}{}",
-                    path.parent().unwrap().display(),
-                    if use_color { "\x1b[34m" } else { "" },
-                    path.file_name().unwrap().display(),
-                    if use_color { "\x1b[0m" } else { "" },
-                ));
-            }
-        } else {
-            let path = file.path();
-            let path = path.strip_prefix(&current_dir).unwrap_or(&path);
-            files.push(format!(
-                "- {}/{}{}{}",
-                path.parent().unwrap().display(),
-                if use_color { "\x1b[34m" } else { "" },
-                path.file_name().unwrap().display(),
-                if use_color { "\x1b[0m" } else { "" }
-            ));
-        }
-    }
-    files.sort();
-
-    for file in files {
-        println!("{}", file);
-    }
+pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
+pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
 
-    println!(
-        "\n💡 Tip of the day: {} always points to the latest run\n",
-        LATEST_EXAMPLE_RUN_DIR.display()
-    );
+fn ensure_dir(path: &Path) -> PathBuf {
+    std::fs::create_dir_all(path).expect("Failed to create directory");
+    path.to_path_buf()
 }

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -1,374 +1,271 @@
-use crate::example::{ActualExcerpt, NamedExample};
-use crate::headless::ZetaCliAppState;
-use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
 use crate::{
-    CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
+    PredictionProvider, PromptFormat,
+    anthropic_client::AnthropicClient,
+    example::{Example, ExamplePrediction},
+    format_prompt::{PromptParser, TeacherPrompt, run_format_prompt},
+    headless::EpAppState,
+    load_project::run_load_project,
+    paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
+    retrieve_context::run_context_retrieval,
+};
+use edit_prediction::{DebugEvent, EditPredictionStore};
+use futures::{FutureExt as _, StreamExt as _, future::Shared};
+use gpui::{AppContext as _, AsyncApp, Task};
+use std::{
+    fs,
+    sync::{
+        Arc, Mutex, OnceLock,
+        atomic::{AtomicUsize, Ordering::SeqCst},
+    },
 };
-use ::serde::Serialize;
-use anyhow::{Context, Result, anyhow};
-use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
-use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
-use futures::StreamExt as _;
-use gpui::{AppContext, AsyncApp, Entity};
-use project::Project;
-use project::buffer_store::BufferStoreEvent;
-use serde::Deserialize;
-use std::fs;
-use std::io::{IsTerminal, Write};
-use std::path::PathBuf;
-use std::sync::Arc;
-use std::sync::Mutex;
-use std::time::{Duration, Instant};
 
-pub async fn run_predict(
-    args: PredictArguments,
-    app_state: &Arc<ZetaCliAppState>,
-    cx: &mut AsyncApp,
+pub async fn run_prediction(
+    example: &mut Example,
+    provider: Option<PredictionProvider>,
+    repetition_count: usize,
+    app_state: Arc<EpAppState>,
+    mut cx: AsyncApp,
 ) {
-    let example = NamedExample::load(args.example_path).unwrap();
-    let project = example.setup_project(app_state, cx).await.unwrap();
-    let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
-    let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
-    let result = perform_predict(example, project, store, None, args.options, cx)
-        .await
-        .unwrap();
-    result.write(args.format, std::io::stdout()).unwrap();
-
-    print_run_data_dir(true, std::io::stdout().is_terminal());
-}
-
-pub fn setup_store(
-    provider: PredictionProvider,
-    project: &Entity<Project>,
-    app_state: &Arc<ZetaCliAppState>,
-    cx: &mut AsyncApp,
-) -> Result<Entity<EditPredictionStore>> {
-    let store = cx.new(|cx| {
-        edit_prediction::EditPredictionStore::new(
-            app_state.client.clone(),
-            app_state.user_store.clone(),
-            cx,
-        )
-    })?;
+    if !example.predictions.is_empty() {
+        return;
+    }
 
-    store.update(cx, |store, _cx| {
-        let model = match provider {
-            PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
-            PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
-            PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
-        };
-        store.set_edit_prediction_model(model);
-    })?;
+    run_load_project(example, app_state.clone(), cx.clone()).await;
+    run_context_retrieval(example, app_state.clone(), cx.clone()).await;
 
-    let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
+    let provider = provider.unwrap();
 
-    cx.subscribe(&buffer_store, {
-        let project = project.clone();
-        let store = store.clone();
-        move |_, event, cx| match event {
-            BufferStoreEvent::BufferAdded(buffer) => {
-                store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
-            }
-            _ => {}
+    if matches!(provider, PredictionProvider::Teacher) {
+        if example.prompt.is_none() {
+            run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
         }
-    })?
-    .detach();
 
-    anyhow::Ok(store)
-}
-
-pub async fn perform_predict(
-    example: NamedExample,
-    project: Entity<Project>,
-    store: Entity<EditPredictionStore>,
-    repetition_ix: Option<u16>,
-    options: PredictionOptions,
-    cx: &mut AsyncApp,
-) -> Result<PredictionDetails> {
-    let mut cache_mode = options.cache;
-    if repetition_ix.is_some() {
-        if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
-            panic!("Repetitions are not supported in Auto cache mode");
-        } else {
-            cache_mode = CacheMode::Skip;
-        }
-    } else if cache_mode == CacheMode::Auto {
-        cache_mode = CacheMode::Requests;
+        let batched = true;
+        return predict_anthropic(example, repetition_count, batched).await;
     }
 
-    let mut example_run_dir = RUN_DIR.join(&example.file_name());
-    if let Some(repetition_ix) = repetition_ix {
-        example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
-    }
-    fs::create_dir_all(&example_run_dir)?;
-    if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
-        fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
+    if matches!(
+        provider,
+        PredictionProvider::Zeta1 | PredictionProvider::Zeta2
+    ) {
+        static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
+        AUTHENTICATED
+            .get_or_init(|| {
+                let client = app_state.client.clone();
+                cx.spawn(async move |cx| {
+                    client
+                        .sign_in_with_optional_connect(true, cx)
+                        .await
+                        .unwrap();
+                })
+                .shared()
+            })
+            .clone()
+            .await;
     }
 
-    #[cfg(unix)]
-    std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
-        .context("creating latest link")?;
-
-    #[cfg(windows)]
-    std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
-        .context("creating latest link")?;
-
-    store.update(cx, |store, _cx| {
-        store.with_eval_cache(Arc::new(RunCache {
-            example_run_dir: example_run_dir.clone(),
-            cache_mode,
-        }));
-    })?;
-
-    let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
-
-    let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
-
-    let prompt_format = options.zeta2.prompt_format;
-
-    store.update(cx, |store, _cx| {
-        let mut options = store.options().clone();
-        options.prompt_format = prompt_format.into();
-        store.set_options(options);
-    })?;
+    let ep_store = cx
+        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+        .unwrap();
 
-    let mut debug_task = gpui::Task::ready(Ok(()));
+    ep_store
+        .update(&mut cx, |store, _cx| {
+            let model = match provider {
+                PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
+                PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
+                PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
+                PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
+                PredictionProvider::Teacher => unreachable!(),
+            };
+            store.set_edit_prediction_model(model);
+        })
+        .unwrap();
+    let state = example.state.as_ref().unwrap();
+    let run_dir = RUN_DIR.join(&example.name);
 
-    if options.provider == crate::PredictionProvider::Zeta2 {
-        let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
+    let updated_example = Arc::new(Mutex::new(example.clone()));
+    let current_run_ix = Arc::new(AtomicUsize::new(0));
 
-        debug_task = cx.background_spawn({
-            let result = result.clone();
-            async move {
-                let mut start_time = None;
-                let mut retrieval_finished_at = None;
-                while let Some(event) = debug_rx.next().await {
-                    match event {
-                        edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
-                            start_time = Some(info.timestamp);
-                            fs::write(
-                                example_run_dir.join("search_prompt.md"),
-                                &info.search_prompt,
-                            )?;
+    let mut debug_rx = ep_store
+        .update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
+        .unwrap();
+    let debug_task = cx.background_spawn({
+        let updated_example = updated_example.clone();
+        let current_run_ix = current_run_ix.clone();
+        let run_dir = run_dir.clone();
+        async move {
+            while let Some(event) = debug_rx.next().await {
+                let run_ix = current_run_ix.load(SeqCst);
+                let mut updated_example = updated_example.lock().unwrap();
+
+                let run_dir = if repetition_count > 1 {
+                    run_dir.join(format!("{:03}", run_ix))
+                } else {
+                    run_dir.clone()
+                };
+
+                match event {
+                    DebugEvent::EditPredictionStarted(request) => {
+                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
+
+                        if let Some(prompt) = request.prompt {
+                            fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
                         }
-                        edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
-                            retrieval_finished_at = Some(info.timestamp);
-                            for (key, value) in &info.metadata {
-                                if *key == "search_queries" {
-                                    fs::write(
-                                        example_run_dir.join("search_queries.json"),
-                                        value.as_bytes(),
-                                    )?;
-                                }
-                            }
+                    }
+                    DebugEvent::EditPredictionFinished(request) => {
+                        assert_eq!(updated_example.predictions.len(), run_ix + 1);
+
+                        if let Some(output) = request.model_output {
+                            fs::write(run_dir.join("prediction_response.md"), &output)?;
+                            updated_example
+                                .predictions
+                                .last_mut()
+                                .unwrap()
+                                .actual_output = output;
                         }
-                        edit_prediction::DebugEvent::EditPredictionRequested(request) => {
-                            let prediction_started_at = Instant::now();
-                            start_time.get_or_insert(prediction_started_at);
-                            let prompt = request.local_prompt.unwrap_or_default();
-                            fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
-
-                            {
-                                let mut result = result.lock().unwrap();
-                                result.prompt_len = prompt.chars().count();
-
-                                for included_file in request.inputs.included_files {
-                                    let insertions =
-                                        vec![(request.inputs.cursor_point, CURSOR_MARKER)];
-                                    result.excerpts.extend(included_file.excerpts.iter().map(
-                                        |excerpt| ActualExcerpt {
-                                            path: included_file.path.components().skip(1).collect(),
-                                            text: String::from(excerpt.text.as_ref()),
-                                        },
-                                    ));
-                                    write_codeblock(
-                                        &included_file.path,
-                                        included_file.excerpts.iter(),
-                                        if included_file.path == request.inputs.cursor_path {
-                                            &insertions
-                                        } else {
-                                            &[]
-                                        },
-                                        included_file.max_row,
-                                        false,
-                                        &mut result.excerpts_text,
-                                    );
-                                }
-                            }
-
-                            let response =
-                                request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
-                            let response =
-                                edit_prediction::open_ai_response::text_from_response(response)
-                                    .unwrap_or_default();
-                            let prediction_finished_at = Instant::now();
-                            fs::write(example_run_dir.join("prediction_response.md"), &response)?;
-
-                            let mut result = result.lock().unwrap();
-                            result.generated_len = response.chars().count();
-                            result.retrieval_time =
-                                retrieval_finished_at.unwrap() - start_time.unwrap();
-                            result.prediction_time = prediction_finished_at - prediction_started_at;
-                            result.total_time = prediction_finished_at - start_time.unwrap();
-
+                        if run_ix >= repetition_count {
                             break;
                         }
                     }
+                    _ => {}
                 }
-                anyhow::Ok(())
             }
-        });
-
-        store.update(cx, |store, cx| {
-            store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
-        })?;
-    }
-
-    let prediction = store
-        .update(cx, |store, cx| {
-            store.request_prediction(
-                &project,
-                &cursor_buffer,
-                cursor_anchor,
-                cloud_llm_client::PredictEditsRequestTrigger::Cli,
-                cx,
-            )
-        })?
-        .await?;
-
-    debug_task.await?;
-
-    let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
-
-    result.diff = prediction
-        .and_then(|prediction| {
-            let prediction = prediction.prediction.ok()?;
-            prediction.edit_preview.as_unified_diff(&prediction.edits)
-        })
-        .unwrap_or_default();
-
-    anyhow::Ok(result)
-}
-
-struct RunCache {
-    cache_mode: CacheMode,
-    example_run_dir: PathBuf,
-}
+            anyhow::Ok(())
+        }
+    });
 
-impl RunCache {
-    fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
-        CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
-    }
+    for ix in 0..repetition_count {
+        current_run_ix.store(ix, SeqCst);
+        let run_dir = if repetition_count > 1 {
+            run_dir.join(format!("{:03}", ix))
+        } else {
+            run_dir.clone()
+        };
 
-    fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
-        CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
+        fs::create_dir_all(&run_dir).unwrap();
+        if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
+            fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
+        }
+        #[cfg(unix)]
+        std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
+        #[cfg(windows)]
+        std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
+
+        updated_example
+            .lock()
+            .unwrap()
+            .predictions
+            .push(ExamplePrediction {
+                actual_patch: String::new(),
+                actual_output: String::new(),
+                provider,
+            });
+
+        let prediction = ep_store
+            .update(&mut cx, |store, cx| {
+                store.request_prediction(
+                    &state.project,
+                    &state.buffer,
+                    state.cursor_position,
+                    cloud_llm_client::PredictEditsRequestTrigger::Cli,
+                    cx,
+                )
+            })
+            .unwrap()
+            .await
+            .unwrap();
+
+        updated_example
+            .lock()
+            .unwrap()
+            .predictions
+            .last_mut()
+            .unwrap()
+            .actual_patch = prediction
+            .and_then(|prediction| {
+                let prediction = prediction.prediction.ok()?;
+                prediction.edit_preview.as_unified_diff(&prediction.edits)
+            })
+            .unwrap_or_default();
     }
 
-    fn link_to_run(&self, key: &EvalCacheKey) {
-        let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
-        fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
+    ep_store
+        .update(&mut cx, |store, _| {
+            store.remove_project(&state.project);
+        })
+        .unwrap();
+    debug_task.await.unwrap();
 
-        let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
-        fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
-    }
+    *example = Arc::into_inner(updated_example)
+        .unwrap()
+        .into_inner()
+        .unwrap();
 }
 
-impl EvalCache for RunCache {
-    fn read(&self, key: EvalCacheKey) -> Option<String> {
-        let path = RunCache::output_cache_path(&key);
-
-        if path.exists() {
-            let use_cache = match key.0 {
-                EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
-                EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
-                    self.cache_mode.use_cached_llm_responses()
-                }
-            };
-            if use_cache {
-                log::info!("Using cache entry: {}", path.display());
-                self.link_to_run(&key);
-                Some(fs::read_to_string(path).unwrap())
-            } else {
-                log::trace!("Skipping cached entry: {}", path.display());
-                None
-            }
-        } else if matches!(self.cache_mode, CacheMode::Force) {
-            panic!(
-                "No cached entry found for {:?}. Run without `--cache force` at least once.",
-                key.0
-            );
-        } else {
-            None
-        }
-    }
-
-    fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
-        fs::create_dir_all(&*CACHE_DIR).unwrap();
+async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
+    let llm_model_name = "claude-sonnet-4-5";
+    let max_tokens = 16384;
+    let llm_client = if batched {
+        AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
+    } else {
+        AnthropicClient::plain()
+    };
+    let llm_client = llm_client.expect("Failed to create LLM client");
+
+    let prompt = example
+        .prompt
+        .as_ref()
+        .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
+
+    let messages = vec![anthropic::Message {
+        role: anthropic::Role::User,
+        content: vec![anthropic::RequestContent::Text {
+            text: prompt.input.clone(),
+            cache_control: None,
+        }],
+    }];
+
+    let Some(response) = llm_client
+        .generate(llm_model_name, max_tokens, messages)
+        .await
+        .unwrap()
+    else {
+        // Request stashed for batched processing
+        return;
+    };
+
+    let actual_output = response
+        .content
+        .into_iter()
+        .filter_map(|content| match content {
+            anthropic::ResponseContent::Text { text } => Some(text),
+            _ => None,
+        })
+        .collect::<Vec<String>>()
+        .join("\n");
 
-        let input_path = RunCache::input_cache_path(&key);
-        fs::write(&input_path, input).unwrap();
+    let actual_patch = TeacherPrompt::parse(example, &actual_output);
 
-        let output_path = RunCache::output_cache_path(&key);
-        log::trace!("Writing cache entry: {}", output_path.display());
-        fs::write(&output_path, output).unwrap();
+    let prediction = ExamplePrediction {
+        actual_patch,
+        actual_output,
+        provider: PredictionProvider::Teacher,
+    };
 
-        self.link_to_run(&key);
-    }
+    example.predictions.push(prediction);
 }
 
-#[derive(Clone, Debug, Serialize, Deserialize)]
-pub struct PredictionDetails {
-    pub diff: String,
-    pub excerpts: Vec<ActualExcerpt>,
-    pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
-    pub retrieval_time: Duration,
-    pub prediction_time: Duration,
-    pub total_time: Duration,
-    pub run_example_dir: PathBuf,
-    pub prompt_len: usize,
-    pub generated_len: usize,
-}
-
-impl PredictionDetails {
-    pub fn new(run_example_dir: PathBuf) -> Self {
-        Self {
-            diff: Default::default(),
-            excerpts: Default::default(),
-            excerpts_text: Default::default(),
-            retrieval_time: Default::default(),
-            prediction_time: Default::default(),
-            total_time: Default::default(),
-            run_example_dir,
-            prompt_len: 0,
-            generated_len: 0,
+pub async fn sync_batches(provider: &PredictionProvider) {
+    match provider {
+        PredictionProvider::Teacher => {
+            let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
+            let llm_client =
+                AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
+            llm_client
+                .sync_batches()
+                .await
+                .expect("Failed to sync batches");
         }
-    }
-
-    pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
-        let formatted = match format {
-            PredictionsOutputFormat::Md => self.to_markdown(),
-            PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
-            PredictionsOutputFormat::Diff => self.diff.clone(),
-        };
-
-        Ok(out.write_all(formatted.as_bytes())?)
-    }
-
-    pub fn to_markdown(&self) -> String {
-        format!(
-            "## Excerpts\n\n\
-            {}\n\n\
-            ## Prediction\n\n\
-            {}\n\n\
-            ## Time\n\n\
-            Retrieval: {}ms\n\
-            Prediction: {}ms\n\n\
-            Total: {}ms\n",
-            self.excerpts_text,
-            self.diff,
-            self.retrieval_time.as_millis(),
-            self.prediction_time.as_millis(),
-            self.total_time.as_millis(),
-        )
+        _ => (),
     }
 }

crates/edit_prediction_cli/src/util.rs → crates/edit_prediction_cli/src/retrieve_context.rs 🔗

@@ -1,106 +1,136 @@
-use anyhow::{Result, anyhow};
-use futures::channel::mpsc;
-use futures::{FutureExt as _, StreamExt as _};
+use crate::{
+    example::{Example, ExampleContext},
+    headless::EpAppState,
+    load_project::run_load_project,
+};
+use anyhow::Result;
+use collections::HashSet;
+use edit_prediction::{DebugEvent, EditPredictionStore};
+use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
 use gpui::{AsyncApp, Entity, Task};
-use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus};
-use project::lsp_store::OpenLspBufferHandle;
-use project::{Project, ProjectPath, Worktree};
-use std::collections::HashSet;
-use std::sync::Arc;
-use std::time::Duration;
-use util::rel_path::RelPath;
-
-pub fn open_buffer(
-    project: Entity<Project>,
-    worktree: Entity<Worktree>,
-    path: Arc<RelPath>,
-    cx: &AsyncApp,
-) -> Task<Result<Entity<Buffer>>> {
-    cx.spawn(async move |cx| {
-        let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
-            worktree_id: worktree.id(),
-            path,
-        })?;
-
-        let buffer = project
-            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
-            .await?;
-
-        let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
-        while *parse_status.borrow() != ParseStatus::Idle {
-            parse_status.changed().await?;
+use language::{Buffer, LanguageNotFound};
+use project::Project;
+use std::{sync::Arc, time::Duration};
+
+pub async fn run_context_retrieval(
+    example: &mut Example,
+    app_state: Arc<EpAppState>,
+    mut cx: AsyncApp,
+) {
+    if example.context.is_some() {
+        return;
+    }
+
+    run_load_project(example, app_state.clone(), cx.clone()).await;
+
+    let state = example.state.as_ref().unwrap();
+    let project = state.project.clone();
+
+    let _lsp_handle = project
+        .update(&mut cx, |project, cx| {
+            project.register_buffer_with_language_servers(&state.buffer, cx)
+        })
+        .unwrap();
+
+    wait_for_language_server_to_start(example, &project, &state.buffer, &mut cx).await;
+
+    let ep_store = cx
+        .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+        .unwrap();
+
+    let mut events = ep_store
+        .update(&mut cx, |store, cx| {
+            store.register_buffer(&state.buffer, &project, cx);
+            store.set_use_context(true);
+            store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
+            store.debug_info(&project, cx)
+        })
+        .unwrap();
+
+    while let Some(event) = events.next().await {
+        match event {
+            DebugEvent::ContextRetrievalFinished(_) => {
+                break;
+            }
+            _ => {}
         }
+    }
 
-        Ok(buffer)
-    })
+    let context_files = ep_store
+        .update(&mut cx, |store, cx| store.context_for_project(&project, cx))
+        .unwrap();
+
+    example.context = Some(ExampleContext {
+        files: context_files,
+    });
 }
 
-pub async fn open_buffer_with_language_server(
-    project: Entity<Project>,
-    worktree: Entity<Worktree>,
-    path: Arc<RelPath>,
-    ready_languages: &mut HashSet<LanguageId>,
+async fn wait_for_language_server_to_start(
+    example: &Example,
+    project: &Entity<Project>,
+    buffer: &Entity<Buffer>,
     cx: &mut AsyncApp,
-) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity<Buffer>)> {
-    let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
-
-    let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
-        (
-            project.register_buffer_with_language_servers(&buffer, cx),
-            project.path_style(cx),
-        )
-    })?;
-
-    let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
+) {
+    let language_registry = project
+        .read_with(cx, |project, _| project.languages().clone())
+        .unwrap();
     let result = language_registry
-        .load_language_for_file_path(path.as_std_path())
+        .load_language_for_file_path(&example.cursor_path)
         .await;
 
     if let Err(error) = result
         && !error.is::<LanguageNotFound>()
     {
-        anyhow::bail!(error);
+        panic!("Failed to load language for file path: {}", error);
     }
 
-    let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
-        buffer.language().map(|language| language.id())
-    })?
+    let Some(language_id) = buffer
+        .read_with(cx, |buffer, _cx| {
+            buffer.language().map(|language| language.id())
+        })
+        .unwrap()
     else {
-        return Err(anyhow!("No language for {}", path.display(path_style)));
+        panic!("No language for {:?}", example.cursor_path);
     };
 
-    let log_prefix = format!("{} | ", path.display(path_style));
+    let mut ready_languages = HashSet::default();
+    let log_prefix = format!("{} | ", example.name);
     if !ready_languages.contains(&language_id) {
-        wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
+        wait_for_lang_server(&project, &buffer, log_prefix, cx)
+            .await
+            .unwrap();
         ready_languages.insert(language_id);
     }
 
-    let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
+    let lsp_store = project
+        .read_with(cx, |project, _cx| project.lsp_store())
+        .unwrap();
 
     // hacky wait for buffer to be registered with the language server
     for _ in 0..100 {
-        let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
-            buffer.update(cx, |buffer, cx| {
-                lsp_store
-                    .language_servers_for_local_buffer(&buffer, cx)
-                    .next()
-                    .map(|(_, language_server)| language_server.server_id())
+        if lsp_store
+            .update(cx, |lsp_store, cx| {
+                buffer.update(cx, |buffer, cx| {
+                    lsp_store
+                        .language_servers_for_local_buffer(&buffer, cx)
+                        .next()
+                        .map(|(_, language_server)| language_server.server_id())
+                })
             })
-        })?
-        else {
+            .unwrap()
+            .is_some()
+        {
+            return;
+        } else {
             cx.background_executor()
                 .timer(Duration::from_millis(10))
                 .await;
-            continue;
-        };
-
-        return Ok((lsp_open_handle, language_server_id, buffer));
+        }
     }
 
-    return Err(anyhow!("No language server found for buffer"));
+    panic!("No language server found for buffer");
 }
 
-// TODO: Dedupe with similar function in crates/eval/src/instance.rs
 pub fn wait_for_lang_server(
     project: &Entity<Project>,
     buffer: &Entity<Buffer>,

crates/edit_prediction_cli/src/score.rs 🔗

@@ -0,0 +1,119 @@
+use crate::{
+    PredictArgs,
+    example::{Example, ExampleScore},
+    headless::EpAppState,
+    metrics::{self, ClassificationMetrics},
+    predict::run_prediction,
+};
+use edit_prediction::udiff::DiffLine;
+use gpui::AsyncApp;
+use std::sync::Arc;
+
+pub async fn run_scoring(
+    example: &mut Example,
+    args: &PredictArgs,
+    app_state: Arc<EpAppState>,
+    cx: AsyncApp,
+) {
+    run_prediction(
+        example,
+        Some(args.provider),
+        args.repetitions,
+        app_state,
+        cx,
+    )
+    .await;
+
+    let expected_patch = parse_patch(&example.expected_patch);
+
+    let mut scores = vec![];
+
+    for pred in &example.predictions {
+        let actual_patch = parse_patch(&pred.actual_patch);
+        let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
+        let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
+
+        scores.push(ExampleScore {
+            delta_chr_f,
+            line_match,
+        });
+    }
+
+    example.score = scores;
+}
+
+fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
+    patch.lines().map(DiffLine::parse).collect()
+}
+
+pub fn print_report(examples: &[Example]) {
+    eprintln!(
+        "──────────────────────────────────────────────────────────────────────────────────────"
+    );
+    eprintln!(
+        "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
+        "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
+    );
+    eprintln!(
+        "──────────────────────────────────────────────────────────────────────────────────────"
+    );
+
+    let mut all_line_match_scores = Vec::new();
+    let mut all_delta_chr_f_scores = Vec::new();
+
+    for example in examples {
+        for score in example.score.iter() {
+            let line_match = &score.line_match;
+
+            eprintln!(
+                "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
+                truncate_name(&example.name, 30),
+                line_match.true_positives,
+                line_match.false_positives,
+                line_match.false_negatives,
+                line_match.precision() * 100.0,
+                line_match.recall() * 100.0,
+                line_match.f1_score() * 100.0,
+                score.delta_chr_f
+            );
+
+            all_line_match_scores.push(line_match.clone());
+            all_delta_chr_f_scores.push(score.delta_chr_f);
+        }
+    }
+
+    eprintln!(
+        "──────────────────────────────────────────────────────────────────────────────────────"
+    );
+
+    if !all_line_match_scores.is_empty() {
+        let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
+        let avg_delta_chr_f: f32 =
+            all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
+
+        eprintln!(
+            "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
+            "TOTAL",
+            total_line_match.true_positives,
+            total_line_match.false_positives,
+            total_line_match.false_negatives,
+            total_line_match.precision() * 100.0,
+            total_line_match.recall() * 100.0,
+            total_line_match.f1_score() * 100.0,
+            avg_delta_chr_f
+        );
+        eprintln!(
+            "──────────────────────────────────────────────────────────────────────────────────────"
+        );
+    }
+
+    eprintln!("\n");
+}
+
+fn truncate_name(name: &str, max_len: usize) -> String {
+    if name.len() <= max_len {
+        name.to_string()
+    } else {
+        format!("{}...", &name[..max_len - 3])
+    }
+}

crates/edit_prediction_cli/src/source_location.rs 🔗

@@ -1,70 +0,0 @@
-use std::{fmt, fmt::Display, path::Path, str::FromStr, sync::Arc};
-
-use ::util::{paths::PathStyle, rel_path::RelPath};
-use anyhow::{Result, anyhow};
-use language::Point;
-use serde::{Deserialize, Deserializer, Serialize, Serializer};
-
-#[derive(Debug, Clone, Hash, Eq, PartialEq)]
-pub struct SourceLocation {
-    pub path: Arc<RelPath>,
-    pub point: Point,
-}
-
-impl Serialize for SourceLocation {
-    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
-    where
-        S: Serializer,
-    {
-        serializer.serialize_str(&self.to_string())
-    }
-}
-
-impl<'de> Deserialize<'de> for SourceLocation {
-    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-    where
-        D: Deserializer<'de>,
-    {
-        let s = String::deserialize(deserializer)?;
-        s.parse().map_err(serde::de::Error::custom)
-    }
-}
-
-impl Display for SourceLocation {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        write!(
-            f,
-            "{}:{}:{}",
-            self.path.display(PathStyle::Posix),
-            self.point.row + 1,
-            self.point.column + 1
-        )
-    }
-}
-
-impl FromStr for SourceLocation {
-    type Err = anyhow::Error;
-
-    fn from_str(s: &str) -> Result<Self> {
-        let parts: Vec<&str> = s.split(':').collect();
-        if parts.len() != 3 {
-            return Err(anyhow!(
-                "Invalid source location. Expected 'file.rs:line:column', got '{}'",
-                s
-            ));
-        }
-
-        let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
-        let line: u32 = parts[1]
-            .parse()
-            .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
-        let column: u32 = parts[2]
-            .parse()
-            .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
-
-        // Convert from 1-based to 0-based indexing
-        let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
-
-        Ok(SourceLocation { path, point })
-    }
-}

crates/edit_prediction_cli/src/training/context.rs 🔗

@@ -1,89 +0,0 @@
-use std::path::Path;
-
-use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
-
-#[derive(Debug, Clone, Default, clap::ValueEnum)]
-pub enum ContextType {
-    #[default]
-    CurrentFile,
-}
-
-const MAX_CONTEXT_SIZE: usize = 32768;
-
-pub fn collect_context(
-    context_type: &ContextType,
-    worktree_dir: &Path,
-    cursor: SourceLocation,
-) -> String {
-    let context = match context_type {
-        ContextType::CurrentFile => {
-            let file_path = worktree_dir.join(cursor.path.as_std_path());
-            let context = std::fs::read_to_string(&file_path).unwrap_or_default();
-
-            let context = add_special_tags(&context, worktree_dir, cursor);
-            context
-        }
-    };
-
-    let region_end_offset = context.find(TeacherModel::REGION_END);
-
-    if context.len() <= MAX_CONTEXT_SIZE {
-        return context;
-    }
-
-    if let Some(region_end_offset) = region_end_offset
-        && region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
-    {
-        let to_truncate = context.len() - MAX_CONTEXT_SIZE;
-        format!(
-            "[...{} bytes truncated]\n{}\n",
-            to_truncate,
-            &context[to_truncate..]
-        )
-    } else {
-        format!(
-            "{}\n[...{} bytes truncated]\n",
-            &context[..MAX_CONTEXT_SIZE],
-            context.len() - MAX_CONTEXT_SIZE
-        )
-    }
-}
-
-/// Add <|editable_region_start/end|> tags
-fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
-    let path = worktree_dir.join(cursor.path.as_std_path());
-    let file = std::fs::read_to_string(&path).unwrap_or_default();
-    let lines = file.lines().collect::<Vec<_>>();
-    let cursor_row = cursor.point.row as usize;
-    let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
-    let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
-
-    let snippet = lines[start_line..end_line].join("\n");
-
-    if context.contains(&snippet) {
-        let mut cursor_line = lines[cursor_row].to_string();
-        cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
-
-        let mut snippet_with_tags_lines = vec![];
-        snippet_with_tags_lines.push(TeacherModel::REGION_START);
-        snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
-        snippet_with_tags_lines.push(&cursor_line);
-        snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
-        snippet_with_tags_lines.push(TeacherModel::REGION_END);
-        let snippet_with_tags = snippet_with_tags_lines.join("\n");
-
-        context.replace(&snippet, &snippet_with_tags)
-    } else {
-        log::warn!(
-            "Can't find area around the cursor in the context; proceeding without special tags"
-        );
-        context.to_string()
-    }
-}
-
-pub fn strip_special_tags(context: &str) -> String {
-    context
-        .replace(TeacherModel::REGION_START, "")
-        .replace(TeacherModel::REGION_END, "")
-        .replace(TeacherModel::USER_CURSOR, "")
-}

crates/edit_prediction_cli/src/training/distill.rs 🔗

@@ -1,94 +0,0 @@
-use serde::Deserialize;
-use std::sync::Arc;
-
-use crate::{
-    DistillArguments,
-    example::Example,
-    source_location::SourceLocation,
-    training::{
-        context::ContextType,
-        llm_client::LlmClient,
-        teacher::{TeacherModel, TeacherOutput},
-    },
-};
-use anyhow::Result;
-use reqwest_client::ReqwestClient;
-
-#[derive(Debug, Deserialize)]
-pub struct SplitCommit {
-    repo_url: String,
-    commit_sha: String,
-    edit_history: String,
-    expected_patch: String,
-    cursor_position: String,
-}
-
-pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
-    let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
-        .expect("Failed to read split commit dataset")
-        .lines()
-        .map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
-        .collect();
-
-    let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
-
-    let llm_client = if let Some(cache_path) = arguments.batch {
-        LlmClient::batch(&cache_path, http_client)?
-    } else {
-        LlmClient::plain(http_client)?
-    };
-
-    let mut teacher = TeacherModel::new(
-        "claude-sonnet-4-5".to_string(),
-        ContextType::CurrentFile,
-        llm_client,
-    );
-
-    let mut num_marked_for_batching = 0;
-
-    for commit in split_commits {
-        if let Some(distilled) = distill_one(&mut teacher, commit).await? {
-            println!("{}", serde_json::to_string(&distilled)?);
-        } else {
-            if num_marked_for_batching == 0 {
-                log::warn!("Marked for batching");
-            }
-            num_marked_for_batching += 1;
-        }
-    }
-
-    eprintln!(
-        "{} requests are marked for batching",
-        num_marked_for_batching
-    );
-    let llm_client = teacher.client;
-    llm_client.sync_batches().await?;
-
-    Ok(())
-}
-
-pub async fn distill_one(
-    teacher: &mut TeacherModel,
-    commit: SplitCommit,
-) -> Result<Option<TeacherOutput>> {
-    let cursor: SourceLocation = commit
-        .cursor_position
-        .parse()
-        .expect("Failed to parse cursor position");
-
-    let path = cursor.path.to_rel_path_buf();
-
-    let example = Example {
-        repository_url: commit.repo_url,
-        revision: commit.commit_sha,
-        uncommitted_diff: commit.edit_history.clone(),
-        cursor_path: path.as_std_path().to_path_buf(),
-        cursor_position: commit.cursor_position,
-        edit_history: commit.edit_history, // todo: trim
-        expected_patch: commit.expected_patch,
-    };
-
-    let prediction = teacher.predict(example).await;
-
-    prediction
-}

crates/edit_prediction_cli/src/training/teacher.rs 🔗

@@ -1,266 +0,0 @@
-use crate::{
-    example::Example,
-    source_location::SourceLocation,
-    training::{
-        context::{ContextType, collect_context, strip_special_tags},
-        llm_client::LlmClient,
-    },
-};
-use anthropic::{Message, RequestContent, ResponseContent, Role};
-use anyhow::Result;
-
-pub struct TeacherModel {
-    pub llm_name: String,
-    pub context: ContextType,
-    pub client: LlmClient,
-}
-
-#[derive(Debug, serde::Serialize)]
-pub struct TeacherOutput {
-    parsed_output: String,
-    prompt: String,
-    raw_llm_response: String,
-    context: String,
-    diff: String,
-}
-
-impl TeacherModel {
-    const PROMPT: &str = include_str!("teacher.prompt.md");
-    pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
-    pub(crate) const REGION_END: &str = "<|editable_region_end|>";
-    pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
-
-    /// Number of lines to include before the cursor position
-    pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
-
-    /// Number of lines to include after the cursor position
-    pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
-
-    /// Truncate edit history to this number of last lines
-    const MAX_HISTORY_LINES: usize = 128;
-
-    pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
-        TeacherModel {
-            llm_name,
-            context,
-            client,
-        }
-    }
-
-    pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
-        let name = input.unique_name();
-        let worktree_dir = input.setup_worktree(name).await?;
-        let cursor: SourceLocation = input
-            .cursor_position
-            .parse()
-            .expect("Failed to parse cursor position");
-
-        let context = collect_context(&self.context, &worktree_dir, cursor.clone());
-        let edit_history = Self::format_edit_history(&input.edit_history);
-
-        let prompt = Self::PROMPT
-            .replace("{{context}}", &context)
-            .replace("{{edit_history}}", &edit_history);
-
-        let messages = vec![Message {
-            role: Role::User,
-            content: vec![RequestContent::Text {
-                text: prompt.clone(),
-                cache_control: None,
-            }],
-        }];
-
-        let Some(response) = self
-            .client
-            .generate(self.llm_name.clone(), 16384, messages)
-            .await?
-        else {
-            return Ok(None);
-        };
-
-        let response_text = response
-            .content
-            .into_iter()
-            .filter_map(|content| match content {
-                ResponseContent::Text { text } => Some(text),
-                _ => None,
-            })
-            .collect::<Vec<String>>()
-            .join("\n");
-
-        let parsed_output = self.parse_response(&response_text);
-
-        let original_editable_region = Self::extract_editable_region(&context);
-        let context_after_edit = context.replace(&original_editable_region, &parsed_output);
-        let context_after_edit = strip_special_tags(&context_after_edit);
-        let context_before_edit = strip_special_tags(&context);
-        let diff = language::unified_diff(&context_before_edit, &context_after_edit);
-
-        // zeta distill --batch batch_results.txt
-        // zeta distill
-        // 1. Run `zeta distill <2000 examples <- all examples>` for the first time
-        //  - store LLM requests in a batch, don't actual send the request
-        //  - send the batch (2000 requests) after all inputs are processed
-        // 2. `zeta send-batches`
-        //   - upload the batch to Anthropic
-
-        // https://platform.claude.com/docs/en/build-with-claude/batch-processing
-        // https://crates.io/crates/anthropic-sdk-rust
-
-        //   - poll for results
-        //   - when ready, store results in cache (a database)
-        // 3. `zeta distill` again
-        //    - use the cached results this time
-
-        Ok(Some(TeacherOutput {
-            parsed_output,
-            prompt,
-            raw_llm_response: response_text,
-            context,
-            diff,
-        }))
-    }
-
-    fn parse_response(&self, content: &str) -> String {
-        let codeblock = Self::extract_last_codeblock(content);
-        let editable_region = Self::extract_editable_region(&codeblock);
-
-        editable_region
-    }
-
-    /// Extract content from the last code-fenced block if any, or else return content as is
-    fn extract_last_codeblock(text: &str) -> String {
-        let mut last_block = None;
-        let mut search_start = 0;
-
-        while let Some(start) = text[search_start..].find("```") {
-            let start = start + search_start;
-            let bytes = text.as_bytes();
-            let mut backtick_end = start;
-
-            while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
-                backtick_end += 1;
-            }
-
-            let backtick_count = backtick_end - start;
-            let closing_backticks = "`".repeat(backtick_count);
-
-            if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
-                let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
-                last_block = Some(code_block.to_string());
-                search_start = backtick_end + end_pos + backtick_count;
-            } else {
-                break;
-            }
-        }
-
-        last_block.unwrap_or_else(|| text.to_string())
-    }
-
-    fn extract_editable_region(text: &str) -> String {
-        let start = text
-            .find(Self::REGION_START)
-            .map_or(0, |pos| pos + Self::REGION_START.len());
-        let end = text.find(Self::REGION_END).unwrap_or(text.len());
-
-        text[start..end].to_string()
-    }
-
-    /// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
-    fn format_edit_history(edit_history: &str) -> String {
-        let lines = edit_history
-            .lines()
-            .filter(|&s| Self::is_content_line(s))
-            .collect::<Vec<_>>();
-
-        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
-            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
-        } else {
-            &lines
-        };
-        history_lines.join("\n")
-    }
-
-    fn is_content_line(s: &str) -> bool {
-        s.starts_with("-")
-            || s.starts_with("+")
-            || s.starts_with(" ")
-            || s.starts_with("---")
-            || s.starts_with("+++")
-            || s.starts_with("@@")
-    }
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-
-    #[test]
-    fn test_parse_response() {
-        let teacher = TeacherModel::new(
-            "test".to_string(),
-            ContextType::CurrentFile,
-            LlmClient::dummy(),
-        );
-        let response = "This is a test response.";
-        let parsed = teacher.parse_response(response);
-        assert_eq!(parsed, response.to_string());
-
-        let response = indoc::indoc! {"
-            Some thinking
-
-            `````
-            actual response
-            `````
-            "};
-        let parsed = teacher.parse_response(response);
-        assert_eq!(parsed, "actual response");
-    }
-
-    #[test]
-    fn test_extract_last_code_block() {
-        let text = indoc::indoc! {"
-            Some thinking
-
-            ```
-            first block
-            ```
-
-            `````
-            last block
-            `````
-            "};
-        let last_block = TeacherModel::extract_last_codeblock(text);
-        assert_eq!(last_block, "last block");
-    }
-
-    #[test]
-    fn test_extract_editable_region() {
-        let teacher = TeacherModel::new(
-            "test".to_string(),
-            ContextType::CurrentFile,
-            LlmClient::dummy(),
-        );
-        let response = indoc::indoc! {"
-            some lines
-            are
-            here
-            <|editable_region_start|>
-            one
-            two three
-
-            <|editable_region_end|>
-            more
-            lines here
-            "};
-        let parsed = teacher.parse_response(response);
-        assert_eq!(
-            parsed,
-            indoc::indoc! {"
-            one
-            two three
-
-            "}
-        );
-    }
-}

crates/edit_prediction_context/Cargo.toml 🔗

@@ -26,6 +26,7 @@ serde.workspace = true
 smallvec.workspace = true
 tree-sitter.workspace = true
 util.workspace = true
+zeta_prompt.workspace = true
 
 [dev-dependencies]
 env_logger.workspace = true

crates/edit_prediction_context/src/assemble_excerpts.rs 🔗

@@ -1,6 +1,6 @@
-use crate::RelatedExcerpt;
 use language::{BufferSnapshot, OffsetRangeExt as _, Point};
 use std::ops::Range;
+use zeta_prompt::RelatedExcerpt;
 
 #[cfg(not(test))]
 const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512;
@@ -76,14 +76,9 @@ pub fn assemble_excerpts(
 
     input_ranges
         .into_iter()
-        .map(|range| {
-            let offset_range = range.to_offset(buffer);
-            RelatedExcerpt {
-                point_range: range,
-                anchor_range: buffer.anchor_before(offset_range.start)
-                    ..buffer.anchor_after(offset_range.end),
-                text: buffer.as_rope().slice(offset_range),
-            }
+        .map(|range| RelatedExcerpt {
+            row_range: range.start.row..range.end.row,
+            text: buffer.text_for_range(range).collect(),
         })
         .collect()
 }

crates/edit_prediction_context/src/edit_prediction_context.rs 🔗

@@ -3,13 +3,13 @@ use anyhow::Result;
 use collections::HashMap;
 use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
 use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
+use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
 use project::{LocationLink, Project, ProjectPath};
-use serde::{Serialize, Serializer};
 use smallvec::SmallVec;
 use std::{
     collections::hash_map,
     ops::Range,
+    path::Path,
     sync::Arc,
     time::{Duration, Instant},
 };
@@ -24,12 +24,14 @@ mod fake_definition_lsp;
 
 pub use cloud_llm_client::predict_edits_v3::Line;
 pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
+pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
 
 const IDENTIFIER_LINE_COUNT: u32 = 3;
 
 pub struct RelatedExcerptStore {
     project: WeakEntity<Project>,
-    related_files: Vec<RelatedFile>,
+    related_files: Arc<[RelatedFile]>,
+    related_file_buffers: Vec<Entity<Buffer>>,
     cache: HashMap<Identifier, Arc<CacheEntry>>,
     update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
     identifier_line_count: u32,
@@ -68,82 +70,6 @@ struct CachedDefinition {
     anchor_range: Range<Anchor>,
 }
 
-#[derive(Clone, Debug, Serialize)]
-pub struct RelatedFile {
-    #[serde(serialize_with = "serialize_project_path")]
-    pub path: ProjectPath,
-    #[serde(skip)]
-    pub buffer: WeakEntity<Buffer>,
-    pub excerpts: Vec<RelatedExcerpt>,
-    pub max_row: u32,
-}
-
-impl RelatedFile {
-    pub fn merge_excerpts(&mut self) {
-        self.excerpts.sort_unstable_by(|a, b| {
-            a.point_range
-                .start
-                .cmp(&b.point_range.start)
-                .then(b.point_range.end.cmp(&a.point_range.end))
-        });
-
-        let mut index = 1;
-        while index < self.excerpts.len() {
-            if self.excerpts[index - 1]
-                .point_range
-                .end
-                .cmp(&self.excerpts[index].point_range.start)
-                .is_ge()
-            {
-                let removed = self.excerpts.remove(index);
-                if removed
-                    .point_range
-                    .end
-                    .cmp(&self.excerpts[index - 1].point_range.end)
-                    .is_gt()
-                {
-                    self.excerpts[index - 1].point_range.end = removed.point_range.end;
-                    self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
-                }
-            } else {
-                index += 1;
-            }
-        }
-    }
-}
-
-#[derive(Clone, Debug, Serialize)]
-pub struct RelatedExcerpt {
-    #[serde(skip)]
-    pub anchor_range: Range<Anchor>,
-    #[serde(serialize_with = "serialize_point_range")]
-    pub point_range: Range<Point>,
-    #[serde(serialize_with = "serialize_rope")]
-    pub text: Rope,
-}
-
-fn serialize_project_path<S: Serializer>(
-    project_path: &ProjectPath,
-    serializer: S,
-) -> Result<S::Ok, S::Error> {
-    project_path.path.serialize(serializer)
-}
-
-fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
-    rope.to_string().serialize(serializer)
-}
-
-fn serialize_point_range<S: Serializer>(
-    range: &Range<Point>,
-    serializer: S,
-) -> Result<S::Ok, S::Error> {
-    [
-        [range.start.row, range.start.column],
-        [range.end.row, range.end.column],
-    ]
-    .serialize(serializer)
-}
-
 const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
 
 impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
@@ -179,7 +105,8 @@ impl RelatedExcerptStore {
         RelatedExcerptStore {
             project: project.downgrade(),
             update_tx,
-            related_files: Vec::new(),
+            related_files: Vec::new().into(),
+            related_file_buffers: Vec::new(),
             cache: Default::default(),
             identifier_line_count: IDENTIFIER_LINE_COUNT,
         }
@@ -193,8 +120,21 @@ impl RelatedExcerptStore {
         self.update_tx.unbounded_send((buffer, position)).ok();
     }
 
-    pub fn related_files(&self) -> &[RelatedFile] {
-        &self.related_files
+    pub fn related_files(&self) -> Arc<[RelatedFile]> {
+        self.related_files.clone()
+    }
+
+    pub fn related_files_with_buffers(
+        &self,
+    ) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
+        self.related_files
+            .iter()
+            .cloned()
+            .zip(self.related_file_buffers.iter().cloned())
+    }
+
+    pub fn set_related_files(&mut self, files: Vec<RelatedFile>) {
+        self.related_files = files.into();
     }
 
     async fn fetch_excerpts(
@@ -297,7 +237,8 @@ impl RelatedExcerptStore {
         }
         mean_definition_latency /= cache_miss_count.max(1) as u32;
 
-        let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
+        let (new_cache, related_files, related_file_buffers) =
+            rebuild_related_files(&project, new_cache, cx).await?;
 
         if let Some(file) = &file {
             log::debug!(
@@ -309,7 +250,8 @@ impl RelatedExcerptStore {
 
         this.update(cx, |this, cx| {
             this.cache = new_cache;
-            this.related_files = related_files;
+            this.related_files = related_files.into();
+            this.related_file_buffers = related_file_buffers;
             cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
                 cache_hit_count,
                 cache_miss_count,
@@ -323,10 +265,16 @@ impl RelatedExcerptStore {
 }
 
 async fn rebuild_related_files(
+    project: &Entity<Project>,
     new_entries: HashMap<Identifier, Arc<CacheEntry>>,
     cx: &mut AsyncApp,
-) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
+) -> Result<(
+    HashMap<Identifier, Arc<CacheEntry>>,
+    Vec<RelatedFile>,
+    Vec<Entity<Buffer>>,
+)> {
     let mut snapshots = HashMap::default();
+    let mut worktree_root_names = HashMap::default();
     for entry in new_entries.values() {
         for definition in &entry.definitions {
             if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
@@ -340,12 +288,22 @@ async fn rebuild_related_files(
                         .read_with(cx, |buffer, _| buffer.snapshot())?,
                 );
             }
+            let worktree_id = definition.path.worktree_id;
+            if let hash_map::Entry::Vacant(e) =
+                worktree_root_names.entry(definition.path.worktree_id)
+            {
+                project.read_with(cx, |project, cx| {
+                    if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
+                        e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
+                    }
+                })?;
+            }
         }
     }
 
     Ok(cx
         .background_spawn(async move {
-            let mut files = Vec::<RelatedFile>::new();
+            let mut files = Vec::new();
             let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
             let mut paths_by_buffer = HashMap::default();
             for entry in new_entries.values() {
@@ -369,16 +327,31 @@ async fn rebuild_related_files(
                     continue;
                 };
                 let excerpts = assemble_excerpts(snapshot, ranges);
-                files.push(RelatedFile {
-                    path: project_path.clone(),
-                    buffer: buffer.downgrade(),
-                    excerpts,
-                    max_row: snapshot.max_point().row,
-                });
+                let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else {
+                    continue;
+                };
+
+                let path = Path::new(&format!(
+                    "{}/{}",
+                    root_name,
+                    project_path.path.as_unix_str()
+                ))
+                .into();
+
+                files.push((
+                    buffer,
+                    RelatedFile {
+                        path,
+                        excerpts,
+                        max_row: snapshot.max_point().row,
+                    },
+                ));
             }
 
-            files.sort_by_key(|file| file.path.clone());
-            (new_entries, files)
+            files.sort_by_key(|(_, file)| file.path.clone());
+            let (related_buffers, related_files) = files.into_iter().unzip();
+
+            (new_entries, related_files, related_buffers)
         })
         .await)
 }

crates/edit_prediction_context/src/edit_prediction_context_tests.rs 🔗

@@ -48,7 +48,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
             &excerpts,
             &[
                 (
-                    "src/company.rs",
+                    "root/src/company.rs",
                     &[indoc! {"
                         pub struct Company {
                             owner: Arc<Person>,
@@ -56,7 +56,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
                         }"}],
                 ),
                 (
-                    "src/main.rs",
+                    "root/src/main.rs",
                     &[
                         indoc! {"
                         pub struct Session {
@@ -71,7 +71,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
                     ],
                 ),
                 (
-                    "src/person.rs",
+                    "root/src/person.rs",
                     &[
                         indoc! {"
                         impl Person {
@@ -446,7 +446,7 @@ fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &
                 .iter()
                 .map(|excerpt| excerpt.text.to_string())
                 .collect::<Vec<_>>();
-            (file.path.path.as_unix_str(), excerpts)
+            (file.path.to_str().unwrap(), excerpts)
         })
         .collect::<Vec<_>>();
     let expected_excerpts = expected_files
@@ -492,10 +492,10 @@ fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
         if excerpt.text.is_empty() {
             continue;
         }
-        if current_row < excerpt.point_range.start.row {
+        if current_row < excerpt.row_range.start {
             writeln!(&mut output, "…").unwrap();
         }
-        current_row = excerpt.point_range.start.row;
+        current_row = excerpt.row_range.start;
 
         for line in excerpt.text.to_string().lines() {
             output.push_str(line);

crates/edit_prediction_ui/Cargo.toml 🔗

@@ -17,7 +17,6 @@ anyhow.workspace = true
 buffer_diff.workspace = true
 client.workspace = true
 cloud_llm_client.workspace = true
-cloud_zeta2_prompt.workspace = true
 codestral.workspace = true
 command_palette_hooks.workspace = true
 copilot.workspace = true
@@ -46,6 +45,7 @@ ui_input.workspace = true
 util.workspace = true
 workspace.workspace = true
 zed_actions.workspace = true
+zeta_prompt.workspace = true
 
 [dev-dependencies]
 copilot = { workspace = true, features = ["test-support"] }

crates/edit_prediction_ui/src/edit_prediction_context_view.rs 🔗

@@ -17,7 +17,7 @@ use gpui::{
 };
 use multi_buffer::MultiBuffer;
 use project::Project;
-use text::OffsetRangeExt;
+use text::Point;
 use ui::{
     ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName,
     StyledTypography as _, h_flex, v_flex,
@@ -66,7 +66,7 @@ impl EditPredictionContextView {
     ) -> Self {
         let store = EditPredictionStore::global(client, user_store, cx);
 
-        let mut debug_rx = store.update(cx, |store, _| store.debug_info());
+        let mut debug_rx = store.update(cx, |store, cx| store.debug_info(&project, cx));
         let _update_task = cx.spawn_in(window, async move |this, cx| {
             while let Some(event) = debug_rx.next().await {
                 this.update_in(cx, |this, window, cx| {
@@ -103,7 +103,8 @@ impl EditPredictionContextView {
                     self.handle_context_retrieval_finished(info, window, cx);
                 }
             }
-            DebugEvent::EditPredictionRequested(_) => {}
+            DebugEvent::EditPredictionStarted(_) => {}
+            DebugEvent::EditPredictionFinished(_) => {}
         }
     }
 
@@ -152,12 +153,11 @@ impl EditPredictionContextView {
         run.finished_at = Some(info.timestamp);
         run.metadata = info.metadata;
 
-        let project = self.project.clone();
         let related_files = self
             .store
             .read(cx)
-            .context_for_project(&self.project, cx)
-            .to_vec();
+            .context_for_project_with_buffers(&self.project, cx)
+            .map_or(Vec::new(), |files| files.collect());
 
         let editor = run.editor.clone();
         let multibuffer = run.editor.read(cx).buffer().clone();
@@ -168,33 +168,14 @@ impl EditPredictionContextView {
 
         cx.spawn_in(window, async move |this, cx| {
             let mut paths = Vec::new();
-            for related_file in related_files {
-                let (buffer, point_ranges): (_, Vec<_>) =
-                    if let Some(buffer) = related_file.buffer.upgrade() {
-                        let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
-
-                        (
-                            buffer,
-                            related_file
-                                .excerpts
-                                .iter()
-                                .map(|excerpt| excerpt.anchor_range.to_point(&snapshot))
-                                .collect(),
-                        )
-                    } else {
-                        (
-                            project
-                                .update(cx, |project, cx| {
-                                    project.open_buffer(related_file.path.clone(), cx)
-                                })?
-                                .await?,
-                            related_file
-                                .excerpts
-                                .iter()
-                                .map(|excerpt| excerpt.point_range.clone())
-                                .collect(),
-                        )
-                    };
+            for (related_file, buffer) in related_files {
+                let point_ranges = related_file
+                    .excerpts
+                    .iter()
+                    .map(|excerpt| {
+                        Point::new(excerpt.row_range.start, 0)..Point::new(excerpt.row_range.end, 0)
+                    })
+                    .collect::<Vec<_>>();
                 cx.update(|_, cx| {
                     let path = PathKey::for_buffer(&buffer, cx);
                     paths.push((path, buffer, point_ranges));

crates/edit_prediction_ui/src/rate_prediction_modal.rs 🔗

@@ -1,5 +1,4 @@
 use buffer_diff::{BufferDiff, BufferDiffSnapshot};
-use cloud_zeta2_prompt::write_codeblock;
 use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore};
 use editor::{Editor, ExcerptRange, MultiBuffer};
 use feature_flags::FeatureFlag;
@@ -362,14 +361,14 @@ impl RatePredictionsModal {
             write!(&mut formatted_inputs, "## Events\n\n").unwrap();
 
             for event in &prediction.inputs.events {
-                write!(&mut formatted_inputs, "```diff\n{event}```\n\n").unwrap();
+                formatted_inputs.push_str("```diff\n");
+                zeta_prompt::write_event(&mut formatted_inputs, event.as_ref());
+                formatted_inputs.push_str("```\n\n");
             }
 
-            write!(&mut formatted_inputs, "## Included files\n\n").unwrap();
-
-            for included_file in &prediction.inputs.included_files {
-                let cursor_insertions = &[(prediction.inputs.cursor_point, "<|CURSOR|>")];
+            write!(&mut formatted_inputs, "## Related files\n\n").unwrap();
 
+            for included_file in prediction.inputs.related_files.as_ref() {
                 write!(
                     &mut formatted_inputs,
                     "### {}\n\n",
@@ -377,20 +376,28 @@ impl RatePredictionsModal {
                 )
                 .unwrap();
 
-                write_codeblock(
-                    &included_file.path,
-                    &included_file.excerpts,
-                    if included_file.path == prediction.inputs.cursor_path {
-                        cursor_insertions.as_slice()
-                    } else {
-                        &[]
-                    },
-                    included_file.max_row,
-                    false,
-                    &mut formatted_inputs,
-                );
+                for excerpt in included_file.excerpts.iter() {
+                    write!(
+                        &mut formatted_inputs,
+                        "```{}\n{}\n```\n",
+                        included_file.path.display(),
+                        excerpt.text
+                    )
+                    .unwrap();
+                }
             }
 
+            write!(&mut formatted_inputs, "## Cursor Excerpt\n\n").unwrap();
+
+            writeln!(
+                &mut formatted_inputs,
+                "```{}\n{}<CURSOR>{}\n```\n",
+                prediction.inputs.cursor_path.display(),
+                &prediction.inputs.cursor_excerpt[..prediction.inputs.cursor_offset_in_excerpt],
+                &prediction.inputs.cursor_excerpt[prediction.inputs.cursor_offset_in_excerpt..],
+            )
+            .unwrap();
+
             self.active_prediction = Some(ActivePrediction {
                 prediction,
                 feedback_editor: cx.new(|cx| {

crates/zeta_prompt/Cargo.toml 🔗

@@ -0,0 +1,15 @@
+[package]
+name = "zeta_prompt"
+version = "0.1.0"
+publish.workspace = true
+edition.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/zeta_prompt.rs"
+
+[dependencies]
+serde.workspace = true

crates/zeta_prompt/src/zeta_prompt.rs 🔗

@@ -0,0 +1,165 @@
+use serde::{Deserialize, Serialize};
+use std::fmt::Write;
+use std::ops::Range;
+use std::path::Path;
+use std::sync::Arc;
+
+pub const CURSOR_MARKER: &str = "<|user_cursor|>";
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ZetaPromptInput {
+    pub cursor_path: Arc<Path>,
+    pub cursor_excerpt: Arc<str>,
+    pub editable_range_in_excerpt: Range<usize>,
+    pub cursor_offset_in_excerpt: usize,
+    pub events: Vec<Arc<Event>>,
+    pub related_files: Arc<[RelatedFile]>,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+#[serde(tag = "event")]
+pub enum Event {
+    BufferChange {
+        path: Arc<Path>,
+        old_path: Arc<Path>,
+        diff: String,
+        predicted: bool,
+        in_open_source_repo: bool,
+    },
+}
+
+pub fn write_event(prompt: &mut String, event: &Event) {
+    fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
+        for component in path.components() {
+            prompt.push('/');
+            write!(prompt, "{}", component.as_os_str().display()).ok();
+        }
+    }
+    match event {
+        Event::BufferChange {
+            path,
+            old_path,
+            diff,
+            predicted,
+            in_open_source_repo: _,
+        } => {
+            if *predicted {
+                prompt.push_str("// User accepted prediction:\n");
+            }
+            prompt.push_str("--- a");
+            write_path_as_unix_str(prompt, old_path.as_ref());
+            prompt.push_str("\n+++ b");
+            write_path_as_unix_str(prompt, path.as_ref());
+            prompt.push('\n');
+            prompt.push_str(diff);
+        }
+    }
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct RelatedFile {
+    pub path: Arc<Path>,
+    pub max_row: u32,
+    pub excerpts: Vec<RelatedExcerpt>,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct RelatedExcerpt {
+    pub row_range: Range<u32>,
+    pub text: String,
+}
+
+pub fn format_zeta_prompt(input: &ZetaPromptInput) -> String {
+    let mut prompt = String::new();
+    write_related_files(&mut prompt, &input.related_files);
+    write_edit_history_section(&mut prompt, input);
+    write_cursor_excerpt_section(&mut prompt, input);
+    prompt
+}
+
+pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) {
+    push_delimited(prompt, "related_files", &[], |prompt| {
+        for file in related_files {
+            let path_str = file.path.to_string_lossy();
+            push_delimited(prompt, "related_file", &[("path", &path_str)], |prompt| {
+                for excerpt in &file.excerpts {
+                    push_delimited(
+                        prompt,
+                        "related_excerpt",
+                        &[(
+                            "lines",
+                            &format!(
+                                "{}-{}",
+                                excerpt.row_range.start + 1,
+                                excerpt.row_range.end + 1
+                            ),
+                        )],
+                        |prompt| {
+                            prompt.push_str(&excerpt.text);
+                            prompt.push('\n');
+                        },
+                    );
+                }
+            });
+        }
+    });
+}
+
+fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
+    push_delimited(prompt, "edit_history", &[], |prompt| {
+        if input.events.is_empty() {
+            prompt.push_str("(No edit history)");
+        } else {
+            for event in &input.events {
+                write_event(prompt, event);
+            }
+        }
+    });
+}
+
+fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
+    push_delimited(prompt, "cursor_excerpt", &[], |prompt| {
+        let path_str = input.cursor_path.to_string_lossy();
+        push_delimited(prompt, "file", &[("path", &path_str)], |prompt| {
+            prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+            push_delimited(prompt, "editable_region", &[], |prompt| {
+                prompt.push_str(
+                    &input.cursor_excerpt
+                        [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
+                );
+                prompt.push_str(CURSOR_MARKER);
+                prompt.push_str(
+                    &input.cursor_excerpt
+                        [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
+                );
+            });
+            prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+        });
+    });
+}
+
+fn push_delimited(
+    prompt: &mut String,
+    tag: &'static str,
+    arguments: &[(&str, &str)],
+    cb: impl FnOnce(&mut String),
+) {
+    if !prompt.ends_with("\n") {
+        prompt.push('\n');
+    }
+    prompt.push('<');
+    prompt.push_str(tag);
+    for (arg_name, arg_value) in arguments {
+        write!(prompt, " {}=\"{}\"", arg_name, arg_value).ok();
+    }
+    prompt.push_str(">\n");
+
+    cb(prompt);
+
+    if !prompt.ends_with('\n') {
+        prompt.push('\n');
+    }
+    prompt.push_str("</");
+    prompt.push_str(tag);
+    prompt.push_str(">\n");
+}