Add ep distill command, for generating edit prediction training examples (#44670)

Max Brunsfeld , Oleksiy Syvokon , and Agus Zubiaga created

Release Notes:

- N/A

---------

Co-authored-by: Oleksiy Syvokon <oleksiy@zed.dev>
Co-authored-by: Agus Zubiaga <agus@zed.dev>

Change summary

crates/edit_prediction/Cargo.toml                |   2 
crates/edit_prediction/src/edit_prediction.rs    |  26 +-
crates/edit_prediction/src/udiff.rs              |   2 
crates/edit_prediction/src/zeta2.rs              |  20 +
crates/edit_prediction_cli/Cargo.toml            |   2 
crates/edit_prediction_cli/src/distill.rs        |  14 +
crates/edit_prediction_cli/src/example.rs        |  23 +-
crates/edit_prediction_cli/src/format_prompt.rs  | 147 ++++++++---------
crates/edit_prediction_cli/src/main.rs           |   9 +
crates/edit_prediction_cli/src/predict.rs        |  16 +
crates/edit_prediction_cli/src/teacher.prompt.md |   1 
11 files changed, 149 insertions(+), 113 deletions(-)

Detailed changes

crates/edit_prediction/Cargo.toml 🔗

@@ -12,7 +12,7 @@ workspace = true
 path = "src/edit_prediction.rs"
 
 [features]
-eval-support = []
+cli-support = []
 
 [dependencies]
 ai_onboarding.workspace = true

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -55,7 +55,7 @@ pub mod open_ai_response;
 mod prediction;
 pub mod sweep_ai;
 
-#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
+#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
 pub mod udiff;
 
 mod zed_edit_prediction_delegate;
@@ -158,7 +158,7 @@ pub struct EditPredictionStore {
     use_context: bool,
     options: ZetaOptions,
     update_required: bool,
-    #[cfg(feature = "eval-support")]
+    #[cfg(feature = "cli-support")]
     eval_cache: Option<Arc<dyn EvalCache>>,
     edit_prediction_model: EditPredictionModel,
     pub sweep_ai: SweepAi,
@@ -505,7 +505,7 @@ impl EditPredictionStore {
                 },
             ),
             update_required: false,
-            #[cfg(feature = "eval-support")]
+            #[cfg(feature = "cli-support")]
             eval_cache: None,
             edit_prediction_model: EditPredictionModel::Zeta2,
             sweep_ai: SweepAi::new(cx),
@@ -554,7 +554,7 @@ impl EditPredictionStore {
             .is_some()
     }
 
-    #[cfg(feature = "eval-support")]
+    #[cfg(feature = "cli-support")]
     pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
         self.eval_cache = Some(cache);
     }
@@ -1590,8 +1590,8 @@ impl EditPredictionStore {
         client: Arc<Client>,
         llm_token: LlmApiToken,
         app_version: Version,
-        #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
-        #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
+        #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
+        #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
     ) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
         let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
             http_client::Url::parse(&predict_edits_url)?
@@ -1601,7 +1601,7 @@ impl EditPredictionStore {
                 .build_zed_llm_url("/predict_edits/raw", &[])?
         };
 
-        #[cfg(feature = "eval-support")]
+        #[cfg(feature = "cli-support")]
         let cache_key = if let Some(cache) = eval_cache {
             use collections::FxHasher;
             use std::hash::{Hash, Hasher};
@@ -1635,7 +1635,7 @@ impl EditPredictionStore {
         )
         .await?;
 
-        #[cfg(feature = "eval-support")]
+        #[cfg(feature = "cli-support")]
         if let Some((cache, request, key)) = cache_key {
             cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
         }
@@ -1767,7 +1767,7 @@ impl EditPredictionStore {
         }
     }
 
