crates/edit_prediction/Cargo.toml 🔗
@@ -12,7 +12,7 @@ workspace = true
path = "src/edit_prediction.rs"
[features]
-eval-support = []
+cli-support = []
[dependencies]
ai_onboarding.workspace = true
Max Brunsfeld , Oleksiy Syvokon , and Agus Zubiaga created
Release Notes:
- N/A
---------
Co-authored-by: Oleksiy Syvokon <oleksiy@zed.dev>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
crates/edit_prediction/Cargo.toml | 2
crates/edit_prediction/src/edit_prediction.rs | 26 +-
crates/edit_prediction/src/udiff.rs | 2
crates/edit_prediction/src/zeta2.rs | 20 +
crates/edit_prediction_cli/Cargo.toml | 2
crates/edit_prediction_cli/src/distill.rs | 14 +
crates/edit_prediction_cli/src/example.rs | 23 +-
crates/edit_prediction_cli/src/format_prompt.rs | 147 ++++++++---------
crates/edit_prediction_cli/src/main.rs | 9 +
crates/edit_prediction_cli/src/predict.rs | 16 +
crates/edit_prediction_cli/src/teacher.prompt.md | 1
11 files changed, 149 insertions(+), 113 deletions(-)
@@ -12,7 +12,7 @@ workspace = true
path = "src/edit_prediction.rs"
[features]
-eval-support = []
+cli-support = []
[dependencies]
ai_onboarding.workspace = true
@@ -55,7 +55,7 @@ pub mod open_ai_response;
mod prediction;
pub mod sweep_ai;
-#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
+#[cfg(any(test, feature = "test-support", feature = "cli-support"))]
pub mod udiff;
mod zed_edit_prediction_delegate;
@@ -158,7 +158,7 @@ pub struct EditPredictionStore {
use_context: bool,
options: ZetaOptions,
update_required: bool,
- #[cfg(feature = "eval-support")]
+ #[cfg(feature = "cli-support")]
eval_cache: Option<Arc<dyn EvalCache>>,
edit_prediction_model: EditPredictionModel,
pub sweep_ai: SweepAi,
@@ -505,7 +505,7 @@ impl EditPredictionStore {
},
),
update_required: false,
- #[cfg(feature = "eval-support")]
+ #[cfg(feature = "cli-support")]
eval_cache: None,
edit_prediction_model: EditPredictionModel::Zeta2,
sweep_ai: SweepAi::new(cx),
@@ -554,7 +554,7 @@ impl EditPredictionStore {
.is_some()
}
- #[cfg(feature = "eval-support")]
+ #[cfg(feature = "cli-support")]
pub fn with_eval_cache(&mut self, cache: Arc<dyn EvalCache>) {
self.eval_cache = Some(cache);
}
@@ -1590,8 +1590,8 @@ impl EditPredictionStore {
client: Arc<Client>,
llm_token: LlmApiToken,
app_version: Version,
- #[cfg(feature = "eval-support")] eval_cache: Option<Arc<dyn EvalCache>>,
- #[cfg(feature = "eval-support")] eval_cache_kind: EvalCacheEntryKind,
+ #[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
+ #[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
http_client::Url::parse(&predict_edits_url)?
@@ -1601,7 +1601,7 @@ impl EditPredictionStore {
.build_zed_llm_url("/predict_edits/raw", &[])?
};
- #[cfg(feature = "eval-support")]
+ #[cfg(feature = "cli-support")]
let cache_key = if let Some(cache) = eval_cache {
use collections::FxHasher;
use std::hash::{Hash, Hasher};
@@ -1635,7 +1635,7 @@ impl EditPredictionStore {
)
.await?;
- #[cfg(feature = "eval-support")]
+ #[cfg(feature = "cli-support")]
if let Some((cache, request, key)) = cache_key {
cache.write(key, &request, &serde_json::to_string_pretty(&response)?);
}
@@ -1767,7 +1767,7 @@ impl EditPredictionStore {
}
}
- #[cfg(feature = "eval-support")]
+ #[cfg(feature = "cli-support")]
pub fn set_context_for_buffer(
&mut self,
project: &Entity<Project>,
@@ -1892,10 +1892,10 @@ pub struct ZedUpdateRequiredError {
minimum_version: Version,
}
-#[cfg(feature = "eval-support")]
+#[cfg(feature = "cli-support")]
pub type EvalCacheKey = (EvalCacheEntryKind, u64);
-#[cfg(feature = "eval-support")]
+#[cfg(feature = "cli-support")]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum EvalCacheEntryKind {
Context,
@@ -1903,7 +1903,7 @@ pub enum EvalCacheEntryKind {
Prediction,
}
-#[cfg(feature = "eval-support")]
+#[cfg(feature = "cli-support")]
impl std::fmt::Display for EvalCacheEntryKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@@ -1914,7 +1914,7 @@ impl std::fmt::Display for EvalCacheEntryKind {
}
}
-#[cfg(feature = "eval-support")]
+#[cfg(feature = "cli-support")]
pub trait EvalCache: Send + Sync {
fn read(&self, key: EvalCacheKey) -> Option<String>;
fn write(&self, key: EvalCacheKey, input: &str, value: &str);
@@ -138,7 +138,7 @@ pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
DiffEvent::Hunk { hunk, .. } => {
let hunk_offset = text
.find(&hunk.context)
- .ok_or_else(|| anyhow!("couldn't result hunk {:?}", hunk.context))?;
+ .ok_or_else(|| anyhow!("couldn't resolve hunk {:?}", hunk.context))?;
for edit in hunk.edits.iter().rev() {
let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
text.replace_range(range, &edit.text);
@@ -1,4 +1,4 @@
-#[cfg(feature = "eval-support")]
+#[cfg(feature = "cli-support")]
use crate::EvalCacheEntryKind;
use crate::open_ai_response::text_from_response;
use crate::prediction::EditPredictionResult;
@@ -44,7 +44,7 @@ pub fn request_prediction_with_zeta2(
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
- #[cfg(feature = "eval-support")]
+ #[cfg(feature = "cli-support")]
let eval_cache = store.eval_cache.clone();
let request_task = cx.background_spawn({
@@ -95,9 +95,9 @@ pub fn request_prediction_with_zeta2(
client,
llm_token,
app_version,
- #[cfg(feature = "eval-support")]
+ #[cfg(feature = "cli-support")]
eval_cache,
- #[cfg(feature = "eval-support")]
+ #[cfg(feature = "cli-support")]
EvalCacheEntryKind::Prediction,
)
.await;
@@ -226,3 +226,15 @@ pub fn zeta2_prompt_input(
};
(editable_offset_range, prompt_input)
}
+
+#[cfg(feature = "cli-support")]
+pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> String {
+ eprintln!("{}", patch);
+ eprintln!("---------------------");
+ eprintln!("{}", input.cursor_excerpt);
+ crate::udiff::apply_diff_to_string(
+ patch,
+ &input.cursor_excerpt[input.editable_range_in_excerpt.clone()],
+ )
+ .unwrap()
+}
@@ -52,7 +52,7 @@ sqlez_macros.workspace = true
terminal_view.workspace = true
util.workspace = true
watch.workspace = true
-edit_prediction = { workspace = true, features = ["eval-support"] }
+edit_prediction = { workspace = true, features = ["cli-support"] }
wasmtime.workspace = true
zeta_prompt.workspace = true
zlog.workspace = true
@@ -0,0 +1,14 @@
+use std::mem;
+
+use crate::example::Example;
+
+pub async fn run_distill(example: &mut Example) {
+ let [prediction]: [_; 1] = mem::take(&mut example.predictions)
+ .try_into()
+ .expect("Run predict first with a single repetition");
+
+ example.expected_patch = prediction.actual_patch;
+ example.prompt = None;
+ example.predictions = Vec::new();
+ example.score = Vec::new();
+}
@@ -25,6 +25,7 @@ pub struct Example {
pub name: String,
pub repository_url: String,
pub revision: String,
+ #[serde(default)]
pub uncommitted_diff: String,
pub cursor_path: Arc<Path>,
pub cursor_position: String,
@@ -195,9 +196,9 @@ pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
.enumerate()
.map(|(line_ix, line)| {
let mut example =
- serde_json::from_str::<Example>(line).unwrap_or_else(|_| {
+ serde_json::from_str::<Example>(line).unwrap_or_else(|error| {
panic!(
- "Failed to parse example on {}:{}",
+ "Failed to parse example on {}:{}\n{error}",
path.display(),
line_ix + 1
)
@@ -264,12 +265,12 @@ fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
state: None,
};
- let mut name = String::new();
let mut text = String::new();
let mut block_info: CowStr = "".into();
#[derive(PartialEq)]
enum Section {
+ Start,
UncommittedDiff,
EditHistory,
CursorPosition,
@@ -278,14 +279,16 @@ fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
Other,
}
- let mut current_section = Section::Other;
+ let mut current_section = Section::Start;
for event in parser {
match event {
Event::Text(line) => {
text.push_str(&line);
- if let Some((field, value)) = line.split_once('=') {
+ if let Section::Start = current_section
+ && let Some((field, value)) = line.split_once('=')
+ {
match field.trim() {
REPOSITORY_URL_FIELD => {
example.repository_url = value.trim().to_string();
@@ -297,14 +300,6 @@ fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
}
}
}
- Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
- if !name.is_empty() {
- anyhow::bail!(
- "Found multiple H1 headings. There should only be one with the name of the example."
- );
- }
- name = mem::take(&mut text);
- }
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
let title = mem::take(&mut text);
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
@@ -363,7 +358,7 @@ fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
Section::ExpectedPatch => {
example.expected_patch = mem::take(&mut text);
}
- Section::Other => {}
+ Section::Start | Section::Other => {}
}
}
_ => {}
@@ -2,9 +2,13 @@ use crate::{
PromptFormat,
example::{Example, ExamplePrompt},
headless::EpAppState,
+ load_project::run_load_project,
retrieve_context::run_context_retrieval,
};
-use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
+use edit_prediction::{
+ EditPredictionStore,
+ zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
+};
use gpui::AsyncApp;
use std::sync::Arc;
use zeta_prompt::format_zeta_prompt;
@@ -15,11 +19,20 @@ pub async fn run_format_prompt(
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) {
- run_context_retrieval(example, app_state, cx.clone()).await;
-
- let prompt = match prompt_format {
- PromptFormat::Teacher => TeacherPrompt::format(example),
+ run_context_retrieval(example, app_state.clone(), cx.clone()).await;
+
+ match prompt_format {
+ PromptFormat::Teacher => {
+ let prompt = TeacherPrompt::format_prompt(example);
+ example.prompt = Some(ExamplePrompt {
+ input: prompt,
+ expected_output: example.expected_patch.clone(), // TODO
+ format: prompt_format,
+ });
+ }
PromptFormat::Zeta2 => {
+ run_load_project(example, app_state, cx.clone()).await;
+
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
@@ -41,30 +54,28 @@ pub async fn run_format_prompt(
)
})
.unwrap();
- format_zeta_prompt(&input)
+ let prompt = format_zeta_prompt(&input);
+ let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
+ example.prompt = Some(ExamplePrompt {
+ input: prompt,
+ expected_output,
+ format: prompt_format,
+ });
}
};
-
- example.prompt = Some(ExamplePrompt {
- input: prompt,
- expected_output: example.expected_patch.clone(), // TODO
- format: prompt_format,
- });
}
-pub trait PromptFormatter {
- fn format(example: &Example) -> String;
-}
+pub struct TeacherPrompt;
-pub trait PromptParser {
- /// Return unified diff patch of prediction given raw LLM response
- fn parse(example: &Example, response: &str) -> String;
-}
+impl TeacherPrompt {
+ const PROMPT: &str = include_str!("teacher.prompt.md");
+ pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
+ pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
-pub struct TeacherPrompt;
+ /// Truncate edit history to this number of last lines
+ const MAX_HISTORY_LINES: usize = 128;
-impl PromptFormatter for TeacherPrompt {
- fn format(example: &Example) -> String {
+ pub fn format_prompt(example: &Example) -> String {
let edit_history = Self::format_edit_history(&example.edit_history);
let context = Self::format_context(example);
let editable_region = Self::format_editable_region(example);
@@ -76,15 +87,46 @@ impl PromptFormatter for TeacherPrompt {
prompt
}
-}
-impl TeacherPrompt {
- const PROMPT: &str = include_str!("teacher.prompt.md");
- pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
- pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
+ pub fn parse(example: &Example, response: &str) -> String {
+ // Ideally, we should always be able to find cursor position in the retrieved context.
+ // In reality, sometimes we don't find it for these reasons:
+ // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
+ // (can be fixed by getting cursor coordinates at the load_example stage)
+ // 2. Context retriever just didn't include cursor line.
+ //
+ // In that case, fallback to using `cursor_position` as excerpt.
+ let cursor_file = &example
+ .buffer
+ .as_ref()
+ .expect("`buffer` should be filled in in the context collection step")
+ .content;
- /// Truncate edit history to this number of last lines
- const MAX_HISTORY_LINES: usize = 128;
+ // Extract updated (new) editable region from the model response
+ let new_editable_region = extract_last_codeblock(response);
+
+ // Reconstruct old editable region we sent to the model
+ let old_editable_region = Self::format_editable_region(example);
+ let old_editable_region = Self::extract_editable_region(&old_editable_region);
+ if !cursor_file.contains(&old_editable_region) {
+ panic!("Something's wrong: editable_region is not found in the cursor file")
+ }
+
+ // Apply editable region to a larger context and compute diff.
+ // This is needed to get a better context lines around the editable region
+ let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
+ let diff = language::unified_diff(&cursor_file, &edited_file);
+
+ let diff = indoc::formatdoc! {"
+ --- a/{path}
+ +++ b/{path}
+ {diff}",
+ path = example.cursor_path.to_string_lossy(),
+ diff = diff,
+ };
+
+ diff
+ }
fn format_edit_history(edit_history: &str) -> String {
// Strip comments ("garbage lines") from edit history
@@ -157,49 +199,6 @@ impl TeacherPrompt {
}
}
-impl PromptParser for TeacherPrompt {
- fn parse(example: &Example, response: &str) -> String {
- // Ideally, we should always be able to find cursor position in the retrieved context.
- // In reality, sometimes we don't find it for these reasons:
- // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
- // (can be fixed by getting cursor coordinates at the load_example stage)
- // 2. Context retriever just didn't include cursor line.
- //
- // In that case, fallback to using `cursor_position` as excerpt.
- let cursor_file = &example
- .buffer
- .as_ref()
- .expect("`buffer` should be filled in in the context collection step")
- .content;
-
- // Extract updated (new) editable region from the model response
- let new_editable_region = extract_last_codeblock(response);
-
- // Reconstruct old editable region we sent to the model
- let old_editable_region = Self::format_editable_region(example);
- let old_editable_region = Self::extract_editable_region(&old_editable_region);
- if !cursor_file.contains(&old_editable_region) {
- panic!("Something's wrong: editable_region is not found in the cursor file")
- }
-
- // Apply editable region to a larger context and compute diff.
- // This is needed to get a better context lines around the editable region
- let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
- let diff = language::unified_diff(&cursor_file, &edited_file);
-
- let diff = indoc::formatdoc! {"
- --- a/{path}
- +++ b/{path}
- {diff}
- ",
- path = example.cursor_path.to_string_lossy(),
- diff = diff,
- };
-
- diff
- }
-}
-
fn extract_last_codeblock(text: &str) -> String {
let mut last_block = None;
let mut search_start = 0;
@@ -221,7 +220,7 @@ fn extract_last_codeblock(text: &str) -> String {
}
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
- let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
+ let code_block = &text[backtick_end + 1..backtick_end + end_pos];
last_block = Some(code_block.to_string());
search_start = backtick_end + end_pos + backtick_count;
} else {
@@ -250,7 +249,7 @@ mod tests {
`````
"};
let last_block = extract_last_codeblock(text);
- assert_eq!(last_block, "last block");
+ assert_eq!(last_block, "last block\n");
}
#[test]
@@ -1,4 +1,5 @@
mod anthropic_client;
+mod distill;
mod example;
mod format_prompt;
mod headless;
@@ -16,6 +17,7 @@ use reqwest_client::ReqwestClient;
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
+use crate::distill::run_distill;
use crate::example::{read_examples, write_examples};
use crate::format_prompt::run_format_prompt;
use crate::load_project::run_load_project;
@@ -54,6 +56,9 @@ enum Command {
Predict(PredictArgs),
/// Computes a score based on actual and expected patches
Score(PredictArgs),
+ /// Prepares a distillation dataset by copying expected outputs to
+ /// predicted outputs and removing actual outputs and prompts.
+ Distill,
/// Print aggregated scores
Eval(PredictArgs),
/// Remove git repositories and worktrees
@@ -87,6 +92,7 @@ enum PredictionProvider {
Zeta1,
Zeta2,
Teacher,
+ TeacherNonBatching,
}
impl EpArgs {
@@ -175,6 +181,9 @@ fn main() {
)
.await;
}
+ Command::Distill => {
+ run_distill(example).await;
+ }
Command::Score(args) | Command::Eval(args) => {
run_scoring(example, &args, app_state, cx).await;
}
@@ -2,7 +2,7 @@ use crate::{
PredictionProvider, PromptFormat,
anthropic_client::AnthropicClient,
example::{Example, ExamplePrediction},
- format_prompt::{PromptParser, TeacherPrompt, run_format_prompt},
+ format_prompt::{TeacherPrompt, run_format_prompt},
headless::EpAppState,
load_project::run_load_project,
paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
@@ -30,20 +30,24 @@ pub async fn run_prediction(
return;
}
- run_load_project(example, app_state.clone(), cx.clone()).await;
run_context_retrieval(example, app_state.clone(), cx.clone()).await;
let provider = provider.unwrap();
- if matches!(provider, PredictionProvider::Teacher) {
+ if matches!(
+ provider,
+ PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
+ ) {
if example.prompt.is_none() {
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
}
- let batched = true;
+ let batched = matches!(provider, PredictionProvider::Teacher);
return predict_anthropic(example, repetition_count, batched).await;
}
+ run_load_project(example, app_state.clone(), cx.clone()).await;
+
if matches!(
provider,
PredictionProvider::Zeta1 | PredictionProvider::Zeta2
@@ -75,7 +79,9 @@ pub async fn run_prediction(
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
- PredictionProvider::Teacher => unreachable!(),
+ PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
+ unreachable!()
+ }
};
store.set_edit_prediction_model(model);
})
@@ -18,6 +18,7 @@ Focus on:
Rules:
- Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
- Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
+- Keep existing formatting unless it's absolutely necessary
Input format:
- You receive small code fragments called context (structs, field definitions, function signatures, etc.). They may or may not be relevant.