-    #[cfg(feature = "eval-support")]
+    #[cfg(feature = "cli-support")]
     pub fn set_context_for_buffer(
         &mut self,
         project: &Entity<Project>,
@@ -1892,10 +1892,10 @@ pub struct ZedUpdateRequiredError {
     minimum_version: Version,
 }
 
-#[cfg(feature = "eval-support")]
+#[cfg(feature = "cli-support")]
 pub type EvalCacheKey = (EvalCacheEntryKind, u64);
 
-#[cfg(feature = "eval-support")]
+#[cfg(feature = "cli-support")]
 #[derive(Debug, Clone, Copy, PartialEq)]
 pub enum EvalCacheEntryKind {
     Context,
@@ -1903,7 +1903,7 @@ pub enum EvalCacheEntryKind {
     Prediction,
 }
 
-#[cfg(feature = "eval-support")]
+#[cfg(feature = "cli-support")]
 impl std::fmt::Display for EvalCacheEntryKind {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         match self {
@@ -1914,7 +1914,7 @@ impl std::fmt::Display for EvalCacheEntryKind {
     }
 }
 
-#[cfg(feature = "eval-support")]
+#[cfg(feature = "cli-support")]
 pub trait EvalCache: Send + Sync {
     fn read(&self, key: EvalCacheKey) -> Option<String>;
     fn write(&self, key: EvalCacheKey, input: &str, value: &str);

crates/edit_prediction/src/udiff.rs 🔗

@@ -138,7 +138,7 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
             DiffEvent::Hunk { hunk, .. } => {
                 let hunk_offset = text
                     .find(&hunk.context)
-                    .ok_or_else(|| anyhow!("couldn't result hunk {:?}", hunk.context))?;
+                    .ok_or_else(|| anyhow!("couldn't resolve hunk {:?}", hunk.context))?;
                 for edit in hunk.edits.iter().rev() {
                     let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
                     text.replace_range(range, &edit.text);

crates/edit_prediction/src/zeta2.rs 🔗

@@ -1,4 +1,4 @@
-#[cfg(feature = "eval-support")]
+#[cfg(feature = "cli-support")]
 use crate::EvalCacheEntryKind;
 use crate::open_ai_response::text_from_response;
 use crate::prediction::EditPredictionResult;
@@ -44,7 +44,7 @@ pub fn request_prediction_with_zeta2(
     let llm_token = store.llm_token.clone();
     let app_version = AppVersion::global(cx);
 
-    #[cfg(feature = "eval-support")]
+    #[cfg(feature = "cli-support")]
     let eval_cache = store.eval_cache.clone();
 
     let request_task = cx.background_spawn({
@@ -95,9 +95,9 @@ pub fn request_prediction_with_zeta2(
                 client,
                 llm_token,
                 app_version,
-                #[cfg(feature = "eval-support")]
+                #[cfg(feature = "cli-support")]
                 eval_cache,
-                #[cfg(feature = "eval-support")]
+                #[cfg(feature = "cli-support")]
                 EvalCacheEntryKind::Prediction,
             )
             .await;
@@ -226,3 +226,15 @@ pub fn zeta2_prompt_input(
     };
     (editable_offset_range, prompt_input)
 }
+
+#[cfg(feature = "cli-support")]
+pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> String {
+    eprintln!("{}", patch);
+    eprintln!("---------------------");
+    eprintln!("{}", input.cursor_excerpt);
+    crate::udiff::apply_diff_to_string(
+        patch,
+        &input.cursor_excerpt[input.editable_range_in_excerpt.clone()],
+    )
+    .unwrap()
+}

crates/edit_prediction_cli/Cargo.toml 🔗

@@ -52,7 +52,7 @@ sqlez_macros.workspace = true
 terminal_view.workspace = true
 util.workspace = true
 watch.workspace = true
-edit_prediction = { workspace = true, features = ["eval-support"] }
+edit_prediction = { workspace = true, features = ["cli-support"] }
 wasmtime.workspace = true
 zeta_prompt.workspace = true
 zlog.workspace = true

crates/edit_prediction_cli/src/distill.rs 🔗

@@ -0,0 +1,14 @@
+use std::mem;
+
+use crate::example::Example;
+
+pub async fn run_distill(example: &mut Example) {
+    let [prediction]: [_; 1] = mem::take(&mut example.predictions)
+        .try_into()
+        .expect("Run predict first with a single repetition");
+
+    example.expected_patch = prediction.actual_patch;
+    example.prompt = None;
+    example.predictions = Vec::new();
+    example.score = Vec::new();
+}

crates/edit_prediction_cli/src/example.rs 🔗

@@ -25,6 +25,7 @@ pub struct Example {
     pub name: String,
     pub repository_url: String,
     pub revision: String,
+    #[serde(default)]
     pub uncommitted_diff: String,
     pub cursor_path: Arc<Path>,
     pub cursor_position: String,
@@ -195,9 +196,9 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
                     .enumerate()
                     .map(|(line_ix, line)| {
                         let mut example =
-                            serde_json::from_str::<Example>(line).unwrap_or_else(|_| {
+                            serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
                                 panic!(
-                                    "Failed to parse example on {}:{}",
+                                    "Failed to parse example on {}:{}\n{error}",
                                     path.display(),
                                     line_ix + 1
                                 )
@@ -264,12 +265,12 @@ fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
         state: None,
     };
 
-    let mut name = String::new();
     let mut text = String::new();
     let mut block_info: CowStr = "".into();
 
     #[derive(PartialEq)]
     enum Section {
+        Start,
         UncommittedDiff,
         EditHistory,
         CursorPosition,
@@ -278,14 +279,16 @@ fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
         Other,
     }
 
-    let mut current_section = Section::Other;
+    let mut current_section = Section::Start;
 
     for event in parser {
         match event {
             Event::Text(line) => {
                 text.push_str(&line);
 
-                if let Some((field, value)) = line.split_once('=') {
+                if let Section::Start = current_section
+                    && let Some((field, value)) = line.split_once('=')
+                {
                     match field.trim() {
                         REPOSITORY_URL_FIELD => {
                             example.repository_url = value.trim().to_string();
@@ -297,14 +300,6 @@ fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
                     }
                 }
             }
-            Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
-                if !name.is_empty() {
-                    anyhow::bail!(
-                        "Found multiple H1 headings. There should only be one with the name of the example."
-                    );
-                }
-                name = mem::take(&mut text);
-            }
             Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
                 let title = mem::take(&mut text);
                 current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
@@ -363,7 +358,7 @@ fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
                     Section::ExpectedPatch => {
                         example.expected_patch = mem::take(&mut text);
                     }
-                    Section::Other => {}
+                    Section::Start | Section::Other => {}
                 }
             }
             _ => {}

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -2,9 +2,13 @@ use crate::{
     PromptFormat,
     example::{Example, ExamplePrompt},
     headless::EpAppState,
+    load_project::run_load_project,
     retrieve_context::run_context_retrieval,
 };
-use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
+use edit_prediction::{
+    EditPredictionStore,
+    zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
+};
 use gpui::AsyncApp;
 use std::sync::Arc;
 use zeta_prompt::format_zeta_prompt;
@@ -15,11 +19,20 @@ pub async fn run_format_prompt(
     app_state: Arc<EpAppState>,
     mut cx: AsyncApp,
 ) {
-    run_context_retrieval(example, app_state, cx.clone()).await;
-
-    let prompt = match prompt_format {
-        PromptFormat::Teacher => TeacherPrompt::format(example),
+    run_context_retrieval(example, app_state.clone(), cx.clone()).await;
+
+    match prompt_format {
+        PromptFormat::Teacher => {
+            let prompt = TeacherPrompt::format_prompt(example);
+            example.prompt = Some(ExamplePrompt {
+                input: prompt,
+                expected_output: example.expected_patch.clone(), // TODO
+                format: prompt_format,
+            });
+        }
         PromptFormat::Zeta2 => {
+            run_load_project(example, app_state, cx.clone()).await;
+
             let ep_store = cx
                 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
                 .unwrap();
@@ -41,30 +54,28 @@ pub async fn run_format_prompt(
                     )
                 })
                 .unwrap();
-            format_zeta_prompt(&input)
+            let prompt = format_zeta_prompt(&input);
+            let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
+            example.prompt = Some(ExamplePrompt {
+                input: prompt,
+                expected_output,
+                format: prompt_format,
+            });
         }
     };
-
-    example.prompt = Some(ExamplePrompt {
-        input: prompt,
-        expected_output: example.expected_patch.clone(), // TODO
-        format: prompt_format,
-    });
 }
 
-pub trait PromptFormatter {
-    fn format(example: &Example) -> String;
-}
+pub struct TeacherPrompt;
 
-pub trait PromptParser {
-    /// Return unified diff patch of prediction given raw LLM response
-    fn parse(example: &Example, response: &str) -> String;
-}
+impl TeacherPrompt {
+    const PROMPT: &str = include_str!("teacher.prompt.md");
+    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
+    pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
 
-pub struct TeacherPrompt;
+    /// Truncate edit history to this number of last lines
+    const MAX_HISTORY_LINES: usize = 128;
 
-impl PromptFormatter for TeacherPrompt {
-    fn format(example: &Example) -> String {
+    pub fn format_prompt(example: &Example) -> String {
         let edit_history = Self::format_edit_history(&example.edit_history);
         let context = Self::format_context(example);
         let editable_region = Self::format_editable_region(example);
@@ -76,15 +87,46 @@ impl PromptFormatter for TeacherPrompt {
 
         prompt
     }
-}
 
-impl TeacherPrompt {
-    const PROMPT: &str = include_str!("teacher.prompt.md");
-    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
-    pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
+    pub fn parse(example: &Example, response: &str) -> String {
+        // Ideally, we should always be able to find cursor position in the retrieved context.
+        // In reality, sometimes we don't find it for these reasons:
+        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
+        //    (can be fixed by getting cursor coordinates at the load_example stage)
+        // 2. Context retriever just didn't include cursor line.
+        //
+        // In that case, fallback to using `cursor_position` as excerpt.
+        let cursor_file = &example
+            .buffer
+            .as_ref()
+            .expect("`buffer` should be filled in in the context collection step")
+            .content;
 
-    /// Truncate edit history to this number of last lines
-    const MAX_HISTORY_LINES: usize = 128;
+        // Extract updated (new) editable region from the model response
+        let new_editable_region = extract_last_codeblock(response);
+
+        // Reconstruct old editable region we sent to the model
+        let old_editable_region = Self::format_editable_region(example);
+        let old_editable_region = Self::extract_editable_region(&old_editable_region);
+        if !cursor_file.contains(&old_editable_region) {
+            panic!("Something's wrong: editable_region is not found in the cursor file")
+        }
+
+        // Apply editable region to a larger context and compute diff.
+        // This is needed to get a better context lines around the editable region
+        let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
+        let diff = language::unified_diff(&cursor_file, &edited_file);
+
+        let diff = indoc::formatdoc! {"
+            --- a/{path}
+            +++ b/{path}
+            {diff}",
+            path = example.cursor_path.to_string_lossy(),
+            diff = diff,
+        };
+
+        diff
+    }
 
     fn format_edit_history(edit_history: &str) -> String {
         // Strip comments ("garbage lines") from edit history
@@ -157,49 +199,6 @@ impl TeacherPrompt {
     }
 }
 
-impl PromptParser for TeacherPrompt {
-    fn parse(example: &Example, response: &str) -> String {
-        // Ideally, we should always be able to find cursor position in the retrieved context.
-        // In reality, sometimes we don't find it for these reasons:
-        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
-        //    (can be fixed by getting cursor coordinates at the load_example stage)
-        // 2. Context retriever just didn't include cursor line.
-        //
-        // In that case, fallback to using `cursor_position` as excerpt.
-        let cursor_file = &example
-            .buffer
-            .as_ref()
-            .expect("`buffer` should be filled in in the context collection step")
-            .content;
-
-        // Extract updated (new) editable region from the model response
-        let new_editable_region = extract_last_codeblock(response);
-
-        // Reconstruct old editable region we sent to the model
-        let old_editable_region = Self::format_editable_region(example);
-        let old_editable_region = Self::extract_editable_region(&old_editable_region);
-        if !cursor_file.contains(&old_editable_region) {
-            panic!("Something's wrong: editable_region is not found in the cursor file")
-        }
-
-        // Apply editable region to a larger context and compute diff.
-        // This is needed to get a better context lines around the editable region
-        let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
-        let diff = language::unified_diff(&cursor_file, &edited_file);
-
-        let diff = indoc::formatdoc! {"
-            --- a/{path}
-            +++ b/{path}
-            {diff}
-            ",
-            path = example.cursor_path.to_string_lossy(),
-            diff = diff,
-        };
-
-        diff
-    }
-}
-
 fn extract_last_codeblock(text: &str) -> String {
     let mut last_block = None;
     let mut search_start = 0;
@@ -221,7 +220,7 @@ fn extract_last_codeblock(text: &str) -> String {
         }
 
         if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
-            let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
+            let code_block = &text[backtick_end + 1..backtick_end + end_pos];
             last_block = Some(code_block.to_string());
             search_start = backtick_end + end_pos + backtick_count;
         } else {
@@ -250,7 +249,7 @@ mod tests {
             `````
             "};
         let last_block = extract_last_codeblock(text);
-        assert_eq!(last_block, "last block");
+        assert_eq!(last_block, "last block\n");
     }
 
     #[test]

crates/edit_prediction_cli/src/main.rs 🔗

@@ -1,4 +1,5 @@
 mod anthropic_client;
+mod distill;
 mod example;
 mod format_prompt;
 mod headless;
@@ -16,6 +17,7 @@ use reqwest_client::ReqwestClient;
 use serde::{Deserialize, Serialize};
 use std::{path::PathBuf, sync::Arc};
 
+use crate::distill::run_distill;
 use crate::example::{read_examples, write_examples};
 use crate::format_prompt::run_format_prompt;
 use crate::load_project::run_load_project;
@@ -54,6 +56,9 @@ enum Command {
     Predict(PredictArgs),
     /// Computes a score based on actual and expected patches
     Score(PredictArgs),
+    /// Prepares a distillation dataset by copying expected outputs to
+    /// predicted outputs and removing actual outputs and prompts.
+    Distill,
     /// Print aggregated scores
     Eval(PredictArgs),
     /// Remove git repositories and worktrees
@@ -87,6 +92,7 @@ enum PredictionProvider {
     Zeta1,
     Zeta2,
     Teacher,
+    TeacherNonBatching,
 }
 
 impl EpArgs {
@@ -175,6 +181,9 @@ fn main() {
                                 )
                                 .await;
                             }
+                            Command::Distill => {
+                                run_distill(example).await;
+                            }
                             Command::Score(args) | Command::Eval(args) => {
                                 run_scoring(example, &args, app_state, cx).await;
                             }

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -2,7 +2,7 @@ use crate::{
     PredictionProvider, PromptFormat,
     anthropic_client::AnthropicClient,
     example::{Example, ExamplePrediction},
-    format_prompt::{PromptParser, TeacherPrompt, run_format_prompt},
+    format_prompt::{TeacherPrompt, run_format_prompt},
     headless::EpAppState,
     load_project::run_load_project,
     paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
@@ -30,20 +30,24 @@ pub async fn run_prediction(
         return;
     }
 
-    run_load_project(example, app_state.clone(), cx.clone()).await;
     run_context_retrieval(example, app_state.clone(), cx.clone()).await;
 
     let provider = provider.unwrap();
 
-    if matches!(provider, PredictionProvider::Teacher) {
+    if matches!(
+        provider,
+        PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
+    ) {
         if example.prompt.is_none() {
             run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
         }
 
-        let batched = true;
+        let batched = matches!(provider, PredictionProvider::Teacher);
         return predict_anthropic(example, repetition_count, batched).await;
     }
 
+    run_load_project(example, app_state.clone(), cx.clone()).await;
+
     if matches!(
         provider,
         PredictionProvider::Zeta1 | PredictionProvider::Zeta2
@@ -75,7 +79,9 @@ pub async fn run_prediction(
                 PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
                 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
                 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
-                PredictionProvider::Teacher => unreachable!(),
+                PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
+                    unreachable!()
+                }
             };
             store.set_edit_prediction_model(model);
         })

crates/edit_prediction_cli/src/teacher.prompt.md 🔗

@@ -18,6 +18,7 @@ Focus on:
 Rules:
 - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
 - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
+- Keep existing formatting unless it's absolutely necessary 
 
 Input format:
 - You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant.