diff --git a/Cargo.lock b/Cargo.lock index a8f602640838d3634863fc60a2399e8a9a9f5288..ff1041695e1f1e95bcbc05798d1a1e0f953533ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3111,16 +3111,6 @@ dependencies = [ "uuid", ] -[[package]] -name = "cloud_zeta2_prompt" -version = "0.1.0" -dependencies = [ - "anyhow", - "cloud_llm_client", - "indoc", - "serde", -] - [[package]] name = "cmake" version = "0.1.54" @@ -5119,7 +5109,6 @@ dependencies = [ "clock", "cloud_api_types", "cloud_llm_client", - "cloud_zeta2_prompt", "collections", "copilot", "credentials_provider", @@ -5150,8 +5139,6 @@ dependencies = [ "serde", "serde_json", "settings", - "smol", - "strsim", "strum 0.27.2", "telemetry", "telemetry_events", @@ -5162,6 +5149,7 @@ dependencies = [ "workspace", "worktree", "zed_actions", + "zeta_prompt", "zlog", ] @@ -5175,11 +5163,10 @@ dependencies = [ "clap", "client", "cloud_llm_client", - "cloud_zeta2_prompt", "collections", "debug_adapter_extension", + "dirs 4.0.0", "edit_prediction", - "edit_prediction_context", "extension", "fs", "futures 0.3.31", @@ -5209,9 +5196,10 @@ dependencies = [ "sqlez", "sqlez_macros", "terminal_view", - "toml 0.8.23", "util", + "wasmtime", "watch", + "zeta_prompt", "zlog", ] @@ -5239,6 +5227,7 @@ dependencies = [ "text", "tree-sitter", "util", + "zeta_prompt", "zlog", ] @@ -5260,7 +5249,6 @@ dependencies = [ "buffer_diff", "client", "cloud_llm_client", - "cloud_zeta2_prompt", "codestral", "command_palette_hooks", "copilot", @@ -5291,6 +5279,7 @@ dependencies = [ "util", "workspace", "zed_actions", + "zeta_prompt", ] [[package]] @@ -20933,6 +20922,13 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "zeta_prompt" +version = "0.1.0" +dependencies = [ + "serde", +] + [[package]] name = "zip" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index 0ad4d2b14523988aa0dd6e3bfc935f84bcd0d8d9..fcbe5c829ded21a9aaf9e6bec93b9955b1db6447 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,6 @@ members = [ "crates/cloud_api_client", "crates/cloud_api_types", "crates/cloud_llm_client", - "crates/cloud_zeta2_prompt", "crates/collab", "crates/collab_ui", "crates/collections", @@ -202,6 +201,7 @@ members = [ "crates/zed_actions", "crates/zed_env_vars", "crates/edit_prediction_cli", + "crates/zeta_prompt", "crates/zlog", "crates/zlog_settings", "crates/ztracing", @@ -266,7 +266,6 @@ clock = { path = "crates/clock" } cloud_api_client = { path = "crates/cloud_api_client" } cloud_api_types = { path = "crates/cloud_api_types" } cloud_llm_client = { path = "crates/cloud_llm_client" } -cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" } collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections", version = "0.1.0" } command_palette = { path = "crates/command_palette" } @@ -425,6 +424,7 @@ zed = { path = "crates/zed" } zed_actions = { path = "crates/zed_actions" } zed_env_vars = { path = "crates/zed_env_vars" } edit_prediction = { path = "crates/edit_prediction" } +zeta_prompt = { path = "crates/zeta_prompt" } zlog = { path = "crates/zlog" } zlog_settings = { path = "crates/zlog_settings" } ztracing = { path = "crates/ztracing" } @@ -657,6 +657,7 @@ time = { version = "0.3", features = [ tiny_http = "0.8" tokio = { version = "1" } tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] } +tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io", "tokio"] } toml = "0.8" toml_edit = { version = "0.22", default-features = false, features = ["display", "parse", "serde"] } tower-http = "0.4.4" diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 7149ad4f55feaae5b596a39a3dd460d71cc5daa5..50cf12b977a62d56bf9d4a036165917a5dfff2fc 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -53,7 +53,7 @@ text.workspace = true thiserror.workspace = true time.workspace = true tiny_http.workspace = true -tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] } +tokio-socks.workspace = true tokio.workspace = true url.workspace = true util.workspace = true diff --git a/crates/cloud_zeta2_prompt/Cargo.toml b/crates/cloud_zeta2_prompt/Cargo.toml deleted file mode 100644 index a15e3fe43c28349920433272c4040ccc58ff4cb4..0000000000000000000000000000000000000000 --- a/crates/cloud_zeta2_prompt/Cargo.toml +++ /dev/null @@ -1,18 +0,0 @@ -[package] -name = "cloud_zeta2_prompt" -version = "0.1.0" -publish.workspace = true -edition.workspace = true -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/cloud_zeta2_prompt.rs" - -[dependencies] -anyhow.workspace = true -cloud_llm_client.workspace = true -indoc.workspace = true -serde.workspace = true diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs deleted file mode 100644 index 62bfa45f47d0fdfefa9fbd72320c0ddee71cbc47..0000000000000000000000000000000000000000 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ /dev/null @@ -1,485 +0,0 @@ -use anyhow::Result; -use cloud_llm_client::predict_edits_v3::{ - self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile, -}; -use indoc::indoc; -use std::cmp; -use std::fmt::Write; -use std::path::Path; -use std::sync::Arc; - -pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024; - -pub const CURSOR_MARKER: &str = "<|user_cursor|>"; -/// NOTE: Differs from zed version of constant - includes a newline -pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n"; -/// NOTE: Differs from zed version of constant - includes a newline -pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n"; - -const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#" - You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase. - - ## Edit History - - "#}; - -const MINIMAL_PROMPT_REMINDER: &str = indoc! {" - --- - - Please analyze the edit history and the files, then provide the unified diff for your predicted edits. - Do not include the cursor marker in your output. - If you're editing multiple files, be sure to reflect filename in the hunk's header. - "}; - -const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#" - # Instructions - - You are an edit prediction agent in a code editor. - - Analyze the history of edits made by the user in order to infer what they are currently trying to accomplish. - Then complete the remainder of the current change if it is incomplete, or predict the next edit the user intends to make. - Always continue along the user's current trajectory, rather than changing course. - - ## Output Format - - You should briefly explain your understanding of the user's overall goal in one sentence, then explain what the next change - along the users current trajectory will be in another, and finally specify the next edit using the following XML-like format: - - - - OLD TEXT 1 HERE - - - NEW TEXT 1 HERE - - - - OLD TEXT 1 HERE - - - NEW TEXT 1 HERE - - - - - Specify the file to edit using the `path` attribute. - - Use `` and `` tags to replace content - - `` must exactly match existing file content, including indentation - - `` cannot be empty - - Do not escape quotes, newlines, or other characters within tags - - Always close all tags properly - - Don't include the <|user_cursor|> marker in your output. - - ## Edit History - -"#}; - -const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#" - --- - - Remember that the edits in the edit history have already been applied. -"#}; - -pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result { - let prompt_data = PromptData { - events: request.events.clone(), - cursor_point: request.cursor_point, - cursor_path: request.excerpt_path.clone(), - included_files: request.related_files.clone(), - }; - match request.prompt_format { - PromptFormat::MinimalQwen => { - return Ok(MinimalQwenPrompt.render(&prompt_data)); - } - PromptFormat::SeedCoder1120 => { - return Ok(SeedCoder1120Prompt.render(&prompt_data)); - } - _ => (), - }; - - let insertions = match request.prompt_format { - PromptFormat::Minimal | PromptFormat::OldTextNewText => { - vec![(request.cursor_point, CURSOR_MARKER)] - } - PromptFormat::OnlySnippets => vec![], - PromptFormat::MinimalQwen => unreachable!(), - PromptFormat::SeedCoder1120 => unreachable!(), - }; - - let mut prompt = match request.prompt_format { - PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(), - PromptFormat::OnlySnippets => String::new(), - PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(), - PromptFormat::MinimalQwen => unreachable!(), - PromptFormat::SeedCoder1120 => unreachable!(), - }; - - if request.events.is_empty() { - prompt.push_str("(No edit history)\n\n"); - } else { - let edit_preamble = if request.prompt_format == PromptFormat::Minimal { - "The following are the latest edits made by the user, from earlier to later.\n\n" - } else { - "Here are the latest edits made by the user, from earlier to later.\n\n" - }; - prompt.push_str(edit_preamble); - push_events(&mut prompt, &request.events); - } - - let excerpts_preamble = match request.prompt_format { - PromptFormat::Minimal => indoc! {" - ## Part of the file under the cursor - - (The cursor marker <|user_cursor|> indicates the current user cursor position. - The file is in current state, edits from edit history has been applied. - We only show part of the file around the cursor. - You can only edit exactly this part of the file. - We prepend line numbers (e.g., `123|`); they are not part of the file.) - "}, - PromptFormat::OldTextNewText => indoc! {" - ## Code Excerpts - - Here is some excerpts of code that you should take into account to predict the next edit. - - The cursor position is marked by `<|user_cursor|>` as it stands after the last edit in the history. - - In addition other excerpts are included to better understand what the edit will be, including the declaration - or references of symbols around the cursor, or other similar code snippets that may need to be updated - following patterns that appear in the edit history. - - Consider each of them carefully in relation to the edit history, and that the user may not have navigated - to the next place they want to edit yet. - - Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs. - "}, - PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => { - indoc! {" - ## Code Excerpts - - The cursor marker <|user_cursor|> indicates the current user cursor position. - The file is in current state, edits from edit history have been applied. - "} - } - }; - - prompt.push_str(excerpts_preamble); - prompt.push('\n'); - - let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal); - for related_file in &request.related_files { - if request.prompt_format == PromptFormat::Minimal { - write_codeblock_with_filename( - &related_file.path, - &related_file.excerpts, - if related_file.path == request.excerpt_path { - &insertions - } else { - &[] - }, - related_file.max_row, - include_line_numbers, - &mut prompt, - ); - } else { - write_codeblock( - &related_file.path, - &related_file.excerpts, - if related_file.path == request.excerpt_path { - &insertions - } else { - &[] - }, - related_file.max_row, - include_line_numbers, - &mut prompt, - ); - } - } - - match request.prompt_format { - PromptFormat::OldTextNewText => { - prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER); - } - PromptFormat::Minimal => { - prompt.push_str(MINIMAL_PROMPT_REMINDER); - } - _ => {} - } - - Ok(prompt) -} - -pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams { - match prompt_format { - PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(), - _ => GenerationParams::default(), - } -} - -pub fn write_codeblock<'a>( - path: &Path, - excerpts: impl IntoIterator, - sorted_insertions: &[(Point, &str)], - file_line_count: Line, - include_line_numbers: bool, - output: &'a mut String, -) { - writeln!(output, "`````{}", DiffPathFmt(path)).unwrap(); - - write_excerpts( - excerpts, - sorted_insertions, - file_line_count, - include_line_numbers, - output, - ); - write!(output, "`````\n\n").unwrap(); -} - -fn write_codeblock_with_filename<'a>( - path: &Path, - excerpts: impl IntoIterator, - sorted_insertions: &[(Point, &str)], - file_line_count: Line, - include_line_numbers: bool, - output: &'a mut String, -) { - writeln!(output, "`````filename={}", DiffPathFmt(path)).unwrap(); - - write_excerpts( - excerpts, - sorted_insertions, - file_line_count, - include_line_numbers, - output, - ); - write!(output, "`````\n\n").unwrap(); -} - -pub fn write_excerpts<'a>( - excerpts: impl IntoIterator, - sorted_insertions: &[(Point, &str)], - file_line_count: Line, - include_line_numbers: bool, - output: &mut String, -) { - let mut current_row = Line(0); - let mut sorted_insertions = sorted_insertions.iter().peekable(); - - for excerpt in excerpts { - if excerpt.start_line > current_row { - writeln!(output, "…").unwrap(); - } - if excerpt.text.is_empty() { - return; - } - - current_row = excerpt.start_line; - - for mut line in excerpt.text.lines() { - if include_line_numbers { - write!(output, "{}|", current_row.0 + 1).unwrap(); - } - - while let Some((insertion_location, insertion_marker)) = sorted_insertions.peek() { - match current_row.cmp(&insertion_location.line) { - cmp::Ordering::Equal => { - let (prefix, suffix) = line.split_at(insertion_location.column as usize); - output.push_str(prefix); - output.push_str(insertion_marker); - line = suffix; - sorted_insertions.next(); - } - cmp::Ordering::Less => break, - cmp::Ordering::Greater => { - sorted_insertions.next(); - break; - } - } - } - output.push_str(line); - output.push('\n'); - current_row.0 += 1; - } - } - - if current_row < file_line_count { - writeln!(output, "…").unwrap(); - } -} - -pub fn push_events(output: &mut String, events: &[Arc]) { - if events.is_empty() { - return; - }; - - writeln!(output, "`````diff").unwrap(); - for event in events { - writeln!(output, "{}", event).unwrap(); - } - writeln!(output, "`````\n").unwrap(); -} - -struct PromptData { - events: Vec>, - cursor_point: Point, - cursor_path: Arc, // TODO: make a common struct with cursor_point - included_files: Vec, -} - -#[derive(Default)] -pub struct GenerationParams { - pub temperature: Option, - pub top_p: Option, - pub stop: Option>, -} - -trait PromptFormatter { - fn render(&self, data: &PromptData) -> String; - - fn generation_params() -> GenerationParams { - return GenerationParams::default(); - } -} - -struct MinimalQwenPrompt; - -impl PromptFormatter for MinimalQwenPrompt { - fn render(&self, data: &PromptData) -> String { - let edit_history = self.fmt_edit_history(data); - let context = self.fmt_context(data); - - format!( - "{instructions}\n\n{edit_history}\n\n{context}", - instructions = MinimalQwenPrompt::INSTRUCTIONS, - edit_history = edit_history, - context = context - ) - } -} - -impl MinimalQwenPrompt { - const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n"; - - fn fmt_edit_history(&self, data: &PromptData) -> String { - if data.events.is_empty() { - "(No edit history)\n\n".to_string() - } else { - let mut events_str = String::new(); - push_events(&mut events_str, &data.events); - format!( - "The following are the latest edits made by the user, from earlier to later.\n\n{}", - events_str - ) - } - } - - fn fmt_context(&self, data: &PromptData) -> String { - let mut context = String::new(); - let include_line_numbers = true; - - for related_file in &data.included_files { - writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap(); - - if related_file.path == data.cursor_path { - write!(context, "<|fim_prefix|>").unwrap(); - write_excerpts( - &related_file.excerpts, - &[(data.cursor_point, "<|fim_suffix|>")], - related_file.max_row, - include_line_numbers, - &mut context, - ); - writeln!(context, "<|fim_middle|>").unwrap(); - } else { - write_excerpts( - &related_file.excerpts, - &[], - related_file.max_row, - include_line_numbers, - &mut context, - ); - } - } - context - } -} - -struct SeedCoder1120Prompt; - -impl PromptFormatter for SeedCoder1120Prompt { - fn render(&self, data: &PromptData) -> String { - let edit_history = self.fmt_edit_history(data); - let context = self.fmt_context(data); - - format!( - "# Edit History:\n{edit_history}\n\n{context}", - edit_history = edit_history, - context = context - ) - } - - fn generation_params() -> GenerationParams { - GenerationParams { - temperature: Some(0.2), - top_p: Some(0.9), - stop: Some(vec!["<[end_of_sentence]>".into()]), - } - } -} - -impl SeedCoder1120Prompt { - fn fmt_edit_history(&self, data: &PromptData) -> String { - if data.events.is_empty() { - "(No edit history)\n\n".to_string() - } else { - let mut events_str = String::new(); - push_events(&mut events_str, &data.events); - events_str - } - } - - fn fmt_context(&self, data: &PromptData) -> String { - let mut context = String::new(); - let include_line_numbers = true; - - for related_file in &data.included_files { - writeln!(context, "# Path: {}\n", DiffPathFmt(&related_file.path)).unwrap(); - - if related_file.path == data.cursor_path { - let fim_prompt = self.fmt_fim(&related_file, data.cursor_point); - context.push_str(&fim_prompt); - } else { - write_excerpts( - &related_file.excerpts, - &[], - related_file.max_row, - include_line_numbers, - &mut context, - ); - } - } - context - } - - fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String { - let mut buf = String::new(); - const FIM_SUFFIX: &str = "<[fim-suffix]>"; - const FIM_PREFIX: &str = "<[fim-prefix]>"; - const FIM_MIDDLE: &str = "<[fim-middle]>"; - write!(buf, "{}", FIM_PREFIX).unwrap(); - write_excerpts( - &file.excerpts, - &[(cursor_point, FIM_SUFFIX)], - file.max_row, - true, - &mut buf, - ); - - // Swap prefix and suffix parts - let index = buf.find(FIM_SUFFIX).unwrap(); - let prefix = &buf[..index]; - let suffix = &buf[index..]; - - format!("{}{}{}", suffix, prefix, FIM_MIDDLE) - } -} diff --git a/crates/edit_prediction/Cargo.toml b/crates/edit_prediction/Cargo.toml index 6e62cfa6f038671d595c5671de147cdc2125064d..c9237232e5e0bb6167fbeee8732d46ee584b080b 100644 --- a/crates/edit_prediction/Cargo.toml +++ b/crates/edit_prediction/Cargo.toml @@ -21,7 +21,6 @@ arrayvec.workspace = true brotli.workspace = true client.workspace = true cloud_llm_client.workspace = true -cloud_zeta2_prompt.workspace = true collections.workspace = true copilot.workspace = true credentials_provider.workspace = true @@ -50,8 +49,6 @@ semver.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true -smol.workspace = true -strsim.workspace = true strum.workspace = true telemetry.workspace = true telemetry_events.workspace = true @@ -62,6 +59,7 @@ uuid.workspace = true workspace.workspace = true worktree.workspace = true zed_actions.workspace = true +zeta_prompt.workspace = true [dev-dependencies] clock = { workspace = true, features = ["test-support"] } diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 141fff3063b83d7e0003fddd6b4eba2d213d5fd5..b0d4a5f4d69c357fb0a153bee267a64dc0c465dd 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1,14 +1,13 @@ use anyhow::Result; use arrayvec::ArrayVec; use client::{Client, EditPredictionUsage, UserStore}; -use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat}; +use cloud_llm_client::predict_edits_v3::{self, PromptFormat}; use cloud_llm_client::{ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME, }; -use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES; use collections::{HashMap, HashSet}; use db::kvp::{Dismissable, KEY_VALUE_STORE}; use edit_prediction_context::EditPredictionExcerptOptions; @@ -16,10 +15,7 @@ use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, Rel use feature_flags::{FeatureFlag, FeatureFlagAppExt as _}; use futures::{ AsyncReadExt as _, FutureExt as _, StreamExt as _, - channel::{ - mpsc::{self, UnboundedReceiver}, - oneshot, - }, + channel::mpsc::{self, UnboundedReceiver}, select_biased, }; use gpui::BackgroundExecutor; @@ -58,8 +54,10 @@ mod onboarding_modal; pub mod open_ai_response; mod prediction; pub mod sweep_ai; + +#[cfg(any(test, feature = "test-support", feature = "eval-support"))] pub mod udiff; -mod xml_edits; + mod zed_edit_prediction_delegate; pub mod zeta1; pub mod zeta2; @@ -72,7 +70,6 @@ use crate::mercury::Mercury; use crate::onboarding_modal::ZedPredictModal; pub use crate::prediction::EditPrediction; pub use crate::prediction::EditPredictionId; -pub use crate::prediction::EditPredictionInputs; use crate::prediction::EditPredictionResult; pub use crate::sweep_ai::SweepAi; pub use telemetry_events::EditPredictionRating; @@ -112,7 +109,6 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { min_bytes: 128, target_before_cursor_over_total_bytes: 0.5, }, - max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES, prompt_format: PromptFormat::DEFAULT, }; @@ -162,7 +158,6 @@ pub struct EditPredictionStore { use_context: bool, options: ZetaOptions, update_required: bool, - debug_tx: Option>, #[cfg(feature = "eval-support")] eval_cache: Option>, edit_prediction_model: EditPredictionModel, @@ -183,10 +178,22 @@ pub enum EditPredictionModel { Mercury, } +pub struct EditPredictionModelInput { + project: Entity, + buffer: Entity, + snapshot: BufferSnapshot, + position: Anchor, + events: Vec>, + related_files: Arc<[RelatedFile]>, + recent_paths: VecDeque, + trigger: PredictEditsRequestTrigger, + diagnostic_search_range: Range, + debug_tx: Option>, +} + #[derive(Debug, Clone, PartialEq)] pub struct ZetaOptions { pub context: EditPredictionExcerptOptions, - pub max_prompt_bytes: usize, pub prompt_format: predict_edits_v3::PromptFormat, } @@ -194,7 +201,8 @@ pub struct ZetaOptions { pub enum DebugEvent { ContextRetrievalStarted(ContextRetrievalStartedDebugEvent), ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent), - EditPredictionRequested(EditPredictionRequestedDebugEvent), + EditPredictionStarted(EditPredictionStartedDebugEvent), + EditPredictionFinished(EditPredictionFinishedDebugEvent), } #[derive(Debug)] @@ -212,27 +220,30 @@ pub struct ContextRetrievalFinishedDebugEvent { } #[derive(Debug)] -pub struct EditPredictionRequestedDebugEvent { - pub inputs: EditPredictionInputs, - pub retrieval_time: Duration, +pub struct EditPredictionStartedDebugEvent { pub buffer: WeakEntity, pub position: Anchor, - pub local_prompt: Result, - pub response_rx: oneshot::Receiver<(Result, Duration)>, + pub prompt: Option, +} + +#[derive(Debug)] +pub struct EditPredictionFinishedDebugEvent { + pub buffer: WeakEntity, + pub position: Anchor, + pub model_output: Option, } pub type RequestDebugInfo = predict_edits_v3::DebugInfo; struct ProjectState { - events: VecDeque>, + events: VecDeque>, last_event: Option, recent_paths: VecDeque, registered_buffers: HashMap, current_prediction: Option, next_pending_prediction_id: usize, pending_predictions: ArrayVec, - context_updates_tx: smol::channel::Sender<()>, - context_updates_rx: smol::channel::Receiver<()>, + debug_tx: Option>, last_prediction_refresh: Option<(EntityId, Instant)>, cancelled_predictions: HashSet, context: Entity, @@ -241,7 +252,7 @@ struct ProjectState { } impl ProjectState { - pub fn events(&self, cx: &App) -> Vec> { + pub fn events(&self, cx: &App) -> Vec> { self.events .iter() .cloned() @@ -376,7 +387,7 @@ impl LastEvent { &self, license_detection_watchers: &HashMap>, cx: &App, - ) -> Option> { + ) -> Option> { let path = buffer_path_with_id_fallback(&self.new_snapshot, cx); let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx); @@ -396,7 +407,7 @@ impl LastEvent { if path == old_path && diff.is_empty() { None } else { - Some(Arc::new(predict_edits_v3::Event::BufferChange { + Some(Arc::new(zeta_prompt::Event::BufferChange { old_path, path, diff, @@ -481,7 +492,6 @@ impl EditPredictionStore { }, ), update_required: false, - debug_tx: None, #[cfg(feature = "eval-support")] eval_cache: None, edit_prediction_model: EditPredictionModel::Zeta2, @@ -536,12 +546,6 @@ impl EditPredictionStore { self.eval_cache = Some(cache); } - pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver { - let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded(); - self.debug_tx = Some(debug_watch_tx); - debug_watch_rx - } - pub fn options(&self) -> &ZetaOptions { &self.options } @@ -560,15 +564,35 @@ impl EditPredictionStore { } } + pub fn edit_history_for_project( + &self, + project: &Entity, + ) -> Vec> { + self.projects + .get(&project.entity_id()) + .map(|project_state| project_state.events.iter().cloned().collect()) + .unwrap_or_default() + } + pub fn context_for_project<'a>( &'a self, project: &Entity, cx: &'a App, - ) -> &'a [RelatedFile] { + ) -> Arc<[RelatedFile]> { self.projects .get(&project.entity_id()) .map(|project| project.context.read(cx).related_files()) - .unwrap_or(&[]) + .unwrap_or_else(|| vec![].into()) + } + + pub fn context_for_project_with_buffers<'a>( + &'a self, + project: &Entity, + cx: &'a App, + ) -> Option)>> { + self.projects + .get(&project.entity_id()) + .map(|project| project.context.read(cx).related_files_with_buffers()) } pub fn usage(&self, cx: &App) -> Option { @@ -599,85 +623,21 @@ impl EditPredictionStore { cx: &mut Context, ) -> &mut ProjectState { let entity_id = project.entity_id(); - let (context_updates_tx, context_updates_rx) = smol::channel::unbounded(); self.projects .entry(entity_id) .or_insert_with(|| ProjectState { context: { let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx)); - cx.subscribe( - &related_excerpt_store, - move |this, _, event, _| match event { - RelatedExcerptStoreEvent::StartedRefresh => { - if let Some(debug_tx) = this.debug_tx.clone() { - debug_tx - .unbounded_send(DebugEvent::ContextRetrievalStarted( - ContextRetrievalStartedDebugEvent { - project_entity_id: entity_id, - timestamp: Instant::now(), - search_prompt: String::new(), - }, - )) - .ok(); - } - } - RelatedExcerptStoreEvent::FinishedRefresh { - cache_hit_count, - cache_miss_count, - mean_definition_latency, - max_definition_latency, - } => { - if let Some(debug_tx) = this.debug_tx.clone() { - debug_tx - .unbounded_send(DebugEvent::ContextRetrievalFinished( - ContextRetrievalFinishedDebugEvent { - project_entity_id: entity_id, - timestamp: Instant::now(), - metadata: vec![ - ( - "Cache Hits", - format!( - "{}/{}", - cache_hit_count, - cache_hit_count + cache_miss_count - ) - .into(), - ), - ( - "Max LSP Time", - format!( - "{} ms", - max_definition_latency.as_millis() - ) - .into(), - ), - ( - "Mean LSP Time", - format!( - "{} ms", - mean_definition_latency.as_millis() - ) - .into(), - ), - ], - }, - )) - .ok(); - } - if let Some(project_state) = this.projects.get(&entity_id) { - project_state.context_updates_tx.send_blocking(()).ok(); - } - } - }, - ) + cx.subscribe(&related_excerpt_store, move |this, _, event, _| { + this.handle_excerpt_store_event(entity_id, event); + }) .detach(); related_excerpt_store }, events: VecDeque::new(), last_event: None, recent_paths: VecDeque::new(), - context_updates_rx, - context_updates_tx, + debug_tx: None, registered_buffers: HashMap::default(), current_prediction: None, cancelled_predictions: HashSet::default(), @@ -689,12 +649,79 @@ impl EditPredictionStore { }) } - pub fn project_context_updates( - &self, + pub fn remove_project(&mut self, project: &Entity) { + self.projects.remove(&project.entity_id()); + } + + fn handle_excerpt_store_event( + &mut self, + project_entity_id: EntityId, + event: &RelatedExcerptStoreEvent, + ) { + if let Some(project_state) = self.projects.get(&project_entity_id) { + if let Some(debug_tx) = project_state.debug_tx.clone() { + match event { + RelatedExcerptStoreEvent::StartedRefresh => { + debug_tx + .unbounded_send(DebugEvent::ContextRetrievalStarted( + ContextRetrievalStartedDebugEvent { + project_entity_id: project_entity_id, + timestamp: Instant::now(), + search_prompt: String::new(), + }, + )) + .ok(); + } + RelatedExcerptStoreEvent::FinishedRefresh { + cache_hit_count, + cache_miss_count, + mean_definition_latency, + max_definition_latency, + } => { + debug_tx + .unbounded_send(DebugEvent::ContextRetrievalFinished( + ContextRetrievalFinishedDebugEvent { + project_entity_id: project_entity_id, + timestamp: Instant::now(), + metadata: vec![ + ( + "Cache Hits", + format!( + "{}/{}", + cache_hit_count, + cache_hit_count + cache_miss_count + ) + .into(), + ), + ( + "Max LSP Time", + format!("{} ms", max_definition_latency.as_millis()) + .into(), + ), + ( + "Mean LSP Time", + format!("{} ms", mean_definition_latency.as_millis()) + .into(), + ), + ], + }, + )) + .ok(); + } + } + } + } + } + + pub fn debug_info( + &mut self, project: &Entity, - ) -> Option> { - let project_state = self.projects.get(&project.entity_id())?; - Some(project_state.context_updates_rx.clone()) + cx: &mut Context, + ) -> mpsc::UnboundedReceiver { + let project_state = self.get_or_init_project(project, cx); + let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded(); + project_state.debug_tx = Some(debug_watch_tx); + debug_watch_rx } fn handle_project_event( @@ -1348,6 +1375,7 @@ impl EditPredictionStore { let project_state = self.projects.get(&project.entity_id()).unwrap(); let events = project_state.events(cx); let has_events = !events.is_empty(); + let debug_tx = project_state.debug_tx.clone(); let snapshot = active_buffer.read(cx).snapshot(); let cursor_point = position.to_point(&snapshot); @@ -1357,55 +1385,29 @@ impl EditPredictionStore { Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0); let related_files = if self.use_context { - self.context_for_project(&project, cx).to_vec() + self.context_for_project(&project, cx) } else { - Vec::new() + Vec::new().into() + }; + + let inputs = EditPredictionModelInput { + project: project.clone(), + buffer: active_buffer.clone(), + snapshot: snapshot.clone(), + position, + events, + related_files, + recent_paths: project_state.recent_paths.clone(), + trigger, + diagnostic_search_range: diagnostic_search_range.clone(), + debug_tx, }; let task = match self.edit_prediction_model { - EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1( - self, - &project, - &active_buffer, - snapshot.clone(), - position, - events, - trigger, - cx, - ), - EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2( - self, - &project, - &active_buffer, - snapshot.clone(), - position, - events, - related_files, - trigger, - cx, - ), - EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep( - &project, - &active_buffer, - snapshot.clone(), - position, - events, - &project_state.recent_paths, - related_files, - diagnostic_search_range.clone(), - cx, - ), - EditPredictionModel::Mercury => self.mercury.request_prediction( - &project, - &active_buffer, - snapshot.clone(), - position, - events, - &project_state.recent_paths, - related_files, - diagnostic_search_range.clone(), - cx, - ), + EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx), + EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx), + EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx), + EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx), }; cx.spawn(async move |this, cx| { @@ -1706,6 +1708,20 @@ impl EditPredictionStore { } } + #[cfg(feature = "eval-support")] + pub fn set_context_for_buffer( + &mut self, + project: &Entity, + related_files: Vec, + cx: &mut Context, + ) { + self.get_or_init_project(project, cx) + .context + .update(cx, |store, _| { + store.set_related_files(related_files); + }); + } + fn is_file_open_source( &self, project: &Entity, @@ -1729,14 +1745,14 @@ impl EditPredictionStore { self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx) } - fn can_collect_events(&self, events: &[Arc]) -> bool { + fn can_collect_events(&self, events: &[Arc]) -> bool { if !self.data_collection_choice.is_enabled() { return false; } events.iter().all(|event| { matches!( event.as_ref(), - Event::BufferChange { + zeta_prompt::Event::BufferChange { in_open_source_repo: true, .. } diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 0b7e289bb32b5a10c32a4bd34f118d7cb6c7d43c..f6465b14cbd1b3357349071bc5eda399253b5328 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1,5 +1,5 @@ use super::*; -use crate::zeta1::MAX_EVENT_TOKENS; +use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS}; use client::{UserStore, test::FakeServer}; use clock::{FakeSystemClock, ReplicaId}; use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; @@ -7,7 +7,6 @@ use cloud_llm_client::{ EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse, RejectEditPredictionsBody, }; -use edit_prediction_context::Line; use futures::{ AsyncReadExt, StreamExt, channel::{mpsc, oneshot}, @@ -28,6 +27,7 @@ use settings::SettingsStore; use std::{path::Path, sync::Arc, time::Duration}; use util::{path, rel_path::rel_path}; use uuid::Uuid; +use zeta_prompt::ZetaPromptInput; use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE}; @@ -65,18 +65,21 @@ async fn test_current_state(cx: &mut TestAppContext) { ep_store.update(cx, |ep_store, cx| { ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx) }); - let (_request, respond_tx) = requests.predict.next().await.unwrap(); + let (request, respond_tx) = requests.predict.next().await.unwrap(); respond_tx - .send(model_response(indoc! {r" - --- a/root/1.txt - +++ b/root/1.txt - @@ ... @@ - Hello! - -How - +How are you? - Bye - "})) + .send(model_response( + request, + indoc! {r" + --- a/root/1.txt + +++ b/root/1.txt + @@ ... @@ + Hello! + -How + +How are you? + Bye + "}, + )) .unwrap(); cx.run_until_parked(); @@ -120,16 +123,20 @@ async fn test_current_state(cx: &mut TestAppContext) { }); }); - let (_request, respond_tx) = requests.predict.next().await.unwrap(); + let (request, respond_tx) = requests.predict.next().await.unwrap(); respond_tx - .send(model_response(indoc! {r#" - --- a/root/2.txt - +++ b/root/2.txt - Hola! - -Como - +Como estas? - Adios - "#})) + .send(model_response( + request, + indoc! {r#" + --- a/root/2.txt + +++ b/root/2.txt + @@ ... @@ + Hola! + -Como + +Como estas? + Adios + "#}, + )) .unwrap(); cx.run_until_parked(); @@ -186,7 +193,7 @@ async fn test_simple_request(cx: &mut TestAppContext) { ep_store.request_prediction(&project, &buffer, position, Default::default(), cx) }); - let (_, respond_tx) = requests.predict.next().await.unwrap(); + let (request, respond_tx) = requests.predict.next().await.unwrap(); // TODO Put back when we have a structured request again // assert_eq!( @@ -202,15 +209,18 @@ async fn test_simple_request(cx: &mut TestAppContext) { // ); respond_tx - .send(model_response(indoc! { r" - --- a/root/foo.md - +++ b/root/foo.md - @@ ... @@ - Hello! - -How - +How are you? - Bye - "})) + .send(model_response( + request, + indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are you? + Bye + "}, + )) .unwrap(); let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap(); @@ -276,15 +286,18 @@ async fn test_request_events(cx: &mut TestAppContext) { ); respond_tx - .send(model_response(indoc! {r#" - --- a/root/foo.md - +++ b/root/foo.md - @@ ... @@ - Hello! - -How - +How are you? - Bye - "#})) + .send(model_response( + request, + indoc! {r#" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are you? + Bye + "#}, + )) .unwrap(); let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap(); @@ -324,18 +337,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) { ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); }); - const NO_OP_DIFF: &str = indoc! { r" - --- a/root/foo.md - +++ b/root/foo.md - @@ ... @@ - Hello! - -How - +How - Bye - "}; - - let (_, respond_tx) = requests.predict.next().await.unwrap(); - let response = model_response(NO_OP_DIFF); + let (request, respond_tx) = requests.predict.next().await.unwrap(); + let response = model_response(request, ""); let id = response.id.clone(); respond_tx.send(response).unwrap(); @@ -389,13 +392,13 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) { ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); }); - let (_, respond_tx) = requests.predict.next().await.unwrap(); + let (request, respond_tx) = requests.predict.next().await.unwrap(); buffer.update(cx, |buffer, cx| { buffer.set_text("Hello!\nHow are you?\nBye", cx); }); - let response = model_response(SIMPLE_DIFF); + let response = model_response(request, SIMPLE_DIFF); let id = response.id.clone(); respond_tx.send(response).unwrap(); @@ -459,8 +462,8 @@ async fn test_replace_current(cx: &mut TestAppContext) { ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); }); - let (_, respond_tx) = requests.predict.next().await.unwrap(); - let first_response = model_response(SIMPLE_DIFF); + let (request, respond_tx) = requests.predict.next().await.unwrap(); + let first_response = model_response(request, SIMPLE_DIFF); let first_id = first_response.id.clone(); respond_tx.send(first_response).unwrap(); @@ -482,8 +485,8 @@ async fn test_replace_current(cx: &mut TestAppContext) { ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); }); - let (_, respond_tx) = requests.predict.next().await.unwrap(); - let second_response = model_response(SIMPLE_DIFF); + let (request, respond_tx) = requests.predict.next().await.unwrap(); + let second_response = model_response(request, SIMPLE_DIFF); let second_id = second_response.id.clone(); respond_tx.send(second_response).unwrap(); @@ -541,8 +544,8 @@ async fn test_current_preferred(cx: &mut TestAppContext) { ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); }); - let (_, respond_tx) = requests.predict.next().await.unwrap(); - let first_response = model_response(SIMPLE_DIFF); + let (request, respond_tx) = requests.predict.next().await.unwrap(); + let first_response = model_response(request, SIMPLE_DIFF); let first_id = first_response.id.clone(); respond_tx.send(first_response).unwrap(); @@ -564,17 +567,20 @@ async fn test_current_preferred(cx: &mut TestAppContext) { ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); }); - let (_, respond_tx) = requests.predict.next().await.unwrap(); + let (request, respond_tx) = requests.predict.next().await.unwrap(); // worse than current prediction - let second_response = model_response(indoc! { r" - --- a/root/foo.md - +++ b/root/foo.md - @@ ... @@ - Hello! - -How - +How are - Bye - "}); + let second_response = model_response( + request, + indoc! { r" + --- a/root/foo.md + +++ b/root/foo.md + @@ ... @@ + Hello! + -How + +How are + Bye + "}, + ); let second_id = second_response.id.clone(); respond_tx.send(second_response).unwrap(); @@ -633,19 +639,19 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) { ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); }); - let (_, respond_first) = requests.predict.next().await.unwrap(); + let (request1, respond_first) = requests.predict.next().await.unwrap(); ep_store.update(cx, |ep_store, cx| { ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); }); - let (_, respond_second) = requests.predict.next().await.unwrap(); + let (request, respond_second) = requests.predict.next().await.unwrap(); // wait for throttle cx.run_until_parked(); // second responds first - let second_response = model_response(SIMPLE_DIFF); + let second_response = model_response(request, SIMPLE_DIFF); let second_id = second_response.id.clone(); respond_second.send(second_response).unwrap(); @@ -663,7 +669,7 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) { ); }); - let first_response = model_response(SIMPLE_DIFF); + let first_response = model_response(request1, SIMPLE_DIFF); let first_id = first_response.id.clone(); respond_first.send(first_response).unwrap(); @@ -724,13 +730,13 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); }); - let (_, respond_first) = requests.predict.next().await.unwrap(); + let (request1, respond_first) = requests.predict.next().await.unwrap(); ep_store.update(cx, |ep_store, cx| { ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx); }); - let (_, respond_second) = requests.predict.next().await.unwrap(); + let (request2, respond_second) = requests.predict.next().await.unwrap(); // wait for throttle, so requests are sent cx.run_until_parked(); @@ -754,9 +760,9 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { // wait for throttle cx.run_until_parked(); - let (_, respond_third) = requests.predict.next().await.unwrap(); + let (request3, respond_third) = requests.predict.next().await.unwrap(); - let first_response = model_response(SIMPLE_DIFF); + let first_response = model_response(request1, SIMPLE_DIFF); let first_id = first_response.id.clone(); respond_first.send(first_response).unwrap(); @@ -774,7 +780,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { ); }); - let cancelled_response = model_response(SIMPLE_DIFF); + let cancelled_response = model_response(request2, SIMPLE_DIFF); let cancelled_id = cancelled_response.id.clone(); respond_second.send(cancelled_response).unwrap(); @@ -792,7 +798,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) { ); }); - let third_response = model_response(SIMPLE_DIFF); + let third_response = model_response(request3, SIMPLE_DIFF); let third_response_id = third_response.id.clone(); respond_third.send(third_response).unwrap(); @@ -1036,7 +1042,24 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) { // ); // } -fn model_response(text: &str) -> open_ai::Response { +// Generate a model response that would apply the given diff to the active file. +fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response { + let prompt = match &request.messages[0] { + open_ai::RequestMessage::User { + content: open_ai::MessageContent::Plain(content), + } => content, + _ => panic!("unexpected request {request:?}"), + }; + + let open = "\n"; + let close = ""; + let cursor = "<|user_cursor|>"; + + let start_ix = open.len() + prompt.find(open).unwrap(); + let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap(); + let excerpt = prompt[start_ix..end_ix].replace(cursor, ""); + let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap(); + open_ai::Response { id: Uuid::new_v4().to_string(), object: "response".into(), @@ -1045,7 +1068,7 @@ fn model_response(text: &str) -> open_ai::Response { choices: vec![open_ai::Choice { index: 0, message: open_ai::RequestMessage::Assistant { - content: Some(open_ai::MessageContent::Plain(text.to_string())), + content: Some(open_ai::MessageContent::Plain(new_excerpt)), tool_calls: vec![], }, finish_reason: None, @@ -1160,20 +1183,19 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx)) .await; - let completion = EditPrediction { + let prediction = EditPrediction { edits, edit_preview, buffer: buffer.clone(), snapshot: cx.read(|cx| buffer.read(cx).snapshot()), id: EditPredictionId("the-id".into()), - inputs: EditPredictionInputs { + inputs: ZetaPromptInput { events: Default::default(), - included_files: Default::default(), - cursor_point: cloud_llm_client::predict_edits_v3::Point { - line: Line(0), - column: 0, - }, + related_files: Default::default(), cursor_path: Path::new("").into(), + cursor_excerpt: "".into(), + editable_range_in_excerpt: 0..0, + cursor_offset_in_excerpt: 0, }, buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), @@ -1182,7 +1204,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { cx.update(|cx| { assert_eq!( from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), &buffer, cx ), @@ -1192,7 +1214,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx)); assert_eq!( from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), &buffer, cx ), @@ -1202,7 +1224,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { buffer.update(cx, |buffer, cx| buffer.undo(cx)); assert_eq!( from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), &buffer, cx ), @@ -1212,7 +1234,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx)); assert_eq!( from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), &buffer, cx ), @@ -1222,7 +1244,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx)); assert_eq!( from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), &buffer, cx ), @@ -1232,7 +1254,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx)); assert_eq!( from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), &buffer, cx ), @@ -1242,7 +1264,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx)); assert_eq!( from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), &buffer, cx ), @@ -1252,7 +1274,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx)); assert_eq!( from_completion_edits( - &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(), + &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(), &buffer, cx ), @@ -1260,7 +1282,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { ); buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx)); - assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None); + assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None); }) } diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index 40c0fdfac021f937df5172fd423d3b6bfc5f8146..f3a3afc53fc5e175fdbda2dc6b5867da6fd38feb 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -1,20 +1,17 @@ use anyhow::{Context as _, Result}; -use cloud_llm_client::predict_edits_v3::Event; use credentials_provider::CredentialsProvider; -use edit_prediction_context::RelatedFile; use futures::{AsyncReadExt as _, FutureExt, future::Shared}; use gpui::{ - App, AppContext as _, Entity, Task, + App, AppContext as _, Task, http_client::{self, AsyncBody, Method}, }; -use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _}; -use project::{Project, ProjectPath}; -use std::{ - collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant, -}; +use language::{OffsetRangeExt as _, ToOffset, ToPoint as _}; +use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant}; +use zeta_prompt::ZetaPromptInput; use crate::{ - EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response, + DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput, + EditPredictionStartedDebugEvent, open_ai_response::text_from_response, prediction::EditPredictionResult, }; @@ -38,16 +35,17 @@ impl Mercury { store_api_token_in_keychain(api_token, cx) } - pub fn request_prediction( + pub(crate) fn request_prediction( &self, - _project: &Entity, - active_buffer: &Entity, - snapshot: BufferSnapshot, - position: language::Anchor, - events: Vec>, - _recent_paths: &VecDeque, - related_files: Vec, - _diagnostic_search_range: Range, + EditPredictionModelInput { + buffer, + snapshot, + position, + events, + related_files, + debug_tx, + .. + }: EditPredictionModelInput, cx: &mut App, ) -> Task>> { let Some(api_token) = self.api_token.clone().now_or_never().flatten() else { @@ -62,6 +60,7 @@ impl Mercury { let http_client = cx.http_client(); let cursor_point = position.to_point(&snapshot); let buffer_snapshotted_at = Instant::now(); + let active_buffer = buffer.clone(); let result = cx.background_spawn(async move { let (editable_range, context_range) = @@ -72,39 +71,39 @@ impl Mercury { MAX_REWRITE_TOKENS, ); - let offset_range = editable_range.to_offset(&snapshot); - let prompt = build_prompt( - &events, - &related_files, - &snapshot, - full_path.as_ref(), - cursor_point, - editable_range, - context_range.clone(), - ); - - let inputs = EditPredictionInputs { - events: events, - included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile { - path: full_path.clone(), - max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row), - excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt { - start_line: cloud_llm_client::predict_edits_v3::Line( - context_range.start.row, - ), - text: snapshot - .text_for_range(context_range.clone()) - .collect::() - .into(), - }], - }], - cursor_point: cloud_llm_client::predict_edits_v3::Point { - column: cursor_point.column, - line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row), - }, + let context_offset_range = context_range.to_offset(&snapshot); + + let editable_offset_range = editable_range.to_offset(&snapshot); + + let inputs = zeta_prompt::ZetaPromptInput { + events, + related_files, + cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) + - context_range.start.to_offset(&snapshot), cursor_path: full_path.clone(), + cursor_excerpt: snapshot + .text_for_range(context_range) + .collect::() + .into(), + editable_range_in_excerpt: (editable_offset_range.start + - context_offset_range.start) + ..(editable_offset_range.end - context_offset_range.start), }; + let prompt = build_prompt(&inputs); + + if let Some(debug_tx) = &debug_tx { + debug_tx + .unbounded_send(DebugEvent::EditPredictionStarted( + EditPredictionStartedDebugEvent { + buffer: active_buffer.downgrade(), + prompt: Some(prompt.clone()), + position, + }, + )) + .ok(); + } + let request_body = open_ai::Request { model: "mercury-coder".into(), messages: vec![open_ai::RequestMessage::User { @@ -160,6 +159,18 @@ impl Mercury { let id = mem::take(&mut response.id); let response_str = text_from_response(response).unwrap_or_default(); + if let Some(debug_tx) = &debug_tx { + debug_tx + .unbounded_send(DebugEvent::EditPredictionFinished( + EditPredictionFinishedDebugEvent { + buffer: active_buffer.downgrade(), + model_output: Some(response_str.clone()), + position, + }, + )) + .ok(); + } + let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str); let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str); @@ -168,15 +179,16 @@ impl Mercury { if response_str != NO_PREDICTION_OUTPUT { let old_text = snapshot - .text_for_range(offset_range.clone()) + .text_for_range(editable_offset_range.clone()) .collect::(); edits.extend( language::text_diff(&old_text, &response_str) .into_iter() .map(|(range, text)| { ( - snapshot.anchor_after(offset_range.start + range.start) - ..snapshot.anchor_before(offset_range.start + range.end), + snapshot.anchor_after(editable_offset_range.start + range.start) + ..snapshot + .anchor_before(editable_offset_range.start + range.end), text, ) }), @@ -186,8 +198,6 @@ impl Mercury { anyhow::Ok((id, edits, snapshot, response_received_at, inputs)) }); - let buffer = active_buffer.clone(); - cx.spawn(async move |cx| { let (id, edits, old_snapshot, response_received_at, inputs) = result.await.context("Mercury edit prediction failed")?; @@ -208,15 +218,7 @@ impl Mercury { } } -fn build_prompt( - events: &[Arc], - related_files: &[RelatedFile], - cursor_buffer: &BufferSnapshot, - cursor_buffer_path: &Path, - cursor_point: Point, - editable_range: Range, - context_range: Range, -) -> String { +fn build_prompt(inputs: &ZetaPromptInput) -> String { const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n"; const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n"; const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n"; @@ -237,14 +239,14 @@ fn build_prompt( &mut prompt, RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END, |prompt| { - for related_file in related_files { + for related_file in inputs.related_files.iter() { for related_excerpt in &related_file.excerpts { push_delimited( prompt, RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END, |prompt| { prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX); - prompt.push_str(related_file.path.path.as_unix_str()); + prompt.push_str(related_file.path.to_string_lossy().as_ref()); prompt.push('\n'); prompt.push_str(&related_excerpt.text.to_string()); }, @@ -259,21 +261,22 @@ fn build_prompt( CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END, |prompt| { prompt.push_str(CURRENT_FILE_PATH_PREFIX); - prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref()); + prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref()); prompt.push('\n'); - let prefix_range = context_range.start..editable_range.start; - let suffix_range = editable_range.end..context_range.end; - - prompt.extend(cursor_buffer.text_for_range(prefix_range)); + prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]); push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| { - let range_before_cursor = editable_range.start..cursor_point; - let range_after_cursor = cursor_point..editable_range.end; - prompt.extend(cursor_buffer.text_for_range(range_before_cursor)); + prompt.push_str( + &inputs.cursor_excerpt + [inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt], + ); prompt.push_str(CURSOR_TAG); - prompt.extend(cursor_buffer.text_for_range(range_after_cursor)); + prompt.push_str( + &inputs.cursor_excerpt + [inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end], + ); }); - prompt.extend(cursor_buffer.text_for_range(suffix_range)); + prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]); }, ); @@ -281,8 +284,8 @@ fn build_prompt( &mut prompt, EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END, |prompt| { - for event in events { - writeln!(prompt, "{event}").unwrap(); + for event in inputs.events.iter() { + zeta_prompt::write_event(prompt, &event); } }, ); diff --git a/crates/edit_prediction/src/prediction.rs b/crates/edit_prediction/src/prediction.rs index 8aa2a8218568a99404cc9aceff36b84127700152..c63640ccd0e1815b32f736e8a0fee8d75d124df1 100644 --- a/crates/edit_prediction/src/prediction.rs +++ b/crates/edit_prediction/src/prediction.rs @@ -1,6 +1,5 @@ use std::{ ops::Range, - path::Path, sync::Arc, time::{Duration, Instant}, }; @@ -9,7 +8,7 @@ use cloud_llm_client::EditPredictionRejectReason; use edit_prediction_types::interpolate_edits; use gpui::{AsyncApp, Entity, SharedString}; use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot}; -use serde::Serialize; +use zeta_prompt::ZetaPromptInput; #[derive(Clone, Default, Debug, PartialEq, Eq, Hash)] pub struct EditPredictionId(pub SharedString); @@ -40,7 +39,7 @@ impl EditPredictionResult { edits: Arc<[(Range, Arc)]>, buffer_snapshotted_at: Instant, response_received_at: Instant, - inputs: EditPredictionInputs, + inputs: ZetaPromptInput, cx: &mut AsyncApp, ) -> Self { if edits.is_empty() { @@ -94,15 +93,7 @@ pub struct EditPrediction { pub buffer: Entity, pub buffer_snapshotted_at: Instant, pub response_received_at: Instant, - pub inputs: EditPredictionInputs, -} - -#[derive(Debug, Clone, Serialize)] -pub struct EditPredictionInputs { - pub events: Vec>, - pub included_files: Vec, - pub cursor_point: cloud_llm_client::predict_edits_v3::Point, - pub cursor_path: Arc, + pub inputs: zeta_prompt::ZetaPromptInput, } impl EditPrediction { @@ -133,9 +124,12 @@ impl std::fmt::Debug for EditPrediction { #[cfg(test)] mod tests { + use std::path::Path; + use super::*; use gpui::{App, Entity, TestAppContext, prelude::*}; use language::{Buffer, ToOffset as _}; + use zeta_prompt::ZetaPromptInput; #[gpui::test] async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) { @@ -154,14 +148,13 @@ mod tests { snapshot: cx.read(|cx| buffer.read(cx).snapshot()), buffer: buffer.clone(), edit_preview, - inputs: EditPredictionInputs { + inputs: ZetaPromptInput { events: vec![], - included_files: vec![], - cursor_point: cloud_llm_client::predict_edits_v3::Point { - line: cloud_llm_client::predict_edits_v3::Line(0), - column: 0, - }, + related_files: vec![].into(), cursor_path: Path::new("path.txt").into(), + cursor_offset_in_excerpt: 0, + cursor_excerpt: "".into(), + editable_range_in_excerpt: 0..0, }, buffer_snapshotted_at: Instant::now(), response_received_at: Instant::now(), diff --git a/crates/edit_prediction/src/sweep_ai.rs b/crates/edit_prediction/src/sweep_ai.rs index 4bb014c640cb489db29c800835a58febf91a7270..f65749ceadf6e05fc3b56838c03234b2f83dc51e 100644 --- a/crates/edit_prediction/src/sweep_ai.rs +++ b/crates/edit_prediction/src/sweep_ai.rs @@ -1,26 +1,21 @@ use anyhow::{Context as _, Result}; -use cloud_llm_client::predict_edits_v3::Event; use credentials_provider::CredentialsProvider; -use edit_prediction_context::RelatedFile; use futures::{AsyncReadExt as _, FutureExt, future::Shared}; use gpui::{ - App, AppContext as _, Entity, Task, + App, AppContext as _, Task, http_client::{self, AsyncBody, Method}, }; -use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _}; +use language::{Point, ToOffset as _}; use lsp::DiagnosticSeverity; -use project::{Project, ProjectPath}; use serde::{Deserialize, Serialize}; use std::{ - collections::VecDeque, fmt::{self, Write as _}, - ops::Range, path::Path, sync::Arc, time::Instant, }; -use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult}; +use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult}; const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete"; @@ -44,40 +39,34 @@ impl SweepAi { pub fn request_prediction_with_sweep( &self, - project: &Entity, - active_buffer: &Entity, - snapshot: BufferSnapshot, - position: language::Anchor, - events: Vec>, - recent_paths: &VecDeque, - related_files: Vec, - diagnostic_search_range: Range, + inputs: EditPredictionModelInput, cx: &mut App, ) -> Task>> { let debug_info = self.debug_info.clone(); let Some(api_token) = self.api_token.clone().now_or_never().flatten() else { return Task::ready(Ok(None)); }; - let full_path: Arc = snapshot + let full_path: Arc = inputs + .snapshot .file() .map(|file| file.full_path(cx)) .unwrap_or_else(|| "untitled".into()) .into(); - let project_file = project::File::from_dyn(snapshot.file()); + let project_file = project::File::from_dyn(inputs.snapshot.file()); let repo_name = project_file .map(|file| file.worktree.read(cx).root_name_str()) .unwrap_or("untitled") .into(); - let offset = position.to_offset(&snapshot); + let offset = inputs.position.to_offset(&inputs.snapshot); - let recent_buffers = recent_paths.iter().cloned(); + let recent_buffers = inputs.recent_paths.iter().cloned(); let http_client = cx.http_client(); let recent_buffer_snapshots = recent_buffers .filter_map(|project_path| { - let buffer = project.read(cx).get_open_buffer(&project_path, cx)?; - if active_buffer == &buffer { + let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?; + if inputs.buffer == buffer { None } else { Some(buffer.read(cx).snapshot()) @@ -86,14 +75,13 @@ impl SweepAi { .take(3) .collect::>(); - let cursor_point = position.to_point(&snapshot); let buffer_snapshotted_at = Instant::now(); let result = cx.background_spawn(async move { - let text = snapshot.text(); + let text = inputs.snapshot.text(); let mut recent_changes = String::new(); - for event in &events { + for event in &inputs.events { write_event(event.as_ref(), &mut recent_changes).unwrap(); } @@ -122,20 +110,23 @@ impl SweepAi { }) .collect::>(); - let retrieval_chunks = related_files + let retrieval_chunks = inputs + .related_files .iter() .flat_map(|related_file| { related_file.excerpts.iter().map(|excerpt| FileChunk { - file_path: related_file.path.path.as_unix_str().to_string(), - start_line: excerpt.point_range.start.row as usize, - end_line: excerpt.point_range.end.row as usize, + file_path: related_file.path.to_string_lossy().to_string(), + start_line: excerpt.row_range.start as usize, + end_line: excerpt.row_range.end as usize, content: excerpt.text.to_string(), timestamp: None, }) }) .collect(); - let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false); + let diagnostic_entries = inputs + .snapshot + .diagnostics_in_range(inputs.diagnostic_search_range, false); let mut diagnostic_content = String::new(); let mut diagnostic_count = 0; @@ -195,21 +186,14 @@ impl SweepAi { serde_json::to_writer(writer, &request_body)?; let body: AsyncBody = buf.into(); - let inputs = EditPredictionInputs { - events, - included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile { - path: full_path.clone(), - max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row), - excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt { - start_line: cloud_llm_client::predict_edits_v3::Line(0), - text: request_body.file_contents.into(), - }], - }], - cursor_point: cloud_llm_client::predict_edits_v3::Point { - column: cursor_point.column, - line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row), - }, + let ep_inputs = zeta_prompt::ZetaPromptInput { + events: inputs.events, + related_files: inputs.related_files.clone(), cursor_path: full_path.clone(), + cursor_excerpt: request_body.file_contents.into(), + // we actually don't know + editable_range_in_excerpt: 0..inputs.snapshot.len(), + cursor_offset_in_excerpt: request_body.cursor_position, }; let request = http_client::Request::builder() @@ -237,15 +221,20 @@ impl SweepAi { let response: AutocompleteResponse = serde_json::from_slice(&body)?; - let old_text = snapshot + let old_text = inputs + .snapshot .text_for_range(response.start_index..response.end_index) .collect::(); let edits = language::text_diff(&old_text, &response.completion) .into_iter() .map(|(range, text)| { ( - snapshot.anchor_after(response.start_index + range.start) - ..snapshot.anchor_before(response.start_index + range.end), + inputs + .snapshot + .anchor_after(response.start_index + range.start) + ..inputs + .snapshot + .anchor_before(response.start_index + range.end), text, ) }) @@ -254,13 +243,13 @@ impl SweepAi { anyhow::Ok(( response.autocomplete_id, edits, - snapshot, + inputs.snapshot, response_received_at, - inputs, + ep_inputs, )) }); - let buffer = active_buffer.clone(); + let buffer = inputs.buffer.clone(); cx.spawn(async move |cx| { let (id, edits, old_snapshot, response_received_at, inputs) = result.await?; @@ -403,12 +392,9 @@ struct AdditionalCompletion { pub finish_reason: Option, } -fn write_event( - event: &cloud_llm_client::predict_edits_v3::Event, - f: &mut impl fmt::Write, -) -> fmt::Result { +fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result { match event { - cloud_llm_client::predict_edits_v3::Event::BufferChange { + zeta_prompt::Event::BufferChange { old_path, path, diff, diff --git a/crates/edit_prediction/src/udiff.rs b/crates/edit_prediction/src/udiff.rs index 5ae029c6c16c2c6b6d0c2451cc961e8399a64a8f..b9cf564c16d68a98baa1986333f2bfd767c6a24b 100644 --- a/crates/edit_prediction/src/udiff.rs +++ b/crates/edit_prediction/src/udiff.rs @@ -14,68 +14,18 @@ use anyhow::anyhow; use collections::HashMap; use gpui::AsyncApp; use gpui::Entity; -use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot}; +use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot}; use project::Project; -pub async fn parse_diff<'a>( - diff_str: &'a str, - get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range])> + Send, -) -> Result<(&'a BufferSnapshot, Vec<(Range, Arc)>)> { - let mut diff = DiffParser::new(diff_str); - let mut edited_buffer = None; - let mut edits = Vec::new(); - - while let Some(event) = diff.next()? { - match event { - DiffEvent::Hunk { - path: file_path, - hunk, - } => { - let (buffer, ranges) = match edited_buffer { - None => { - edited_buffer = get_buffer(&Path::new(file_path.as_ref())); - edited_buffer - .as_ref() - .context("Model tried to edit a file that wasn't included")? - } - Some(ref current) => current, - }; - - edits.extend( - resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges) - .with_context(|| format!("Diff:\n{diff_str}"))?, - ); - } - DiffEvent::FileEnd { renamed_to } => { - let (buffer, _) = edited_buffer - .take() - .context("Got a FileEnd event before an Hunk event")?; - - if renamed_to.is_some() { - anyhow::bail!("edit predictions cannot rename files"); - } - - if diff.next()?.is_some() { - anyhow::bail!("Edited more than one file"); - } - - return Ok((buffer, edits)); - } - } - } - - Err(anyhow::anyhow!("No EOF")) -} - -#[derive(Debug)] -pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap, Entity>); +#[derive(Clone, Debug)] +pub struct OpenedBuffers(#[allow(unused)] HashMap>); #[must_use] -pub async fn apply_diff<'a>( - diff_str: &'a str, +pub async fn apply_diff( + diff_str: &str, project: &Entity, cx: &mut AsyncApp, -) -> Result> { +) -> Result { let mut included_files = HashMap::default(); for line in diff_str.lines() { @@ -94,7 +44,7 @@ pub async fn apply_diff<'a>( })?? .await?; - included_files.insert(path, buffer); + included_files.insert(path.to_string(), buffer); } } @@ -113,7 +63,7 @@ pub async fn apply_diff<'a>( let (buffer, ranges) = match current_file { None => { let buffer = included_files - .get_mut(&file_path) + .get_mut(file_path.as_ref()) .expect("Opened all files in diff"); current_file = Some((buffer, ranges.as_slice())); @@ -167,6 +117,29 @@ pub async fn apply_diff<'a>( Ok(OpenedBuffers(included_files)) } +pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result { + let mut diff = DiffParser::new(diff_str); + + let mut text = text.to_string(); + + while let Some(event) = diff.next()? { + match event { + DiffEvent::Hunk { hunk, .. } => { + let hunk_offset = text + .find(&hunk.context) + .ok_or_else(|| anyhow!("couldn't result hunk {:?}", hunk.context))?; + for edit in hunk.edits.iter().rev() { + let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end); + text.replace_range(range, &edit.text); + } + } + DiffEvent::FileEnd { .. } => {} + } + } + + Ok(text) +} + struct PatchFile<'a> { old_path: Cow<'a, str>, new_path: Cow<'a, str>, @@ -492,7 +465,6 @@ mod tests { use super::*; use gpui::TestAppContext; use indoc::indoc; - use language::Point; use pretty_assertions::assert_eq; use project::{FakeFs, Project}; use serde_json::json; @@ -817,137 +789,6 @@ mod tests { }); } - #[gpui::test] - async fn test_apply_diff_non_unique(cx: &mut TestAppContext) { - let fs = init_test(cx); - - let buffer_1_text = indoc! {r#" - one - two - three - four - five - one - two - three - four - five - "# }; - - fs.insert_tree( - path!("/root"), - json!({ - "file1": buffer_1_text, - }), - ) - .await; - - let project = Project::test(fs, [path!("/root").as_ref()], cx).await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/root/file1"), cx) - }) - .await - .unwrap(); - let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); - - let diff = indoc! {r#" - --- a/root/file1 - +++ b/root/file1 - one - two - -three - +3 - four - five - "#}; - - let final_text = indoc! {r#" - one - two - three - four - five - one - two - 3 - four - five - "#}; - - apply_diff(diff, &project, &mut cx.to_async()) - .await - .expect_err("Non-unique edits should fail"); - - let ranges = [buffer_snapshot.anchor_before(Point::new(1, 0)) - ..buffer_snapshot.anchor_after(buffer_snapshot.max_point())]; - - let (edited_snapshot, edits) = parse_diff(diff, |_path| Some((&buffer_snapshot, &ranges))) - .await - .unwrap(); - - assert_eq!(edited_snapshot.remote_id(), buffer_snapshot.remote_id()); - buffer.update(cx, |buffer, cx| { - buffer.edit(edits, None, cx); - assert_eq!(buffer.text(), final_text); - }); - } - - #[gpui::test] - async fn test_parse_diff_with_edits_within_line(cx: &mut TestAppContext) { - let fs = init_test(cx); - - let buffer_1_text = indoc! {r#" - one two three four - five six seven eight - nine ten eleven twelve - "# }; - - fs.insert_tree( - path!("/root"), - json!({ - "file1": buffer_1_text, - }), - ) - .await; - - let project = Project::test(fs, [path!("/root").as_ref()], cx).await; - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/root/file1"), cx) - }) - .await - .unwrap(); - let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); - - let diff = indoc! {r#" - --- a/root/file1 - +++ b/root/file1 - one two three four - -five six seven eight - +five SIX seven eight! - nine ten eleven twelve - "#}; - - let (buffer, edits) = parse_diff(diff, |_path| { - Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_])) - }) - .await - .unwrap(); - - let edits = edits - .into_iter() - .map(|(range, text)| (range.to_point(&buffer), text)) - .collect::>(); - assert_eq!( - edits, - &[ - (Point::new(1, 5)..Point::new(1, 8), "SIX".into()), - (Point::new(1, 20)..Point::new(1, 20), "!".into()) - ] - ); - } - #[gpui::test] async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) { let fs = init_test(cx); diff --git a/crates/edit_prediction/src/xml_edits.rs b/crates/edit_prediction/src/xml_edits.rs deleted file mode 100644 index ee8dd47cb25ad3dcd2c3d7d172b62e724b41c22d..0000000000000000000000000000000000000000 --- a/crates/edit_prediction/src/xml_edits.rs +++ /dev/null @@ -1,637 +0,0 @@ -use anyhow::{Context as _, Result}; -use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point}; -use std::{cmp, ops::Range, path::Path, sync::Arc}; - -const EDITS_TAG_NAME: &'static str = "edits"; -const OLD_TEXT_TAG_NAME: &'static str = "old_text"; -const NEW_TEXT_TAG_NAME: &'static str = "new_text"; -const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME]; - -pub async fn parse_xml_edits<'a>( - input: &'a str, - get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range])> + Send, -) -> Result<(&'a BufferSnapshot, Vec<(Range, Arc)>)> { - parse_xml_edits_inner(input, get_buffer) - .await - .with_context(|| format!("Failed to parse XML edits:\n{input}")) -} - -async fn parse_xml_edits_inner<'a>( - input: &'a str, - get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range])> + Send, -) -> Result<(&'a BufferSnapshot, Vec<(Range, Arc)>)> { - let xml_edits = extract_xml_replacements(input)?; - - let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref()) - .with_context(|| format!("no buffer for file {}", xml_edits.file_path))?; - - let mut all_edits = vec![]; - for (old_text, new_text) in xml_edits.replacements { - let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?; - let matched_old_text = buffer - .text_for_range(match_range.clone()) - .collect::(); - let edits_within_hunk = language::text_diff(&matched_old_text, new_text); - all_edits.extend( - edits_within_hunk - .into_iter() - .map(move |(inner_range, inner_text)| { - ( - buffer.anchor_after(match_range.start + inner_range.start) - ..buffer.anchor_before(match_range.start + inner_range.end), - inner_text, - ) - }), - ); - } - - Ok((buffer, all_edits)) -} - -fn fuzzy_match_in_ranges( - old_text: &str, - buffer: &BufferSnapshot, - context_ranges: &[Range], -) -> Result> { - let mut state = FuzzyMatcher::new(buffer, old_text); - let mut best_match = None; - let mut tie_match_range = None; - - for range in context_ranges { - let best_match_cost = best_match.as_ref().map(|(score, _)| *score); - match (best_match_cost, state.match_range(range.to_offset(buffer))) { - (Some(lowest_cost), Some((new_cost, new_range))) => { - if new_cost == lowest_cost { - tie_match_range = Some(new_range); - } else if new_cost < lowest_cost { - tie_match_range.take(); - best_match = Some((new_cost, new_range)); - } - } - (None, Some(new_match)) => { - best_match = Some(new_match); - } - (None, None) | (Some(_), None) => {} - }; - } - - if let Some((_, best_match_range)) = best_match { - if let Some(tie_match_range) = tie_match_range { - anyhow::bail!( - "Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}", - best_match_range.clone(), - buffer.text_for_range(best_match_range).collect::(), - tie_match_range.clone(), - buffer.text_for_range(tie_match_range).collect::() - ); - } - return Ok(best_match_range); - } - - anyhow::bail!( - "Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```", - old_text, - context_ranges - .iter() - .map(|range| buffer.text_for_range(range.clone()).collect::()) - .collect::>() - .join("```\n```") - ); -} - -#[derive(Debug)] -struct XmlEdits<'a> { - file_path: &'a str, - /// Vec of (old_text, new_text) pairs - replacements: Vec<(&'a str, &'a str)>, -} - -fn extract_xml_replacements(input: &str) -> Result> { - let mut cursor = 0; - - let (edits_body_start, edits_attrs) = - find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?; - - let file_path = edits_attrs - .trim_start() - .strip_prefix("path") - .context("no path attribute on edits tag")? - .trim_end() - .strip_prefix('=') - .context("no value for path attribute")? - .trim() - .trim_start_matches('"') - .trim_end_matches('"'); - - cursor = edits_body_start; - let mut edits_list = Vec::new(); - - while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? { - let old_body_end = find_tag_close(input, &mut cursor)?; - let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]); - - let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)? - .context("no new_text tag following old_text")?; - let new_body_end = find_tag_close(input, &mut cursor)?; - let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]); - - edits_list.push((old_text, new_text)); - } - - Ok(XmlEdits { - file_path, - replacements: edits_list, - }) -} - -/// Trims a single leading and trailing newline -fn trim_surrounding_newlines(input: &str) -> &str { - let start = input.strip_prefix('\n').unwrap_or(input); - let end = start.strip_suffix('\n').unwrap_or(start); - end -} - -fn find_tag_open<'a>( - input: &'a str, - cursor: &mut usize, - expected_tag: &str, -) -> Result> { - let mut search_pos = *cursor; - - while search_pos < input.len() { - let Some(tag_start) = input[search_pos..].find("<") else { - break; - }; - let tag_start = search_pos + tag_start; - if !input[tag_start + 1..].starts_with(expected_tag) { - search_pos = search_pos + tag_start + 1; - continue; - }; - - let after_tag_name = tag_start + expected_tag.len() + 1; - let close_bracket = input[after_tag_name..] - .find('>') - .with_context(|| format!("missing > after <{}", expected_tag))?; - let attrs_end = after_tag_name + close_bracket; - let body_start = attrs_end + 1; - - let attributes = input[after_tag_name..attrs_end].trim(); - *cursor = body_start; - - return Ok(Some((body_start, attributes))); - } - - Ok(None) -} - -fn find_tag_close(input: &str, cursor: &mut usize) -> Result { - let mut depth = 1; - let mut search_pos = *cursor; - - while search_pos < input.len() && depth > 0 { - let Some(bracket_offset) = input[search_pos..].find('<') else { - break; - }; - let bracket_pos = search_pos + bracket_offset; - - if input[bracket_pos..].starts_with("') - { - let close_start = bracket_pos + 2; - let tag_name = input[close_start..close_start + close_end].trim(); - - if XML_TAGS.contains(&tag_name) { - depth -= 1; - if depth == 0 { - *cursor = close_start + close_end + 1; - return Ok(bracket_pos); - } - } - search_pos = close_start + close_end + 1; - continue; - } else if let Some(close_bracket_offset) = input[bracket_pos..].find('>') { - let close_bracket_pos = bracket_pos + close_bracket_offset; - let tag_name = &input[bracket_pos + 1..close_bracket_pos].trim(); - if XML_TAGS.contains(&tag_name) { - depth += 1; - } - } - - search_pos = bracket_pos + 1; - } - - anyhow::bail!("no closing tag found") -} - -const REPLACEMENT_COST: u32 = 1; -const INSERTION_COST: u32 = 3; -const DELETION_COST: u32 = 10; - -/// A fuzzy matcher that can process text chunks incrementally -/// and return the best match found so far at each step. -struct FuzzyMatcher<'a> { - snapshot: &'a BufferSnapshot, - query_lines: Vec<&'a str>, - matrix: SearchMatrix, -} - -impl<'a> FuzzyMatcher<'a> { - fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self { - let query_lines = old_text.lines().collect(); - Self { - snapshot, - query_lines, - matrix: SearchMatrix::new(0), - } - } - - fn match_range(&mut self, range: Range) -> Option<(u32, Range)> { - let point_range = range.to_point(&self.snapshot); - let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize; - - self.matrix - .reset(self.query_lines.len() + 1, buffer_line_count + 1); - let query_line_count = self.query_lines.len(); - - for row in 0..query_line_count { - let query_line = self.query_lines[row].trim(); - let leading_deletion_cost = (row + 1) as u32 * DELETION_COST; - - self.matrix.set( - row + 1, - 0, - SearchState::new(leading_deletion_cost, SearchDirection::Up), - ); - - let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines(); - - let mut col = 0; - while let Some(buffer_line) = buffer_lines.next() { - let buffer_line = buffer_line.trim(); - let up = SearchState::new( - self.matrix - .get(row, col + 1) - .cost - .saturating_add(DELETION_COST), - SearchDirection::Up, - ); - let left = SearchState::new( - self.matrix - .get(row + 1, col) - .cost - .saturating_add(INSERTION_COST), - SearchDirection::Left, - ); - let diagonal = SearchState::new( - if query_line == buffer_line { - self.matrix.get(row, col).cost - } else if fuzzy_eq(query_line, buffer_line) { - self.matrix.get(row, col).cost + REPLACEMENT_COST - } else { - self.matrix - .get(row, col) - .cost - .saturating_add(DELETION_COST + INSERTION_COST) - }, - SearchDirection::Diagonal, - ); - self.matrix - .set(row + 1, col + 1, up.min(left).min(diagonal)); - col += 1; - } - } - - // Find all matches with the best cost - let mut best_cost = u32::MAX; - let mut matches_with_best_cost = Vec::new(); - - for col in 1..=buffer_line_count { - let cost = self.matrix.get(query_line_count, col).cost; - if cost < best_cost { - best_cost = cost; - matches_with_best_cost.clear(); - matches_with_best_cost.push(col as u32); - } else if cost == best_cost { - matches_with_best_cost.push(col as u32); - } - } - - // Find ranges for the matches - for &match_end_col in &matches_with_best_cost { - let mut matched_lines = 0; - let mut query_row = query_line_count; - let mut match_start_col = match_end_col; - while query_row > 0 && match_start_col > 0 { - let current = self.matrix.get(query_row, match_start_col as usize); - match current.direction { - SearchDirection::Diagonal => { - query_row -= 1; - match_start_col -= 1; - matched_lines += 1; - } - SearchDirection::Up => { - query_row -= 1; - } - SearchDirection::Left => { - match_start_col -= 1; - } - } - } - - let buffer_row_start = match_start_col + point_range.start.row; - let buffer_row_end = match_end_col + point_range.start.row; - - let matched_buffer_row_count = buffer_row_end - buffer_row_start; - let matched_ratio = matched_lines as f32 - / (matched_buffer_row_count as f32).max(query_line_count as f32); - if matched_ratio >= 0.8 { - let buffer_start_ix = self - .snapshot - .point_to_offset(Point::new(buffer_row_start, 0)); - let buffer_end_ix = self.snapshot.point_to_offset(Point::new( - buffer_row_end - 1, - self.snapshot.line_len(buffer_row_end - 1), - )); - return Some((best_cost, buffer_start_ix..buffer_end_ix)); - } - } - - None - } -} - -fn fuzzy_eq(left: &str, right: &str) -> bool { - const THRESHOLD: f64 = 0.8; - - let min_levenshtein = left.len().abs_diff(right.len()); - let min_normalized_levenshtein = - 1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64); - if min_normalized_levenshtein < THRESHOLD { - return false; - } - - strsim::normalized_levenshtein(left, right) >= THRESHOLD -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] -enum SearchDirection { - Up, - Left, - Diagonal, -} - -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -struct SearchState { - cost: u32, - direction: SearchDirection, -} - -impl SearchState { - fn new(cost: u32, direction: SearchDirection) -> Self { - Self { cost, direction } - } -} - -struct SearchMatrix { - cols: usize, - rows: usize, - data: Vec, -} - -impl SearchMatrix { - fn new(cols: usize) -> Self { - SearchMatrix { - cols, - rows: 0, - data: Vec::new(), - } - } - - fn reset(&mut self, rows: usize, cols: usize) { - self.rows = rows; - self.cols = cols; - self.data - .fill(SearchState::new(0, SearchDirection::Diagonal)); - self.data.resize( - self.rows * self.cols, - SearchState::new(0, SearchDirection::Diagonal), - ); - } - - fn get(&self, row: usize, col: usize) -> SearchState { - debug_assert!(row < self.rows); - debug_assert!(col < self.cols); - self.data[row * self.cols + col] - } - - fn set(&mut self, row: usize, col: usize, state: SearchState) { - debug_assert!(row < self.rows && col < self.cols); - self.data[row * self.cols + col] = state; - } -} - -#[cfg(test)] -mod tests { - use super::*; - use gpui::TestAppContext; - use indoc::indoc; - use language::Point; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::SettingsStore; - use util::path; - - #[test] - fn test_extract_xml_edits() { - let input = indoc! {r#" - - - old content - - - new content - - - "#}; - - let result = extract_xml_replacements(input).unwrap(); - assert_eq!(result.file_path, "test.rs"); - assert_eq!(result.replacements.len(), 1); - assert_eq!(result.replacements[0].0, "old content"); - assert_eq!(result.replacements[0].1, "new content"); - } - - #[test] - fn test_extract_xml_edits_with_wrong_closing_tags() { - let input = indoc! {r#" - - - old content - - - new content - - - "#}; - - let result = extract_xml_replacements(input).unwrap(); - assert_eq!(result.file_path, "test.rs"); - assert_eq!(result.replacements.len(), 1); - assert_eq!(result.replacements[0].0, "old content"); - assert_eq!(result.replacements[0].1, "new content"); - } - - #[test] - fn test_extract_xml_edits_with_xml_like_content() { - let input = indoc! {r#" - - - - - - - - - "#}; - - let result = extract_xml_replacements(input).unwrap(); - assert_eq!(result.file_path, "component.tsx"); - assert_eq!(result.replacements.len(), 1); - assert_eq!(result.replacements[0].0, ""); - assert_eq!( - result.replacements[0].1, - "" - ); - } - - #[test] - fn test_extract_xml_edits_with_conflicting_content() { - let input = indoc! {r#" - - - - - - - - - "#}; - - let result = extract_xml_replacements(input).unwrap(); - assert_eq!(result.file_path, "component.tsx"); - assert_eq!(result.replacements.len(), 1); - assert_eq!(result.replacements[0].0, ""); - assert_eq!(result.replacements[0].1, ""); - } - - #[test] - fn test_extract_xml_edits_multiple_pairs() { - let input = indoc! {r#" - Some reasoning before edits. Lots of thinking going on here - - - - first old - - - first new - - - second old - - - second new - - - "#}; - - let result = extract_xml_replacements(input).unwrap(); - assert_eq!(result.file_path, "test.rs"); - assert_eq!(result.replacements.len(), 2); - assert_eq!(result.replacements[0].0, "first old"); - assert_eq!(result.replacements[0].1, "first new"); - assert_eq!(result.replacements[1].0, "second old"); - assert_eq!(result.replacements[1].1, "second new"); - } - - #[test] - fn test_extract_xml_edits_unexpected_eof() { - let input = indoc! {r#" - - - first old - - - nine ten eleven twelve - - - nine TEN eleven twelve! - - - "#}; - - let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)]; - let (buffer, edits) = parse_xml_edits(edits, |_path| { - Some((&buffer_snapshot, included_ranges.as_slice())) - }) - .await - .unwrap(); - - let edits = edits - .into_iter() - .map(|(range, text)| (range.to_point(&buffer), text)) - .collect::>(); - assert_eq!( - edits, - &[ - (Point::new(2, 5)..Point::new(2, 8), "TEN".into()), - (Point::new(2, 22)..Point::new(2, 22), "!".into()) - ] - ); - } - - fn init_test(cx: &mut TestAppContext) -> Arc { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - }); - - FakeFs::new(cx.background_executor.clone()) - } -} diff --git a/crates/edit_prediction/src/zeta1.rs b/crates/edit_prediction/src/zeta1.rs index ad630484d392d75849bd33a52a55e63ea77ca23f..ed531749cb39d10d71d18947990dd1972f23a986 100644 --- a/crates/edit_prediction/src/zeta1.rs +++ b/crates/edit_prediction/src/zeta1.rs @@ -1,22 +1,23 @@ use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant}; use crate::{ - EditPredictionId, EditPredictionStore, ZedUpdateRequiredError, + DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput, + EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError, cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count}, - prediction::{EditPredictionInputs, EditPredictionResult}, + prediction::EditPredictionResult, }; use anyhow::{Context as _, Result}; use cloud_llm_client::{ PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse, - predict_edits_v3::Event, }; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task}; use language::{ - Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff, + Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff, }; use project::{Project, ProjectPath}; use release_channel::AppVersion; use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification}; +use zeta_prompt::{Event, ZetaPromptInput}; const CURSOR_MARKER: &str = "<|user_cursor_is_here|>"; const START_OF_FILE_MARKER: &str = "<|start_of_file|>"; @@ -29,24 +30,27 @@ pub(crate) const MAX_EVENT_TOKENS: usize = 500; pub(crate) fn request_prediction_with_zeta1( store: &mut EditPredictionStore, - project: &Entity, - buffer: &Entity, - snapshot: BufferSnapshot, - position: language::Anchor, - events: Vec>, - trigger: PredictEditsRequestTrigger, + EditPredictionModelInput { + project, + buffer, + snapshot, + position, + events, + trigger, + debug_tx, + .. + }: EditPredictionModelInput, cx: &mut Context, ) -> Task>> { - let buffer = buffer.clone(); let buffer_snapshotted_at = Instant::now(); let client = store.client.clone(); let llm_token = store.llm_token.clone(); let app_version = AppVersion::global(cx); let (git_info, can_collect_file) = if let Some(file) = snapshot.file() { - let can_collect_file = store.can_collect_file(project, file, cx); + let can_collect_file = store.can_collect_file(&project, file, cx); let git_info = if can_collect_file { - git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx) + git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx) } else { None }; @@ -120,33 +124,33 @@ pub(crate) fn request_prediction_with_zeta1( ) .await; - let inputs = EditPredictionInputs { + let context_start_offset = context_range.start.to_offset(&snapshot); + let editable_offset_range = editable_range.to_offset(&snapshot); + + let inputs = ZetaPromptInput { events: included_events.into(), - included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile { - path: full_path.clone(), - max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row), - excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt { - start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row), - text: snapshot - .text_for_range(context_range) - .collect::() - .into(), - }], - }], - cursor_point: cloud_llm_client::predict_edits_v3::Point { - column: cursor_point.column, - line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row), - }, + related_files: vec![].into(), cursor_path: full_path, + cursor_excerpt: snapshot + .text_for_range(context_range) + .collect::() + .into(), + editable_range_in_excerpt: (editable_range.start - context_start_offset) + ..(editable_offset_range.end - context_start_offset), + cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset, }; - // let response = perform_predict_edits(PerformPredictEditsParams { - // client, - // llm_token, - // app_version, - // body, - // }) - // .await; + if let Some(debug_tx) = &debug_tx { + debug_tx + .unbounded_send(DebugEvent::EditPredictionStarted( + EditPredictionStartedDebugEvent { + buffer: buffer.downgrade(), + prompt: Some(serde_json::to_string(&inputs).unwrap()), + position, + }, + )) + .ok(); + } let (response, usage) = match response { Ok(response) => response, @@ -189,6 +193,18 @@ pub(crate) fn request_prediction_with_zeta1( .ok(); } + if let Some(debug_tx) = &debug_tx { + debug_tx + .unbounded_send(DebugEvent::EditPredictionFinished( + EditPredictionFinishedDebugEvent { + buffer: buffer.downgrade(), + model_output: Some(response.output_excerpt.clone()), + position, + }, + )) + .ok(); + } + let edit_prediction = process_completion_response( response, buffer, @@ -226,7 +242,7 @@ fn process_completion_response( buffer: Entity, snapshot: &BufferSnapshot, editable_range: Range, - inputs: EditPredictionInputs, + inputs: ZetaPromptInput, buffer_snapshotted_at: Instant, received_response_at: Instant, cx: &AsyncApp, diff --git a/crates/edit_prediction/src/zeta2.rs b/crates/edit_prediction/src/zeta2.rs index e542bc7e86e6e381766bbedac6a15f431e0693f1..034954f5760939fc31b3e5e1e8a09737c5b2e568 100644 --- a/crates/edit_prediction/src/zeta2.rs +++ b/crates/edit_prediction/src/zeta2.rs @@ -3,46 +3,39 @@ use crate::EvalCacheEntryKind; use crate::open_ai_response::text_from_response; use crate::prediction::EditPredictionResult; use crate::{ - DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs, - EditPredictionRequestedDebugEvent, EditPredictionStore, + DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, EditPredictionId, + EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, }; -use anyhow::{Result, anyhow, bail}; -use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat}; -use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger}; -use cloud_zeta2_prompt::CURSOR_MARKER; -use edit_prediction_context::{EditPredictionExcerpt, Line}; -use edit_prediction_context::{RelatedExcerpt, RelatedFile}; -use futures::channel::oneshot; -use gpui::{Entity, Task, prelude::*}; -use language::{Anchor, BufferSnapshot}; -use language::{Buffer, Point, ToOffset as _, ToPoint}; -use project::{Project, ProjectItem as _}; +use anyhow::{Result, anyhow}; +use cloud_llm_client::EditPredictionRejectReason; +use gpui::{Task, prelude::*}; +use language::{OffsetRangeExt as _, ToOffset as _, ToPoint}; use release_channel::AppVersion; -use std::{ - env, - path::Path, - sync::Arc, - time::{Duration, Instant}, -}; +use std::{path::Path, sync::Arc, time::Instant}; +use zeta_prompt::CURSOR_MARKER; +use zeta_prompt::format_zeta_prompt; + +const MAX_CONTEXT_TOKENS: usize = 150; +const MAX_REWRITE_TOKENS: usize = 350; pub fn request_prediction_with_zeta2( store: &mut EditPredictionStore, - project: &Entity, - active_buffer: &Entity, - active_snapshot: BufferSnapshot, - position: Anchor, - events: Vec>, - mut included_files: Vec, - trigger: PredictEditsRequestTrigger, + EditPredictionModelInput { + buffer, + snapshot, + position, + related_files, + events, + debug_tx, + .. + }: EditPredictionModelInput, cx: &mut Context, ) -> Task>> { - let options = store.options.clone(); let buffer_snapshotted_at = Instant::now(); - let Some((excerpt_path, active_project_path)) = active_snapshot + let Some(excerpt_path) = snapshot .file() .map(|file| -> Arc { file.full_path(cx).into() }) - .zip(active_buffer.read(cx).project_path(cx)) else { return Task::ready(Err(anyhow!("No file path for excerpt"))); }; @@ -50,148 +43,35 @@ pub fn request_prediction_with_zeta2( let client = store.client.clone(); let llm_token = store.llm_token.clone(); let app_version = AppVersion::global(cx); - let debug_tx = store.debug_tx.clone(); - - let file = active_buffer.read(cx).file(); - - let active_file_full_path = file.as_ref().map(|f| f.full_path(cx)); - - // TODO data collection - let can_collect_data = file - .as_ref() - .map_or(false, |file| store.can_collect_file(project, file, cx)); #[cfg(feature = "eval-support")] let eval_cache = store.eval_cache.clone(); let request_task = cx.background_spawn({ - let active_buffer = active_buffer.clone(); async move { - let cursor_offset = position.to_offset(&active_snapshot); - let cursor_point = cursor_offset.to_point(&active_snapshot); - - let before_retrieval = Instant::now(); - - let excerpt_options = options.context; - - let Some(excerpt) = EditPredictionExcerpt::select_from_buffer( - cursor_point, - &active_snapshot, - &excerpt_options, - ) else { - return Ok((None, None)); - }; - - let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start) - ..active_snapshot.anchor_before(excerpt.range.end); - let related_excerpt = RelatedExcerpt { - anchor_range: excerpt_anchor_range.clone(), - point_range: Point::new(excerpt.line_range.start.0, 0) - ..Point::new(excerpt.line_range.end.0, 0), - text: active_snapshot.as_rope().slice(excerpt.range), - }; - - if let Some(buffer_ix) = included_files - .iter() - .position(|file| file.buffer.entity_id() == active_buffer.entity_id()) - { - let file = &mut included_files[buffer_ix]; - file.excerpts.push(related_excerpt); - file.merge_excerpts(); - let last_ix = included_files.len() - 1; - included_files.swap(buffer_ix, last_ix); - } else { - let active_file = RelatedFile { - path: active_project_path, - buffer: active_buffer.downgrade(), - excerpts: vec![related_excerpt], - max_row: active_snapshot.max_point().row, - }; - included_files.push(active_file); - } - - let included_files = included_files - .iter() - .map(|related_file| predict_edits_v3::RelatedFile { - path: Arc::from(related_file.path.path.as_std_path()), - max_row: Line(related_file.max_row), - excerpts: related_file - .excerpts - .iter() - .map(|excerpt| predict_edits_v3::Excerpt { - start_line: Line(excerpt.point_range.start.row), - text: excerpt.text.to_string().into(), - }) - .collect(), - }) - .collect::>(); - - let cloud_request = predict_edits_v3::PredictEditsRequest { - excerpt_path, - excerpt: String::new(), - excerpt_line_range: Line(0)..Line(0), - excerpt_range: 0..0, - cursor_point: predict_edits_v3::Point { - line: predict_edits_v3::Line(cursor_point.row), - column: cursor_point.column, - }, - related_files: included_files, + let cursor_offset = position.to_offset(&snapshot); + let (editable_offset_range, prompt_input) = zeta2_prompt_input( + &snapshot, + related_files, events, - can_collect_data, - debug_info: debug_tx.is_some(), - prompt_max_bytes: Some(options.max_prompt_bytes), - prompt_format: options.prompt_format, - excerpt_parent: None, - git_info: None, - trigger, - }; - - let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request); - - let inputs = EditPredictionInputs { - included_files: cloud_request.related_files, - events: cloud_request.events, - cursor_point: cloud_request.cursor_point, - cursor_path: cloud_request.excerpt_path, - }; - - let retrieval_time = Instant::now() - before_retrieval; + excerpt_path, + cursor_offset, + ); - let debug_response_tx = if let Some(debug_tx) = &debug_tx { - let (response_tx, response_rx) = oneshot::channel(); + let prompt = format_zeta_prompt(&prompt_input); + if let Some(debug_tx) = &debug_tx { debug_tx - .unbounded_send(DebugEvent::EditPredictionRequested( - EditPredictionRequestedDebugEvent { - inputs: inputs.clone(), - retrieval_time, - buffer: active_buffer.downgrade(), - local_prompt: match prompt_result.as_ref() { - Ok(prompt) => Ok(prompt.clone()), - Err(err) => Err(err.to_string()), - }, + .unbounded_send(DebugEvent::EditPredictionStarted( + EditPredictionStartedDebugEvent { + buffer: buffer.downgrade(), + prompt: Some(prompt.clone()), position, - response_rx, }, )) .ok(); - Some(response_tx) - } else { - None - }; - - if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() { - if let Some(debug_response_tx) = debug_response_tx { - debug_response_tx - .send((Err("Request skipped".to_string()), Duration::ZERO)) - .ok(); - } - anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set") } - let prompt = prompt_result?; - let generation_params = - cloud_zeta2_prompt::generation_params(cloud_request.prompt_format); let request = open_ai::Request { model: EDIT_PREDICTIONS_MODEL_ID.clone(), messages: vec![open_ai::RequestMessage::User { @@ -199,8 +79,8 @@ pub fn request_prediction_with_zeta2( }], stream: false, max_completion_tokens: None, - stop: generation_params.stop.unwrap_or_default(), - temperature: generation_params.temperature.or(Some(0.7)), + stop: Default::default(), + temperature: Default::default(), tool_choice: None, parallel_tool_calls: None, tools: vec![], @@ -210,7 +90,6 @@ pub fn request_prediction_with_zeta2( log::trace!("Sending edit prediction request"); - let before_request = Instant::now(); let response = EditPredictionStore::send_raw_llm_request( request, client, @@ -223,68 +102,53 @@ pub fn request_prediction_with_zeta2( ) .await; let received_response_at = Instant::now(); - let request_time = received_response_at - before_request; log::trace!("Got edit prediction response"); - if let Some(debug_response_tx) = debug_response_tx { - debug_response_tx - .send(( - response - .as_ref() - .map_err(|err| err.to_string()) - .map(|response| response.0.clone()), - request_time, - )) - .ok(); - } - let (res, usage) = response?; let request_id = EditPredictionId(res.id.clone().into()); let Some(mut output_text) = text_from_response(res) else { return Ok((Some((request_id, None)), usage)); }; + if let Some(debug_tx) = &debug_tx { + debug_tx + .unbounded_send(DebugEvent::EditPredictionFinished( + EditPredictionFinishedDebugEvent { + buffer: buffer.downgrade(), + position, + model_output: Some(output_text.clone()), + }, + )) + .ok(); + } + if output_text.contains(CURSOR_MARKER) { log::trace!("Stripping out {CURSOR_MARKER} from response"); output_text = output_text.replace(CURSOR_MARKER, ""); } - let get_buffer_from_context = |path: &Path| { - if Some(path) == active_file_full_path.as_deref() { - Some(( - &active_snapshot, - std::slice::from_ref(&excerpt_anchor_range), - )) - } else { - None - } - }; - - let (_, edits) = match options.prompt_format { - PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => { - if output_text.contains("--- a/\n+++ b/\nNo edits") { - let edits = vec![]; - (&active_snapshot, edits) - } else { - crate::udiff::parse_diff(&output_text, get_buffer_from_context).await? - } - } - PromptFormat::OldTextNewText => { - crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await? - } - _ => { - bail!("unsupported prompt format {}", options.prompt_format) - } - }; + let old_text = snapshot + .text_for_range(editable_offset_range.clone()) + .collect::(); + let edits: Vec<_> = language::text_diff(&old_text, &output_text) + .into_iter() + .map(|(range, text)| { + ( + snapshot.anchor_after(editable_offset_range.start + range.start) + ..snapshot.anchor_before(editable_offset_range.start + range.end), + text, + ) + }) + .collect(); anyhow::Ok(( Some(( request_id, Some(( - inputs, - active_buffer, - active_snapshot.clone(), + prompt_input, + buffer, + snapshot.clone(), edits, received_response_at, )), @@ -325,3 +189,40 @@ pub fn request_prediction_with_zeta2( )) }) } + +pub fn zeta2_prompt_input( + snapshot: &language::BufferSnapshot, + related_files: Arc<[zeta_prompt::RelatedFile]>, + events: Vec>, + excerpt_path: Arc, + cursor_offset: usize, +) -> (std::ops::Range, zeta_prompt::ZetaPromptInput) { + let cursor_point = cursor_offset.to_point(snapshot); + + let (editable_range, context_range) = + crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position( + cursor_point, + snapshot, + MAX_CONTEXT_TOKENS, + MAX_REWRITE_TOKENS, + ); + + let context_start_offset = context_range.start.to_offset(snapshot); + let editable_offset_range = editable_range.to_offset(snapshot); + let cursor_offset_in_excerpt = cursor_offset - context_start_offset; + let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset) + ..(editable_offset_range.end - context_start_offset); + + let prompt_input = zeta_prompt::ZetaPromptInput { + cursor_path: excerpt_path, + cursor_excerpt: snapshot + .text_for_range(context_range) + .collect::() + .into(), + editable_range_in_excerpt, + cursor_offset_in_excerpt, + events, + related_files, + }; + (editable_offset_range, prompt_input) +} diff --git a/crates/edit_prediction_cli/Cargo.toml b/crates/edit_prediction_cli/Cargo.toml index 26a060994d75a2c194cc159c33d88fbc296dfa47..0e7fff8d70156c58147069f8da64035d6a80adc8 100644 --- a/crates/edit_prediction_cli/Cargo.toml +++ b/crates/edit_prediction_cli/Cargo.toml @@ -9,7 +9,7 @@ license = "GPL-3.0-or-later" workspace = true [[bin]] -name = "ep_cli" +name = "ep" path = "src/main.rs" [dependencies] @@ -20,10 +20,9 @@ chrono.workspace = true clap.workspace = true client.workspace = true cloud_llm_client.workspace= true -cloud_zeta2_prompt.workspace = true collections.workspace = true debug_adapter_extension.workspace = true -edit_prediction_context.workspace = true +dirs.workspace = true extension.workspace = true fs.workspace = true futures.workspace = true @@ -51,12 +50,21 @@ smol.workspace = true sqlez.workspace = true sqlez_macros.workspace = true terminal_view.workspace = true -toml.workspace = true util.workspace = true watch.workspace = true edit_prediction = { workspace = true, features = ["eval-support"] } +wasmtime.workspace = true +zeta_prompt.workspace = true zlog.workspace = true +# Wasmtime is included as a dependency in order to enable the same +# features that are enabled in Zed. +# +# If we don't enable these features we get crashes when creating +# a Tree-sitter WasmStore. +[package.metadata.cargo-machete] +ignored = ["wasmtime"] + [dev-dependencies] indoc.workspace = true gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/edit_prediction_cli/src/training/llm_client.rs b/crates/edit_prediction_cli/src/anthropic_client.rs similarity index 89% rename from crates/edit_prediction_cli/src/training/llm_client.rs rename to crates/edit_prediction_cli/src/anthropic_client.rs index ebecbe915d36a9a456296e818e559c654370f939..8afc4d1c03f8a37ae258cc2926daf85caebe3d8a 100644 --- a/crates/edit_prediction_cli/src/training/llm_client.rs +++ b/crates/edit_prediction_cli/src/anthropic_client.rs @@ -5,11 +5,13 @@ use anthropic::{ use anyhow::Result; use http_client::HttpClient; use indoc::indoc; +use reqwest_client::ReqwestClient; use sqlez::bindable::Bind; use sqlez::bindable::StaticColumnCount; use sqlez_macros::sql; use std::hash::Hash; use std::hash::Hasher; +use std::path::Path; use std::sync::Arc; pub struct PlainLlmClient { @@ -18,7 +20,8 @@ pub struct PlainLlmClient { } impl PlainLlmClient { - fn new(http_client: Arc) -> Result { + fn new() -> Result { + let http_client: Arc = Arc::new(ReqwestClient::new()); let api_key = std::env::var("ANTHROPIC_API_KEY") .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?; Ok(Self { @@ -29,12 +32,12 @@ impl PlainLlmClient { async fn generate( &self, - model: String, + model: &str, max_tokens: u64, messages: Vec, ) -> Result { let request = AnthropicRequest { - model, + model: model.to_string(), max_tokens, messages, tools: Vec::new(), @@ -105,11 +108,12 @@ struct SerializableMessage { } impl BatchingLlmClient { - fn new(cache_path: &str, http_client: Arc) -> Result { + fn new(cache_path: &Path) -> Result { + let http_client: Arc = Arc::new(ReqwestClient::new()); let api_key = std::env::var("ANTHROPIC_API_KEY") .map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?; - let connection = sqlez::connection::Connection::open_file(&cache_path); + let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap()); let mut statement = sqlez::statement::Statement::prepare( &connection, indoc! {" @@ -182,16 +186,16 @@ impl BatchingLlmClient { async fn generate( &self, - model: String, + model: &str, max_tokens: u64, messages: Vec, ) -> Result> { - let response = self.lookup(&model, max_tokens, &messages)?; + let response = self.lookup(model, max_tokens, &messages)?; if let Some(response) = response { return Ok(Some(response)); } - self.mark_for_batch(&model, max_tokens, &messages)?; + self.mark_for_batch(model, max_tokens, &messages)?; Ok(None) } @@ -258,7 +262,7 @@ impl BatchingLlmClient { } } } - log::info!("Uploaded {} successful requests", success_count); + log::info!("Downloaded {} successful requests", success_count); } } @@ -363,23 +367,20 @@ fn message_content_to_string(content: &[RequestContent]) -> String { .join("\n") } -pub enum LlmClient { +pub enum AnthropicClient { // No batching Plain(PlainLlmClient), Batch(BatchingLlmClient), Dummy, } -impl LlmClient { - pub fn plain(http_client: Arc) -> Result { - Ok(Self::Plain(PlainLlmClient::new(http_client)?)) +impl AnthropicClient { + pub fn plain() -> Result { + Ok(Self::Plain(PlainLlmClient::new()?)) } - pub fn batch(cache_path: &str, http_client: Arc) -> Result { - Ok(Self::Batch(BatchingLlmClient::new( - cache_path, - http_client, - )?)) + pub fn batch(cache_path: &Path) -> Result { + Ok(Self::Batch(BatchingLlmClient::new(cache_path)?)) } #[allow(dead_code)] @@ -389,29 +390,29 @@ impl LlmClient { pub async fn generate( &self, - model: String, + model: &str, max_tokens: u64, messages: Vec, ) -> Result> { match self { - LlmClient::Plain(plain_llm_client) => plain_llm_client + AnthropicClient::Plain(plain_llm_client) => plain_llm_client .generate(model, max_tokens, messages) .await .map(Some), - LlmClient::Batch(batching_llm_client) => { + AnthropicClient::Batch(batching_llm_client) => { batching_llm_client .generate(model, max_tokens, messages) .await } - LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"), + AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"), } } pub async fn sync_batches(&self) -> Result<()> { match self { - LlmClient::Plain(_) => Ok(()), - LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await, - LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"), + AnthropicClient::Plain(_) => Ok(()), + AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await, + AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"), } } } diff --git a/crates/edit_prediction_cli/src/evaluate.rs b/crates/edit_prediction_cli/src/evaluate.rs deleted file mode 100644 index 686c8ce7e7865f265d6bf17e51ca9477194e5252..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_cli/src/evaluate.rs +++ /dev/null @@ -1,641 +0,0 @@ -use crate::metrics::{self, Scores}; -use std::{ - collections::HashMap, - io::{IsTerminal, Write}, - sync::Arc, -}; - -use anyhow::Result; -use edit_prediction::{EditPredictionStore, udiff::DiffLine}; -use gpui::{AsyncApp, Entity}; -use project::Project; -use util::ResultExt as _; - -use crate::{ - EvaluateArguments, PredictionOptions, - example::{Example, NamedExample}, - headless::ZetaCliAppState, - paths::print_run_data_dir, - predict::{PredictionDetails, perform_predict, setup_store}, -}; - -#[derive(Debug)] -pub(crate) struct ExecutionData { - execution_id: String, - diff: String, - reasoning: String, -} - -pub async fn run_evaluate( - args: EvaluateArguments, - app_state: &Arc, - cx: &mut AsyncApp, -) { - if args.example_paths.is_empty() { - eprintln!("No examples provided"); - return; - } - - let all_tasks = args.example_paths.into_iter().map(|path| { - let options = args.options.clone(); - let app_state = app_state.clone(); - let example = NamedExample::load(&path).expect("Failed to load example"); - - cx.spawn(async move |cx| { - let project = example.setup_project(&app_state, cx).await.unwrap(); - - let providers = (0..args.repetitions) - .map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap()) - .collect::>(); - - let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap(); - - let tasks = providers - .into_iter() - .enumerate() - .map(move |(repetition_ix, store)| { - let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16); - let example = example.clone(); - let project = project.clone(); - let options = options.clone(); - - cx.spawn(async move |cx| { - let name = example.name.clone(); - run_evaluate_one( - example, - repetition_ix, - project, - store, - options, - !args.skip_prediction, - cx, - ) - .await - .map_err(|err| (err, name, repetition_ix)) - }) - }); - futures::future::join_all(tasks).await - }) - }); - let all_results = futures::future::join_all(all_tasks).await; - - write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap(); - if let Some(mut output_file) = - std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err() - { - write_aggregated_scores(&mut output_file, &all_results).log_err(); - }; - - if args.repetitions > 1 { - if let Err(e) = write_bucketed_analysis(&all_results) { - eprintln!("Failed to write bucketed analysis: {:?}", e); - } - } - - print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal()); -} - -fn write_aggregated_scores( - w: &mut impl std::io::Write, - all_results: &Vec< - Vec)>>, - >, -) -> Result<()> { - let mut successful = Vec::new(); - let mut failed_count = 0; - - for result in all_results.iter().flatten() { - match result { - Ok((eval_result, _execution_data)) => successful.push(eval_result), - Err((err, name, repetition_ix)) => { - if failed_count == 0 { - writeln!(w, "## Errors\n")?; - } - - failed_count += 1; - writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?; - } - } - } - - if successful.len() > 1 { - let edit_scores = successful - .iter() - .filter_map(|r| r.edit_scores.clone()) - .collect::>(); - let has_edit_predictions = edit_scores.len() > 0; - let aggregated_result = EvaluationResult { - context_scores: Scores::aggregate(successful.iter().map(|r| &r.context_scores)), - edit_scores: has_edit_predictions.then(|| EditScores::aggregate(&edit_scores)), - prompt_len: successful.iter().map(|r| r.prompt_len).sum::() / successful.len(), - generated_len: successful.iter().map(|r| r.generated_len).sum::() - / successful.len(), - }; - - writeln!(w, "\n{}", "-".repeat(80))?; - writeln!(w, "\n## TOTAL SCORES")?; - writeln!(w, "{:#}", aggregated_result)?; - } - - if successful.len() + failed_count > 1 { - writeln!( - w, - "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉", - successful.len(), - successful.len() + failed_count, - (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0 - )?; - } - - Ok(()) -} - -pub async fn run_evaluate_one( - example: NamedExample, - repetition_ix: Option, - project: Entity, - store: Entity, - prediction_options: PredictionOptions, - predict: bool, - cx: &mut AsyncApp, -) -> Result<(EvaluationResult, ExecutionData)> { - let predict_result = perform_predict( - example.clone(), - project, - store, - repetition_ix, - prediction_options, - cx, - ) - .await?; - - let evaluation_result = evaluate(&example.example, &predict_result, predict); - - if repetition_ix.is_none() { - write_eval_result( - &example, - &predict_result, - &evaluation_result, - &mut std::io::stdout(), - std::io::stdout().is_terminal(), - predict, - )?; - } - - if let Some(mut results_file) = - std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err() - { - write_eval_result( - &example, - &predict_result, - &evaluation_result, - &mut results_file, - false, - predict, - ) - .log_err(); - } - - let execution_data = ExecutionData { - execution_id: if let Some(rep_ix) = repetition_ix { - format!("{:03}", rep_ix) - } else { - example.name.clone() - }, - diff: predict_result.diff.clone(), - reasoning: std::fs::read_to_string( - predict_result - .run_example_dir - .join("prediction_response.md"), - ) - .unwrap_or_default(), - }; - - anyhow::Ok((evaluation_result, execution_data)) -} - -fn write_eval_result( - example: &NamedExample, - predictions: &PredictionDetails, - evaluation_result: &EvaluationResult, - out: &mut impl Write, - use_color: bool, - predict: bool, -) -> Result<()> { - if predict { - writeln!( - out, - "## Expected edit prediction:\n\n```diff\n{}\n```\n", - compare_diffs( - &example.example.expected_patch, - &predictions.diff, - use_color - ) - )?; - writeln!( - out, - "## Actual edit prediction:\n\n```diff\n{}\n```\n", - compare_diffs( - &predictions.diff, - &example.example.expected_patch, - use_color - ) - )?; - } - - writeln!(out, "{:#}", evaluation_result)?; - - anyhow::Ok(()) -} - -#[derive(Debug, Default, Clone)] -pub struct EditScores { - pub line_match: Scores, - pub chr_f: f64, -} - -impl EditScores { - pub fn aggregate(scores: &[EditScores]) -> EditScores { - let line_match = Scores::aggregate(scores.iter().map(|s| &s.line_match)); - let chr_f = scores.iter().map(|s| s.chr_f).sum::() / scores.len() as f64; - - EditScores { line_match, chr_f } - } -} - -#[derive(Debug, Default)] -pub struct EvaluationResult { - pub edit_scores: Option, - pub context_scores: Scores, - pub prompt_len: usize, - pub generated_len: usize, -} - -impl std::fmt::Display for EvaluationResult { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if f.alternate() { - self.fmt_table(f) - } else { - self.fmt_markdown(f) - } - } -} - -impl EvaluationResult { - fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - r#" -### Context Scores -{} -"#, - self.context_scores.to_markdown(), - )?; - if let Some(scores) = &self.edit_scores { - write!( - f, - r#" - ### Edit Prediction Scores - {}"#, - scores.line_match.to_markdown() - )?; - } - Ok(()) - } - - fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - writeln!(f, "#### Prompt Statistics")?; - writeln!(f, "─────────────────────────")?; - writeln!(f, "Prompt_len Generated_len")?; - writeln!(f, "─────────────────────────")?; - writeln!(f, "{:<11} {:<14}", self.prompt_len, self.generated_len,)?; - writeln!(f)?; - writeln!(f)?; - writeln!(f, "#### Performance Scores")?; - writeln!( - f, - "──────────────────────────────────────────────────────────────────" - )?; - writeln!( - f, - " TP FP FN Precision Recall F1" - )?; - writeln!( - f, - "──────────────────────────────────────────────────────────────────" - )?; - writeln!( - f, - "Context Retrieval {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}", - self.context_scores.true_positives, - self.context_scores.false_positives, - self.context_scores.false_negatives, - self.context_scores.precision() * 100.0, - self.context_scores.recall() * 100.0, - self.context_scores.f1_score() * 100.0 - )?; - if let Some(edit_scores) = &self.edit_scores { - let line_match = &edit_scores.line_match; - writeln!(f, "Edit Prediction")?; - writeln!( - f, - " ├─ exact lines {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}", - line_match.true_positives, - line_match.false_positives, - line_match.false_negatives, - line_match.precision() * 100.0, - line_match.recall() * 100.0, - line_match.f1_score() * 100.0 - )?; - writeln!( - f, - " └─ diff chrF {:<6} {:<6} {:<6} {:>8} {:>8} {:>6.2}", - "-", "-", "-", "-", "-", edit_scores.chr_f - )?; - } - Ok(()) - } -} - -fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult { - let mut eval_result = EvaluationResult { - prompt_len: preds.prompt_len, - generated_len: preds.generated_len, - ..Default::default() - }; - - if predict { - // todo: alternatives for patches - let expected_patch = example - .expected_patch - .lines() - .map(DiffLine::parse) - .collect::>(); - let actual_patch = preds.diff.lines().map(DiffLine::parse).collect::>(); - - let line_match = metrics::line_match_score(&expected_patch, &actual_patch); - let chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch); - - eval_result.edit_scores = Some(EditScores { line_match, chr_f }); - } - - eval_result -} - -/// Return annotated `patch_a` so that: -/// Additions and deletions that are not present in `patch_b` will be highlighted in red. -/// Additions and deletions that are present in `patch_b` will be highlighted in green. -pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String { - let green = if use_color { "\x1b[32m✓ " } else { "" }; - let red = if use_color { "\x1b[31m✗ " } else { "" }; - let neutral = if use_color { " " } else { "" }; - let reset = if use_color { "\x1b[0m" } else { "" }; - let lines_a = patch_a.lines().map(DiffLine::parse); - let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect(); - - let annotated = lines_a - .map(|line| match line { - DiffLine::Addition(_) | DiffLine::Deletion(_) => { - if lines_b.contains(&line) { - format!("{green}{line}{reset}") - } else { - format!("{red}{line}{reset}") - } - } - _ => format!("{neutral}{line}{reset}"), - }) - .collect::>(); - - annotated.join("\n") -} - -fn write_bucketed_analysis( - all_results: &Vec< - Vec)>>, - >, -) -> Result<()> { - #[derive(Debug)] - struct EditBucket { - diff: String, - is_correct: bool, - execution_indices: Vec, - reasoning_samples: Vec, - } - - let mut total_executions = 0; - let mut empty_predictions = Vec::new(); - let mut errors = Vec::new(); - - let mut buckets: HashMap = HashMap::new(); - - for result in all_results.iter().flatten() { - total_executions += 1; - - let (evaluation_result, execution_data) = match result { - Ok((eval_result, execution_data)) => { - if execution_data.diff.is_empty() { - empty_predictions.push(execution_data); - continue; - } - (eval_result, execution_data) - } - Err(err) => { - errors.push(err); - continue; - } - }; - - buckets - .entry(execution_data.diff.clone()) - .and_modify(|bucket| { - bucket - .execution_indices - .push(execution_data.execution_id.clone()); - bucket - .reasoning_samples - .push(execution_data.reasoning.clone()); - }) - .or_insert_with(|| EditBucket { - diff: execution_data.diff.clone(), - is_correct: { - evaluation_result - .edit_scores - .as_ref() - .map_or(false, |edit_scores| { - edit_scores.line_match.false_positives == 0 - && edit_scores.line_match.false_negatives == 0 - && edit_scores.line_match.true_positives > 0 - }) - }, - execution_indices: vec![execution_data.execution_id.clone()], - reasoning_samples: vec![execution_data.reasoning.clone()], - }); - } - - let mut sorted_buckets = buckets.into_values().collect::>(); - sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) { - (true, false) => std::cmp::Ordering::Less, - (false, true) => std::cmp::Ordering::Greater, - _ => b.execution_indices.len().cmp(&a.execution_indices.len()), - }); - - let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md"); - let mut output = std::fs::File::create(&output_path)?; - - writeln!(output, "# Bucketed Edit Analysis\n")?; - - writeln!(output, "## Summary\n")?; - writeln!(output, "- **Total executions**: {}", total_executions)?; - - let correct_count: usize = sorted_buckets - .iter() - .filter(|b| b.is_correct) - .map(|b| b.execution_indices.len()) - .sum(); - - let incorrect_count: usize = sorted_buckets - .iter() - .filter(|b| !b.is_correct) - .map(|b| b.execution_indices.len()) - .sum(); - - writeln!( - output, - "- **Correct predictions**: {} ({:.1}%)", - correct_count, - (correct_count as f64 / total_executions as f64) * 100.0 - )?; - - writeln!( - output, - "- **Incorrect predictions**: {} ({:.1}%)", - incorrect_count, - (incorrect_count as f64 / total_executions as f64) * 100.0 - )?; - - writeln!( - output, - "- **No Predictions**: {} ({:.1}%)", - empty_predictions.len(), - (empty_predictions.len() as f64 / total_executions as f64) * 100.0 - )?; - - let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count(); - writeln!( - output, - "- **Unique incorrect edit patterns**: {}\n", - unique_incorrect - )?; - - writeln!(output, "---\n")?; - - for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() { - if idx == 0 { - writeln!( - output, - "## Correct Predictions ({} occurrences)\n", - bucket.execution_indices.len() - )?; - } - - writeln!(output, "**Predicted Edit:**\n")?; - writeln!(output, "```diff")?; - writeln!(output, "{}", bucket.diff)?; - writeln!(output, "```\n")?; - - writeln!( - output, - "**Executions:** {}\n", - bucket.execution_indices.join(", ") - )?; - writeln!(output, "---\n")?; - } - - for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() { - writeln!( - output, - "## Incorrect Prediction #{} ({} occurrences)\n", - idx + 1, - bucket.execution_indices.len() - )?; - - writeln!(output, "**Predicted Edit:**\n")?; - writeln!(output, "```diff")?; - writeln!(output, "{}", bucket.diff)?; - writeln!(output, "```\n")?; - - writeln!( - output, - "**Executions:** {}\n", - bucket.execution_indices.join(", ") - )?; - - for (exec_id, reasoning) in bucket - .execution_indices - .iter() - .zip(bucket.reasoning_samples.iter()) - { - writeln!(output, "{}", fmt_execution(exec_id, reasoning))?; - } - - writeln!(output, "\n---\n")?; - } - - if !empty_predictions.is_empty() { - writeln!( - output, - "## No Predictions ({} occurrences)\n", - empty_predictions.len() - )?; - - for execution_data in &empty_predictions { - writeln!( - output, - "{}", - fmt_execution(&execution_data.execution_id, &execution_data.reasoning) - )?; - } - writeln!(output, "\n---\n")?; - } - - if !errors.is_empty() { - writeln!(output, "## Errors ({} occurrences)\n", errors.len())?; - - for (err, name, repetition_ix) in &errors { - writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?; - } - writeln!(output, "\n---\n")?; - } - - fn fmt_execution(exec_id: &str, reasoning: &str) -> String { - let exec_content = format!( - "\n### Execution {} `{}/{}/prediction_response.md`{}", - exec_id, - crate::paths::RUN_DIR.display(), - exec_id, - indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2) - ); - indent_text(&exec_content, 2) - } - - fn indent_text(text: &str, spaces: usize) -> String { - let indent = " ".repeat(spaces); - text.lines() - .collect::>() - .join(&format!("\n{}", indent)) - } - - Ok(()) -} - -fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option) -> String { - let err = format!("{err:?}") - .replace("", "\n```"); - format!( - "### ERROR {name}{}\n\n{err}\n", - repetition_ix - .map(|ix| format!(" [RUN {ix:03}]")) - .unwrap_or_default() - ) -} diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index 4f8c1867cd57d7fb5dbb9c2c08b63dccf2b97d30..a13b339ae69b9584f3b47186d8b6c36f458a2b76 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -1,59 +1,103 @@ +use crate::{ + PredictionProvider, PromptFormat, + metrics::ClassificationMetrics, + paths::{REPOS_DIR, WORKTREES_DIR}, +}; +use anyhow::{Context as _, Result}; +use edit_prediction::udiff::OpenedBuffers; +use gpui::Entity; +use http_client::Url; +use language::{Anchor, Buffer}; +use project::Project; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; use std::{ borrow::Cow, - cell::RefCell, - fmt::{self, Display}, - fs, - hash::Hash, - hash::Hasher, - io::Write, + io::{Read, Write}, mem, path::{Path, PathBuf}, - sync::{Arc, OnceLock}, }; +use zeta_prompt::RelatedFile; -use crate::headless::ZetaCliAppState; -use anyhow::{Context as _, Result, anyhow}; -use clap::ValueEnum; -use cloud_zeta2_prompt::CURSOR_MARKER; -use collections::HashMap; -use edit_prediction::udiff::OpenedBuffers; -use futures::{ - AsyncWriteExt as _, - lock::{Mutex, OwnedMutexGuard}, -}; -use futures::{FutureExt as _, future::Shared}; -use gpui::{AsyncApp, Entity, Task, http_client::Url}; -use language::{Anchor, Buffer}; -use project::{Project, ProjectPath}; -use pulldown_cmark::CowStr; -use serde::{Deserialize, Serialize}; -use util::{paths::PathStyle, rel_path::RelPath}; - -use crate::paths::{REPOS_DIR, WORKTREES_DIR}; - -const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff"; -const EDIT_HISTORY_HEADING: &str = "Edit History"; -const CURSOR_POSITION_HEADING: &str = "Cursor Position"; -const EXPECTED_PATCH_HEADING: &str = "Expected Patch"; -const EXPECTED_CONTEXT_HEADING: &str = "Expected Context"; -const REPOSITORY_URL_FIELD: &str = "repository_url"; -const REVISION_FIELD: &str = "revision"; - -#[derive(Debug, Clone)] -pub struct NamedExample { - pub name: String, - pub example: Example, -} - -#[derive(Clone, Debug, Hash, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Example { + #[serde(default)] + pub name: String, pub repository_url: String, pub revision: String, pub uncommitted_diff: String, - pub cursor_path: PathBuf, + pub cursor_path: Arc, pub cursor_position: String, pub edit_history: String, pub expected_patch: String, + + /// The full content of the file where an edit is being predicted, and the + /// actual cursor offset. + #[serde(skip_serializing_if = "Option::is_none")] + pub buffer: Option, + + /// The context retrieved for the prediction. This requires the worktree to + /// be loaded and the language server to be started. + #[serde(skip_serializing_if = "Option::is_none")] + pub context: Option, + + /// The input and expected output from the edit prediction model. + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, + + /// The actual predictions from the model. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub predictions: Vec, + + /// The scores, for how well the actual predictions match the expected + /// predictions. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub score: Vec, + + /// The application state used to process this example. + #[serde(skip)] + pub state: Option, +} + +#[derive(Clone, Debug)] +pub struct ExampleState { + pub project: Entity, + pub buffer: Entity, + pub cursor_position: Anchor, + pub _open_buffers: OpenedBuffers, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ExampleContext { + pub files: Arc<[RelatedFile]>, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ExampleBuffer { + pub content: String, + pub cursor_row: u32, + pub cursor_column: u32, + pub cursor_offset: usize, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ExamplePrompt { + pub input: String, + pub expected_output: String, + pub format: PromptFormat, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ExamplePrediction { + pub actual_patch: String, + pub actual_output: String, + pub provider: PredictionProvider, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ExampleScore { + pub delta_chr_f: f32, + pub line_match: ClassificationMetrics, } impl Example { @@ -90,485 +134,244 @@ impl Example { } } - pub async fn setup_worktree(&self, file_name: String) -> Result { - let (repo_owner, repo_name) = self.repo_name()?; - - let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref()); - let repo_lock = lock_repo(&repo_dir).await; + pub fn worktree_path(&self) -> PathBuf { + WORKTREES_DIR + .join(&self.name) + .join(self.repo_name().unwrap().1.as_ref()) + } - if !repo_dir.is_dir() { - fs::create_dir_all(&repo_dir)?; - run_git(&repo_dir, &["init"]).await?; - run_git( - &repo_dir, - &["remote", "add", "origin", &self.repository_url], - ) - .await?; - } + pub fn repo_path(&self) -> PathBuf { + let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name"); + REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref()) + } +} - // Resolve the example to a revision, fetching it if needed. - let revision = run_git( - &repo_dir, - &["rev-parse", &format!("{}^{{commit}}", self.revision)], - ) - .await; - let revision = if let Ok(revision) = revision { - revision +pub fn read_examples(inputs: &[PathBuf]) -> Vec { + let mut examples = Vec::new(); + + let stdin_path: PathBuf = PathBuf::from("-"); + + let inputs = if inputs.is_empty() { + &[stdin_path] + } else { + inputs + }; + + for path in inputs { + let is_stdin = path.as_path() == Path::new("-"); + let content = if is_stdin { + let mut buffer = String::new(); + std::io::stdin() + .read_to_string(&mut buffer) + .expect("Failed to read from stdin"); + buffer } else { - if run_git( - &repo_dir, - &["fetch", "--depth", "1", "origin", &self.revision], - ) - .await - .is_err() - { - run_git(&repo_dir, &["fetch", "origin"]).await?; - } - let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?; - if revision != self.revision { - run_git(&repo_dir, &["tag", &self.revision, &revision]).await?; - } - revision + std::fs::read_to_string(path) + .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path)) }; - - // Create the worktree for this example if needed. - let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref()); - if worktree_path.is_dir() { - run_git(&worktree_path, &["clean", "--force", "-d"]).await?; - run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?; - run_git(&worktree_path, &["checkout", revision.as_str()]).await?; + let filename = path.file_stem().unwrap().to_string_lossy().to_string(); + let ext = if !is_stdin { + path.extension() + .map(|ext| ext.to_string_lossy().to_string()) + .unwrap_or_else(|| panic!("{} should have an extension", path.display())) } else { - let worktree_path_string = worktree_path.to_string_lossy(); - run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?; - run_git( - &repo_dir, - &["worktree", "add", "-f", &worktree_path_string, &file_name], - ) - .await?; - } - drop(repo_lock); - - // Apply the uncommitted diff for this example. - if !self.uncommitted_diff.is_empty() { - let mut apply_process = smol::process::Command::new("git") - .current_dir(&worktree_path) - .args(&["apply", "-"]) - .stdin(std::process::Stdio::piped()) - .spawn()?; - - let mut stdin = apply_process.stdin.take().unwrap(); - stdin.write_all(self.uncommitted_diff.as_bytes()).await?; - stdin.close().await?; - drop(stdin); - - let apply_result = apply_process.output().await?; - if !apply_result.status.success() { - anyhow::bail!( - "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}", - apply_result.status, - String::from_utf8_lossy(&apply_result.stderr), - String::from_utf8_lossy(&apply_result.stdout), - ); + "jsonl".to_string() + }; + + match ext.as_ref() { + "json" => { + let mut example = + serde_json::from_str::(&content).unwrap_or_else(|error| { + panic!("Failed to parse example file: {}\n{error}", path.display()) + }); + if example.name.is_empty() { + example.name = filename; + } + examples.push(example); + } + "jsonl" => examples.extend( + content + .lines() + .enumerate() + .map(|(line_ix, line)| { + let mut example = + serde_json::from_str::(line).unwrap_or_else(|_| { + panic!( + "Failed to parse example on {}:{}", + path.display(), + line_ix + 1 + ) + }); + if example.name.is_empty() { + example.name = format!("{filename}-{line_ix}") + } + example + }) + .collect::>(), + ), + "md" => { + examples.push(parse_markdown_example(filename, &content).unwrap()); + } + ext => { + panic!("{} has invalid example extension `{ext}`", path.display()) } } - - Ok(worktree_path) - } - - pub fn unique_name(&self) -> String { - let mut hasher = std::hash::DefaultHasher::new(); - self.hash(&mut hasher); - let disambiguator = hasher.finish(); - let hash = format!("{:04x}", disambiguator); - format!("{}_{}", &self.revision[..8], &hash[..4]) } + examples } -pub type ActualExcerpt = Excerpt; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Excerpt { - pub path: PathBuf, - pub text: String, -} - -#[derive(ValueEnum, Debug, Clone)] -pub enum ExampleFormat { - Json, - Toml, - Md, +pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) { + let mut content = String::new(); + for example in examples { + let line = serde_json::to_string(example).unwrap(); + content.push_str(&line); + content.push('\n'); + } + if let Some(output_path) = output_path { + std::fs::write(output_path, content).expect("Failed to write examples"); + } else { + std::io::stdout().write_all(&content.as_bytes()).unwrap(); + } } -impl NamedExample { - pub fn load(path: impl AsRef) -> Result { - let path = path.as_ref(); - let content = std::fs::read_to_string(path)?; - let ext = path.extension(); - - match ext.and_then(|s| s.to_str()) { - Some("json") => Ok(Self { - name: path.file_stem().unwrap_or_default().display().to_string(), - example: serde_json::from_str(&content)?, - }), - Some("toml") => Ok(Self { - name: path.file_stem().unwrap_or_default().display().to_string(), - example: toml::from_str(&content)?, - }), - Some("md") => Self::parse_md(&content), - Some(_) => { - anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display()); - } - None => { - anyhow::bail!( - "Failed to determine example type since the file does not have an extension." - ); - } - } +fn parse_markdown_example(id: String, input: &str) -> Result { + use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd}; + + const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff"; + const EDIT_HISTORY_HEADING: &str = "Edit History"; + const CURSOR_POSITION_HEADING: &str = "Cursor Position"; + const EXPECTED_PATCH_HEADING: &str = "Expected Patch"; + const EXPECTED_CONTEXT_HEADING: &str = "Expected Context"; + const REPOSITORY_URL_FIELD: &str = "repository_url"; + const REVISION_FIELD: &str = "revision"; + + let parser = Parser::new(input); + + let mut example = Example { + name: id, + repository_url: String::new(), + revision: String::new(), + uncommitted_diff: String::new(), + cursor_path: PathBuf::new().into(), + cursor_position: String::new(), + edit_history: String::new(), + expected_patch: String::new(), + buffer: None, + context: None, + prompt: None, + predictions: Vec::new(), + score: Vec::new(), + state: None, + }; + + let mut name = String::new(); + let mut text = String::new(); + let mut block_info: CowStr = "".into(); + + #[derive(PartialEq)] + enum Section { + UncommittedDiff, + EditHistory, + CursorPosition, + ExpectedExcerpts, + ExpectedPatch, + Other, } - pub fn parse_md(input: &str) -> Result { - use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd}; - - let parser = Parser::new(input); - - let mut named = NamedExample { - name: String::new(), - example: Example { - repository_url: String::new(), - revision: String::new(), - uncommitted_diff: String::new(), - cursor_path: PathBuf::new(), - cursor_position: String::new(), - edit_history: String::new(), - expected_patch: String::new(), - }, - }; + let mut current_section = Section::Other; - let mut text = String::new(); - let mut block_info: CowStr = "".into(); - - #[derive(PartialEq)] - enum Section { - UncommittedDiff, - EditHistory, - CursorPosition, - ExpectedExcerpts, - ExpectedPatch, - Other, - } + for event in parser { + match event { + Event::Text(line) => { + text.push_str(&line); - let mut current_section = Section::Other; - - for event in parser { - match event { - Event::Text(line) => { - text.push_str(&line); - - if !named.name.is_empty() - && current_section == Section::Other - // in h1 section - && let Some((field, value)) = line.split_once('=') - { - match field.trim() { - REPOSITORY_URL_FIELD => { - named.example.repository_url = value.trim().to_string(); - } - REVISION_FIELD => { - named.example.revision = value.trim().to_string(); - } - _ => {} - } - } - } - Event::End(TagEnd::Heading(HeadingLevel::H1)) => { - if !named.name.is_empty() { - anyhow::bail!( - "Found multiple H1 headings. There should only be one with the name of the example." - ); - } - named.name = mem::take(&mut text); - } - Event::End(TagEnd::Heading(HeadingLevel::H2)) => { - let title = mem::take(&mut text); - current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) { - Section::UncommittedDiff - } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) { - Section::EditHistory - } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) { - Section::CursorPosition - } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) { - Section::ExpectedPatch - } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) { - Section::ExpectedExcerpts - } else { - Section::Other - }; - } - Event::End(TagEnd::Heading(HeadingLevel::H3)) => { - mem::take(&mut text); - } - Event::End(TagEnd::Heading(HeadingLevel::H4)) => { - mem::take(&mut text); - } - Event::End(TagEnd::Heading(level)) => { - anyhow::bail!("Unexpected heading level: {level}"); - } - Event::Start(Tag::CodeBlock(kind)) => { - match kind { - CodeBlockKind::Fenced(info) => { - block_info = info; - } - CodeBlockKind::Indented => { - anyhow::bail!("Unexpected indented codeblock"); - } - }; - } - Event::Start(_) => { - text.clear(); - block_info = "".into(); - } - Event::End(TagEnd::CodeBlock) => { - let block_info = block_info.trim(); - match current_section { - Section::UncommittedDiff => { - named.example.uncommitted_diff = mem::take(&mut text); - } - Section::EditHistory => { - named.example.edit_history.push_str(&mem::take(&mut text)); - } - Section::CursorPosition => { - named.example.cursor_path = block_info.into(); - named.example.cursor_position = mem::take(&mut text); - } - Section::ExpectedExcerpts => { - mem::take(&mut text); + if let Some((field, value)) = line.split_once('=') { + match field.trim() { + REPOSITORY_URL_FIELD => { + example.repository_url = value.trim().to_string(); } - Section::ExpectedPatch => { - named.example.expected_patch = mem::take(&mut text); + REVISION_FIELD => { + example.revision = value.trim().to_string(); } - Section::Other => {} + _ => {} } } - _ => {} } - } - - if named.example.cursor_path.as_path() == Path::new("") - || named.example.cursor_position.is_empty() - { - anyhow::bail!("Missing cursor position codeblock"); - } - - Ok(named) - } - - pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> { - match format { - ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?), - ExampleFormat::Toml => { - Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?) + Event::End(TagEnd::Heading(HeadingLevel::H1)) => { + if !name.is_empty() { + anyhow::bail!( + "Found multiple H1 headings. There should only be one with the name of the example." + ); + } + name = mem::take(&mut text); } - ExampleFormat::Md => Ok(write!(out, "{}", self)?), - } - } - - pub async fn setup_project( - &self, - app_state: &Arc, - cx: &mut AsyncApp, - ) -> Result> { - let worktree_path = self.setup_worktree().await?; - - static AUTHENTICATED: OnceLock>> = OnceLock::new(); - - AUTHENTICATED - .get_or_init(|| { - let client = app_state.client.clone(); - cx.spawn(async move |cx| { - client - .sign_in_with_optional_connect(true, cx) - .await - .unwrap(); - }) - .shared() - }) - .clone() - .await; - - let project = cx.update(|cx| { - Project::local( - app_state.client.clone(), - app_state.node_runtime.clone(), - app_state.user_store.clone(), - app_state.languages.clone(), - app_state.fs.clone(), - None, - cx, - ) - })?; - - let worktree = project - .update(cx, |project, cx| { - project.create_worktree(&worktree_path, true, cx) - })? - .await?; - worktree - .read_with(cx, |worktree, _cx| { - worktree.as_local().unwrap().scan_complete() - })? - .await; - - anyhow::Ok(project) - } - - pub async fn setup_worktree(&self) -> Result { - self.example.setup_worktree(self.file_name()).await - } - - pub fn file_name(&self) -> String { - self.name - .chars() - .map(|c| { - if c.is_whitespace() { - '-' + Event::End(TagEnd::Heading(HeadingLevel::H2)) => { + let title = mem::take(&mut text); + current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) { + Section::UncommittedDiff + } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) { + Section::EditHistory + } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) { + Section::CursorPosition + } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) { + Section::ExpectedPatch + } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) { + Section::ExpectedExcerpts } else { - c.to_ascii_lowercase() + Section::Other + }; + } + Event::End(TagEnd::Heading(HeadingLevel::H3)) => { + mem::take(&mut text); + } + Event::End(TagEnd::Heading(HeadingLevel::H4)) => { + mem::take(&mut text); + } + Event::End(TagEnd::Heading(level)) => { + anyhow::bail!("Unexpected heading level: {level}"); + } + Event::Start(Tag::CodeBlock(kind)) => { + match kind { + CodeBlockKind::Fenced(info) => { + block_info = info; + } + CodeBlockKind::Indented => { + anyhow::bail!("Unexpected indented codeblock"); + } + }; + } + Event::Start(_) => { + text.clear(); + block_info = "".into(); + } + Event::End(TagEnd::CodeBlock) => { + let block_info = block_info.trim(); + match current_section { + Section::UncommittedDiff => { + example.uncommitted_diff = mem::take(&mut text); + } + Section::EditHistory => { + example.edit_history.push_str(&mem::take(&mut text)); + } + Section::CursorPosition => { + example.cursor_path = Path::new(block_info).into(); + example.cursor_position = mem::take(&mut text); + } + Section::ExpectedExcerpts => { + mem::take(&mut text); + } + Section::ExpectedPatch => { + example.expected_patch = mem::take(&mut text); + } + Section::Other => {} } - }) - .collect() - } - - pub async fn cursor_position( - &self, - project: &Entity, - cx: &mut AsyncApp, - ) -> Result<(Entity, Anchor)> { - let worktree = project.read_with(cx, |project, cx| { - project.visible_worktrees(cx).next().unwrap() - })?; - let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc(); - let cursor_buffer = project - .update(cx, |project, cx| { - project.open_buffer( - ProjectPath { - worktree_id: worktree.read(cx).id(), - path: cursor_path, - }, - cx, - ) - })? - .await?; - let cursor_offset_within_excerpt = self - .example - .cursor_position - .find(CURSOR_MARKER) - .ok_or_else(|| anyhow!("missing cursor marker"))?; - let mut cursor_excerpt = self.example.cursor_position.clone(); - cursor_excerpt.replace_range( - cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()), - "", - ); - let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| { - let text = buffer.text(); - - let mut matches = text.match_indices(&cursor_excerpt); - let Some((excerpt_offset, _)) = matches.next() else { - anyhow::bail!( - "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer." - ); - }; - assert!(matches.next().is_none()); - - Ok(excerpt_offset) - })??; - - let cursor_offset = excerpt_offset + cursor_offset_within_excerpt; - let cursor_anchor = - cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?; - Ok((cursor_buffer, cursor_anchor)) - } - - #[must_use] - pub async fn apply_edit_history( - &self, - project: &Entity, - cx: &mut AsyncApp, - ) -> Result> { - edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await - } -} - -async fn run_git(repo_path: &Path, args: &[&str]) -> Result { - let output = smol::process::Command::new("git") - .current_dir(repo_path) - .args(args) - .output() - .await?; - - anyhow::ensure!( - output.status.success(), - "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}", - args.join(" "), - repo_path.display(), - output.status, - String::from_utf8_lossy(&output.stderr), - String::from_utf8_lossy(&output.stdout), - ); - Ok(String::from_utf8(output.stdout)?.trim().to_string()) -} - -impl Display for NamedExample { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "# {}\n\n", self.name)?; - write!( - f, - "{REPOSITORY_URL_FIELD} = {}\n", - self.example.repository_url - )?; - write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?; - - write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?; - write!(f, "`````diff\n")?; - write!(f, "{}", self.example.uncommitted_diff)?; - write!(f, "`````\n")?; - - if !self.example.edit_history.is_empty() { - write!(f, "`````diff\n{}`````\n", self.example.edit_history)?; - } - - write!( - f, - "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n", - self.example.cursor_path.display(), - self.example.cursor_position - )?; - write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?; - - if !self.example.expected_patch.is_empty() { - write!( - f, - "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n", - self.example.expected_patch - )?; + } + _ => {} } - - Ok(()) } -} - -thread_local! { - static REPO_LOCKS: RefCell>>> = RefCell::new(HashMap::default()); -} + if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() { + anyhow::bail!("Missing cursor position codeblock"); + } -#[must_use] -pub async fn lock_repo(path: impl AsRef) -> OwnedMutexGuard<()> { - REPO_LOCKS - .with(|cell| { - cell.borrow_mut() - .entry(path.as_ref().to_path_buf()) - .or_default() - .clone() - }) - .lock_owned() - .await + Ok(example) } diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs new file mode 100644 index 0000000000000000000000000000000000000000..53ef6ebfde77dcecba9926062cdfd75c1ee3521c --- /dev/null +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -0,0 +1,280 @@ +use crate::{ + PromptFormat, + example::{Example, ExamplePrompt}, + headless::EpAppState, + retrieve_context::run_context_retrieval, +}; +use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input}; +use gpui::AsyncApp; +use std::sync::Arc; +use zeta_prompt::format_zeta_prompt; + +pub async fn run_format_prompt( + example: &mut Example, + prompt_format: PromptFormat, + app_state: Arc, + mut cx: AsyncApp, +) { + run_context_retrieval(example, app_state, cx.clone()).await; + + let prompt = match prompt_format { + PromptFormat::Teacher => TeacherPrompt::format(example), + PromptFormat::Zeta2 => { + let ep_store = cx + .update(|cx| EditPredictionStore::try_global(cx).unwrap()) + .unwrap(); + + let state = example.state.as_ref().unwrap(); + let snapshot = state + .buffer + .read_with(&cx, |buffer, _| buffer.snapshot()) + .unwrap(); + let project = state.project.clone(); + let (_, input) = ep_store + .update(&mut cx, |ep_store, _cx| { + zeta2_prompt_input( + &snapshot, + example.context.as_ref().unwrap().files.clone(), + ep_store.edit_history_for_project(&project), + example.cursor_path.clone(), + example.buffer.as_ref().unwrap().cursor_offset, + ) + }) + .unwrap(); + format_zeta_prompt(&input) + } + }; + + example.prompt = Some(ExamplePrompt { + input: prompt, + expected_output: example.expected_patch.clone(), // TODO + format: prompt_format, + }); +} + +pub trait PromptFormatter { + fn format(example: &Example) -> String; +} + +pub trait PromptParser { + /// Return unified diff patch of prediction given raw LLM response + fn parse(example: &Example, response: &str) -> String; +} + +pub struct TeacherPrompt; + +impl PromptFormatter for TeacherPrompt { + fn format(example: &Example) -> String { + let edit_history = Self::format_edit_history(&example.edit_history); + let context = Self::format_context(example); + let editable_region = Self::format_editable_region(example); + + let prompt = Self::PROMPT + .replace("{{context}}", &context) + .replace("{{edit_history}}", &edit_history) + .replace("{{editable_region}}", &editable_region); + + prompt + } +} + +impl TeacherPrompt { + const PROMPT: &str = include_str!("teacher.prompt.md"); + pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n"; + pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>"; + + /// Truncate edit history to this number of last lines + const MAX_HISTORY_LINES: usize = 128; + + fn format_edit_history(edit_history: &str) -> String { + // Strip comments ("garbage lines") from edit history + let lines = edit_history + .lines() + .filter(|&s| Self::is_udiff_content_line(s)) + .collect::>(); + + let history_lines = if lines.len() > Self::MAX_HISTORY_LINES { + &lines[lines.len() - Self::MAX_HISTORY_LINES..] + } else { + &lines + }; + + if history_lines.is_empty() { + return "(No edit history)".to_string(); + } + + history_lines.join("\n") + } + + fn format_context(example: &Example) -> String { + if example.context.is_none() { + panic!("Missing context retriever step"); + } + + let mut prompt = String::new(); + zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files); + + prompt + } + + fn format_editable_region(example: &Example) -> String { + let mut result = String::new(); + + let path_str = example.cursor_path.to_string_lossy(); + result.push_str(&format!("`````path=\"{path_str}\"\n")); + result.push_str(Self::EDITABLE_REGION_START); + + // TODO: control number of lines around cursor + result.push_str(&example.cursor_position); + if !example.cursor_position.ends_with('\n') { + result.push('\n'); + } + + result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END)); + result.push_str("`````"); + + result + } + + fn extract_editable_region(text: &str) -> String { + let start = text + .find(Self::EDITABLE_REGION_START) + .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len()); + let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len()); + + let region = &text[start..end]; + + region.replace("<|user_cursor|>", "") + } + + fn is_udiff_content_line(s: &str) -> bool { + s.starts_with("-") + || s.starts_with("+") + || s.starts_with(" ") + || s.starts_with("---") + || s.starts_with("+++") + || s.starts_with("@@") + } +} + +impl PromptParser for TeacherPrompt { + fn parse(example: &Example, response: &str) -> String { + // Ideally, we should always be able to find cursor position in the retrieved context. + // In reality, sometimes we don't find it for these reasons: + // 1. `example.cursor_position` contains _more_ context than included in the retrieved context + // (can be fixed by getting cursor coordinates at the load_example stage) + // 2. Context retriever just didn't include cursor line. + // + // In that case, fallback to using `cursor_position` as excerpt. + let cursor_file = &example + .buffer + .as_ref() + .expect("`buffer` should be filled in in the context collection step") + .content; + + // Extract updated (new) editable region from the model response + let new_editable_region = extract_last_codeblock(response); + + // Reconstruct old editable region we sent to the model + let old_editable_region = Self::format_editable_region(example); + let old_editable_region = Self::extract_editable_region(&old_editable_region); + if !cursor_file.contains(&old_editable_region) { + panic!("Something's wrong: editable_region is not found in the cursor file") + } + + // Apply editable region to a larger context and compute diff. + // This is needed to get a better context lines around the editable region + let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region); + let diff = language::unified_diff(&cursor_file, &edited_file); + + let diff = indoc::formatdoc! {" + --- a/{path} + +++ b/{path} + {diff} + ", + path = example.cursor_path.to_string_lossy(), + diff = diff, + }; + + diff + } +} + +fn extract_last_codeblock(text: &str) -> String { + let mut last_block = None; + let mut search_start = 0; + + while let Some(start) = text[search_start..].find("```") { + let start = start + search_start; + let bytes = text.as_bytes(); + let mut backtick_end = start; + + while backtick_end < bytes.len() && bytes[backtick_end] == b'`' { + backtick_end += 1; + } + + let backtick_count = backtick_end - start; + let closing_backticks = "`".repeat(backtick_count); + + while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' { + backtick_end += 1; + } + + if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) { + let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1]; + last_block = Some(code_block.to_string()); + search_start = backtick_end + end_pos + backtick_count; + } else { + break; + } + } + + last_block.unwrap_or_else(|| text.to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_last_code_block() { + let text = indoc::indoc! {" + Some thinking + + ``` + first block + ``` + + `````path='something' lines=1:2 + last block + ````` + "}; + let last_block = extract_last_codeblock(text); + assert_eq!(last_block, "last block"); + } + + #[test] + fn test_extract_editable_region() { + let text = indoc::indoc! {" + some lines + are + here + <|editable_region_start|> + one + two three + + <|editable_region_end|> + more + lines here + "}; + let parsed = TeacherPrompt::extract_editable_region(text); + assert_eq!( + parsed, + indoc::indoc! {" + one + two three + + "} + ); + } +} diff --git a/crates/edit_prediction_cli/src/headless.rs b/crates/edit_prediction_cli/src/headless.rs index c4d8667d63dfb3dd39fbced609e0ae0bc44974d2..fd20774168ea3c07f4efffdefe23f1b4ff5f5ef4 100644 --- a/crates/edit_prediction_cli/src/headless.rs +++ b/crates/edit_prediction_cli/src/headless.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use util::ResultExt as _; /// Headless subset of `workspace::AppState`. -pub struct ZetaCliAppState { +pub struct EpAppState { pub languages: Arc, pub client: Arc, pub user_store: Entity, @@ -25,7 +25,7 @@ pub struct ZetaCliAppState { } // TODO: dedupe with crates/eval/src/eval.rs -pub fn init(cx: &mut App) -> ZetaCliAppState { +pub fn init(cx: &mut App) -> EpAppState { let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned())); let app_version = AppVersion::load( @@ -112,7 +112,7 @@ pub fn init(cx: &mut App) -> ZetaCliAppState { prompt_store::init(cx); terminal_view::init(cx); - ZetaCliAppState { + EpAppState { languages, client, user_store, diff --git a/crates/edit_prediction_cli/src/load_project.rs b/crates/edit_prediction_cli/src/load_project.rs new file mode 100644 index 0000000000000000000000000000000000000000..842b63a43335454655ed41ef4d852167e8faf72a --- /dev/null +++ b/crates/edit_prediction_cli/src/load_project.rs @@ -0,0 +1,320 @@ +use crate::{ + example::{Example, ExampleBuffer, ExampleState}, + headless::EpAppState, +}; +use anyhow::{Result, anyhow}; +use collections::HashMap; +use edit_prediction::EditPredictionStore; +use edit_prediction::udiff::OpenedBuffers; +use futures::{ + AsyncWriteExt as _, + lock::{Mutex, OwnedMutexGuard}, +}; +use gpui::{AsyncApp, Entity}; +use language::{Anchor, Buffer, ToOffset, ToPoint}; +use project::buffer_store::BufferStoreEvent; +use project::{Project, ProjectPath}; +use std::{ + cell::RefCell, + fs, + path::{Path, PathBuf}, + sync::Arc, +}; +use util::{paths::PathStyle, rel_path::RelPath}; +use zeta_prompt::CURSOR_MARKER; + +pub async fn run_load_project(example: &mut Example, app_state: Arc, mut cx: AsyncApp) { + if example.state.is_some() { + return; + } + + let project = setup_project(example, &app_state, &mut cx).await; + let buffer_store = project + .read_with(&cx, |project, _| project.buffer_store().clone()) + .unwrap(); + + let ep_store = cx + .update(|cx| EditPredictionStore::try_global(cx).unwrap()) + .unwrap(); + + cx.subscribe(&buffer_store, { + let project = project.clone(); + move |_, event, cx| match event { + BufferStoreEvent::BufferAdded(buffer) => { + ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx)); + } + _ => {} + } + }) + .unwrap() + .detach(); + + let _open_buffers = apply_edit_history(example, &project, &mut cx) + .await + .unwrap(); + let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await; + example.buffer = buffer + .read_with(&cx, |buffer, _cx| { + let cursor_point = cursor_position.to_point(&buffer); + Some(ExampleBuffer { + content: buffer.text(), + cursor_row: cursor_point.row, + cursor_column: cursor_point.column, + cursor_offset: cursor_position.to_offset(&buffer), + }) + }) + .unwrap(); + example.state = Some(ExampleState { + buffer, + project, + cursor_position, + _open_buffers, + }); +} + +async fn cursor_position( + example: &Example, + project: &Entity, + cx: &mut AsyncApp, +) -> (Entity, Anchor) { + let worktree = project + .read_with(cx, |project, cx| { + project.visible_worktrees(cx).next().unwrap() + }) + .unwrap(); + + let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix) + .unwrap() + .into_arc(); + let cursor_buffer = project + .update(cx, |project, cx| { + project.open_buffer( + ProjectPath { + worktree_id: worktree.read(cx).id(), + path: cursor_path, + }, + cx, + ) + }) + .unwrap() + .await + .unwrap(); + let cursor_offset_within_excerpt = example + .cursor_position + .find(CURSOR_MARKER) + .ok_or_else(|| anyhow!("missing cursor marker")) + .unwrap(); + let mut cursor_excerpt = example.cursor_position.clone(); + cursor_excerpt.replace_range( + cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()), + "", + ); + let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| { + let text = buffer.text(); + + let mut matches = text.match_indices(&cursor_excerpt); + let (excerpt_offset, _) = matches.next().unwrap_or_else(|| { + panic!( + "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer." + ); + }); + assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name); + excerpt_offset + }).unwrap(); + + let cursor_offset = excerpt_offset + cursor_offset_within_excerpt; + let cursor_anchor = cursor_buffer + .read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset)) + .unwrap(); + + (cursor_buffer, cursor_anchor) +} + +async fn setup_project( + example: &mut Example, + app_state: &Arc, + cx: &mut AsyncApp, +) -> Entity { + setup_worktree(example).await; + + let project = cx + .update(|cx| { + Project::local( + app_state.client.clone(), + app_state.node_runtime.clone(), + app_state.user_store.clone(), + app_state.languages.clone(), + app_state.fs.clone(), + None, + cx, + ) + }) + .unwrap(); + + let worktree = project + .update(cx, |project, cx| { + project.create_worktree(&example.worktree_path(), true, cx) + }) + .unwrap() + .await + .unwrap(); + worktree + .read_with(cx, |worktree, _cx| { + worktree.as_local().unwrap().scan_complete() + }) + .unwrap() + .await; + project +} + +pub async fn setup_worktree(example: &Example) { + let repo_dir = example.repo_path(); + let repo_lock = lock_repo(&repo_dir).await; + + if !repo_dir.is_dir() { + fs::create_dir_all(&repo_dir).unwrap(); + run_git(&repo_dir, &["init"]).await.unwrap(); + run_git( + &repo_dir, + &["remote", "add", "origin", &example.repository_url], + ) + .await + .unwrap(); + } + + // Resolve the example to a revision, fetching it if needed. + let revision = run_git( + &repo_dir, + &["rev-parse", &format!("{}^{{commit}}", example.revision)], + ) + .await; + let revision = if let Ok(revision) = revision { + revision + } else { + if run_git( + &repo_dir, + &["fetch", "--depth", "1", "origin", &example.revision], + ) + .await + .is_err() + { + run_git(&repo_dir, &["fetch", "origin"]).await.unwrap(); + } + let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]) + .await + .unwrap(); + if revision != example.revision { + run_git(&repo_dir, &["tag", &example.revision, &revision]) + .await + .unwrap(); + } + revision + }; + + // Create the worktree for this example if needed. + let worktree_path = example.worktree_path(); + if worktree_path.is_dir() { + run_git(&worktree_path, &["clean", "--force", "-d"]) + .await + .unwrap(); + run_git(&worktree_path, &["reset", "--hard", "HEAD"]) + .await + .unwrap(); + run_git(&worktree_path, &["checkout", revision.as_str()]) + .await + .unwrap(); + } else { + let worktree_path_string = worktree_path.to_string_lossy(); + run_git( + &repo_dir, + &["branch", "-f", &example.name, revision.as_str()], + ) + .await + .unwrap(); + run_git( + &repo_dir, + &[ + "worktree", + "add", + "-f", + &worktree_path_string, + &example.name, + ], + ) + .await + .unwrap(); + } + drop(repo_lock); + + // Apply the uncommitted diff for this example. + if !example.uncommitted_diff.is_empty() { + let mut apply_process = smol::process::Command::new("git") + .current_dir(&worktree_path) + .args(&["apply", "-"]) + .stdin(std::process::Stdio::piped()) + .spawn() + .unwrap(); + + let mut stdin = apply_process.stdin.take().unwrap(); + stdin + .write_all(example.uncommitted_diff.as_bytes()) + .await + .unwrap(); + stdin.close().await.unwrap(); + drop(stdin); + + let apply_result = apply_process.output().await.unwrap(); + if !apply_result.status.success() { + panic!( + "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}", + apply_result.status, + String::from_utf8_lossy(&apply_result.stderr), + String::from_utf8_lossy(&apply_result.stdout), + ); + } + } +} + +async fn apply_edit_history( + example: &Example, + project: &Entity, + cx: &mut AsyncApp, +) -> Result { + edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await +} + +thread_local! { + static REPO_LOCKS: RefCell>>> = RefCell::new(HashMap::default()); +} + +#[must_use] +pub async fn lock_repo(path: impl AsRef) -> OwnedMutexGuard<()> { + REPO_LOCKS + .with(|cell| { + cell.borrow_mut() + .entry(path.as_ref().to_path_buf()) + .or_default() + .clone() + }) + .lock_owned() + .await +} + +async fn run_git(repo_path: &Path, args: &[&str]) -> Result { + let output = smol::process::Command::new("git") + .current_dir(repo_path) + .args(args) + .output() + .await?; + + anyhow::ensure!( + output.status.success(), + "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}", + args.join(" "), + repo_path.display(), + output.status, + String::from_utf8_lossy(&output.stderr), + String::from_utf8_lossy(&output.stdout), + ); + Ok(String::from_utf8(output.stdout)?.trim().to_string()) +} diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 00086777f1f03112b92f11923ad2d025276699f5..51ea23649d0ec0b124c38ead2897ba16ecd96e26 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -1,522 +1,196 @@ -mod evaluate; +mod anthropic_client; mod example; +mod format_prompt; mod headless; +mod load_project; mod metrics; mod paths; mod predict; -mod source_location; -mod training; -mod util; +mod retrieve_context; +mod score; -use crate::{ - evaluate::run_evaluate, - example::{ExampleFormat, NamedExample}, - headless::ZetaCliAppState, - predict::run_predict, - source_location::SourceLocation, - training::{context::ContextType, distill::run_distill}, - util::{open_buffer, open_buffer_with_language_server}, -}; -use ::util::{ResultExt, paths::PathStyle}; -use anyhow::{Result, anyhow}; -use clap::{Args, Parser, Subcommand, ValueEnum}; -use cloud_llm_client::predict_edits_v3; -use edit_prediction::udiff::DiffLine; -use edit_prediction_context::EditPredictionExcerptOptions; -use gpui::{Application, AsyncApp, Entity, prelude::*}; -use language::{Bias, Buffer, BufferSnapshot, Point}; -use metrics::delta_chr_f; -use project::{Project, Worktree, lsp_store::OpenLspBufferHandle}; +use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum}; +use edit_prediction::EditPredictionStore; +use gpui::Application; use reqwest_client::ReqwestClient; -use std::io::{self}; -use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc}; +use serde::{Deserialize, Serialize}; +use std::{path::PathBuf, sync::Arc}; + +use crate::example::{read_examples, write_examples}; +use crate::format_prompt::run_format_prompt; +use crate::load_project::run_load_project; +use crate::predict::run_prediction; +use crate::retrieve_context::run_context_retrieval; +use crate::score::run_scoring; #[derive(Parser, Debug)] -#[command(name = "zeta")] -struct ZetaCliArgs { +#[command(name = "ep")] +struct EpArgs { #[arg(long, default_value_t = false)] printenv: bool, + #[clap(long, default_value_t = 10)] + max_parallelism: usize, #[command(subcommand)] command: Option, + #[clap(global = true)] + inputs: Vec, + #[arg(long, short, global = true)] + output: Option, + #[arg(long, short, global = true)] + in_place: bool, } #[derive(Subcommand, Debug)] enum Command { - Context(ContextArgs), - Predict(PredictArguments), - Eval(EvaluateArguments), - Distill(DistillArguments), - ConvertExample { - path: PathBuf, - #[arg(long, value_enum, default_value_t = ExampleFormat::Md)] - output_format: ExampleFormat, - }, - Score { - golden_patch: PathBuf, - actual_patch: PathBuf, - }, + /// Parse markdown examples and output a combined .jsonl file + ParseExample, + /// Create git worktrees for each example and load file contents + LoadBuffer, + /// Retrieve context for input examples. + Context, + /// Generate a prompt string for a specific model + FormatPrompt(FormatPromptArgs), + /// Runs edit prediction + Predict(PredictArgs), + /// Computes a score based on actual and expected patches + Score(PredictArgs), + /// Print aggregated scores + Eval(PredictArgs), + /// Remove git repositories and worktrees Clean, } #[derive(Debug, Args)] -struct ContextArgs { - #[arg(long)] - provider: ContextProvider, - #[arg(long)] - worktree: PathBuf, - #[arg(long)] - cursor: SourceLocation, - #[arg(long)] - use_language_server: bool, - #[arg(long)] - edit_history: Option, - #[clap(flatten)] - zeta2_args: Zeta2Args, -} - -#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)] -enum ContextProvider { - Zeta1, - #[default] - Zeta2, -} - -#[derive(Clone, Debug, Args)] -struct Zeta2Args { - #[arg(long, default_value_t = 8192)] - max_prompt_bytes: usize, - #[arg(long, default_value_t = 2048)] - max_excerpt_bytes: usize, - #[arg(long, default_value_t = 1024)] - min_excerpt_bytes: usize, - #[arg(long, default_value_t = 0.66)] - target_before_cursor_over_total_bytes: f32, - #[arg(long, default_value_t = 1024)] - max_diagnostic_bytes: usize, - #[arg(long, value_enum, default_value_t = PromptFormat::default())] +struct FormatPromptArgs { + #[clap(long)] prompt_format: PromptFormat, - #[arg(long, value_enum, default_value_t = Default::default())] - output_format: OutputFormat, - #[arg(long, default_value_t = 42)] - file_indexing_parallelism: usize, - #[arg(long, default_value_t = false)] - disable_imports_gathering: bool, - #[arg(long, default_value_t = u8::MAX)] - max_retrieved_definitions: u8, } -#[derive(Debug, Args)] -pub struct PredictArguments { - #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)] - format: PredictionsOutputFormat, - example_path: PathBuf, - #[clap(flatten)] - options: PredictionOptions, +#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)] +enum PromptFormat { + Teacher, + Zeta2, } #[derive(Debug, Args)] -pub struct DistillArguments { - split_commit_dataset: PathBuf, - #[clap(long, value_enum, default_value_t = ContextType::CurrentFile)] - context_type: ContextType, - #[clap(long)] - batch: Option, -} - -#[derive(Clone, Debug, Args)] -pub struct PredictionOptions { - #[clap(flatten)] - zeta2: Zeta2Args, +struct PredictArgs { #[clap(long)] provider: PredictionProvider, - #[clap(long, value_enum, default_value_t = CacheMode::default())] - cache: CacheMode, -} - -#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)] -pub enum CacheMode { - /// Use cached LLM requests and responses, except when multiple repetitions are requested - #[default] - Auto, - /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint. - #[value(alias = "request")] - Requests, - /// Ignore existing cache entries for both LLM and search. - Skip, - /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet. - /// Useful for reproducing results and fixing bugs outside of search queries - Force, -} - -impl CacheMode { - fn use_cached_llm_responses(&self) -> bool { - self.assert_not_auto(); - matches!(self, CacheMode::Requests | CacheMode::Force) - } - - fn use_cached_search_results(&self) -> bool { - self.assert_not_auto(); - matches!(self, CacheMode::Force) - } - - fn assert_not_auto(&self) { - assert_ne!( - *self, - CacheMode::Auto, - "Cache mode should not be auto at this point!" - ); - } -} - -#[derive(clap::ValueEnum, Debug, Clone)] -pub enum PredictionsOutputFormat { - Json, - Md, - Diff, + #[clap(long, default_value_t = 1)] + repetitions: usize, } -#[derive(Debug, Args)] -pub struct EvaluateArguments { - example_paths: Vec, - #[clap(flatten)] - options: PredictionOptions, - #[clap(short, long, default_value_t = 1, alias = "repeat")] - repetitions: u16, - #[arg(long)] - skip_prediction: bool, -} - -#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)] +#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)] enum PredictionProvider { + Sweep, + Mercury, Zeta1, - #[default] Zeta2, - Sweep, -} - -fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions { - edit_prediction::ZetaOptions { - context: EditPredictionExcerptOptions { - max_bytes: args.max_excerpt_bytes, - min_bytes: args.min_excerpt_bytes, - target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes, - }, - max_prompt_bytes: args.max_prompt_bytes, - prompt_format: args.prompt_format.into(), - } -} - -#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)] -enum PromptFormat { - OnlySnippets, - #[default] - OldTextNewText, - Minimal, - MinimalQwen, - SeedCoder1120, + Teacher, } -impl Into for PromptFormat { - fn into(self) -> predict_edits_v3::PromptFormat { - match self { - Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets, - Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText, - Self::Minimal => predict_edits_v3::PromptFormat::Minimal, - Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen, - Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120, +impl EpArgs { + fn output_path(&self) -> Option { + if self.in_place { + if self.inputs.len() == 1 { + self.inputs.first().cloned() + } else { + panic!("--in-place requires exactly one input file") + } + } else { + self.output.clone() } } } -#[derive(clap::ValueEnum, Default, Debug, Clone)] -enum OutputFormat { - #[default] - Prompt, - Request, - Full, -} - -#[derive(Debug, Clone)] -enum FileOrStdin { - File(PathBuf), - Stdin, -} +fn main() { + zlog::init(); + zlog::init_output_stderr(); + let args = EpArgs::parse(); -impl FileOrStdin { - async fn read_to_string(&self) -> Result { - match self { - FileOrStdin::File(path) => smol::fs::read_to_string(path).await, - FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await, - } + if args.printenv { + ::util::shell_env::print_env(); + return; } -} - -impl FromStr for FileOrStdin { - type Err = ::Err; - fn from_str(s: &str) -> Result { - match s { - "-" => Ok(Self::Stdin), - _ => Ok(Self::File(PathBuf::from_str(s)?)), + let output = args.output_path(); + let command = match args.command { + Some(cmd) => cmd, + None => { + EpArgs::command().print_help().unwrap(); + return; } - } -} - -struct LoadedContext { - full_path_str: String, - snapshot: BufferSnapshot, - clipped_cursor: Point, - worktree: Entity, - project: Entity, - buffer: Entity, - lsp_open_handle: Option, -} - -async fn load_context( - args: &ContextArgs, - app_state: &Arc, - cx: &mut AsyncApp, -) -> Result { - let ContextArgs { - worktree: worktree_path, - cursor, - use_language_server, - .. - } = args; - - let worktree_path = worktree_path.canonicalize()?; - - let project = cx.update(|cx| { - Project::local( - app_state.client.clone(), - app_state.node_runtime.clone(), - app_state.user_store.clone(), - app_state.languages.clone(), - app_state.fs.clone(), - None, - cx, - ) - })?; - - let worktree = project - .update(cx, |project, cx| { - project.create_worktree(&worktree_path, true, cx) - })? - .await?; - - let mut ready_languages = HashSet::default(); - let (lsp_open_handle, buffer) = if *use_language_server { - let (lsp_open_handle, _, buffer) = open_buffer_with_language_server( - project.clone(), - worktree.clone(), - cursor.path.clone(), - &mut ready_languages, - cx, - ) - .await?; - (Some(lsp_open_handle), buffer) - } else { - let buffer = - open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?; - (None, buffer) }; - let full_path_str = worktree - .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))? - .display(PathStyle::local()) - .to_string(); - - let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?; - let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left); - if clipped_cursor != cursor.point { - let max_row = snapshot.max_point().row; - if cursor.point.row < max_row { - return Err(anyhow!( - "Cursor position {:?} is out of bounds (line length is {})", - cursor.point, - snapshot.line_len(cursor.point.row) - )); - } else { - return Err(anyhow!( - "Cursor position {:?} is out of bounds (max row is {})", - cursor.point, - max_row - )); + match &command { + Command::Clean => { + std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap(); + return; } + _ => {} } - Ok(LoadedContext { - full_path_str, - snapshot, - clipped_cursor, - worktree, - project, - buffer, - lsp_open_handle, - }) -} - -async fn zeta2_context( - args: ContextArgs, - app_state: &Arc, - cx: &mut AsyncApp, -) -> Result { - let LoadedContext { - worktree, - project, - buffer, - clipped_cursor, - lsp_open_handle: _handle, - .. - } = load_context(&args, app_state, cx).await?; - - // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for - // the whole worktree. - worktree - .read_with(cx, |worktree, _cx| { - worktree.as_local().unwrap().scan_complete() - })? - .await; - let output = cx - .update(|cx| { - let store = cx.new(|cx| { - edit_prediction::EditPredictionStore::new( - app_state.client.clone(), - app_state.user_store.clone(), - cx, - ) - }); - store.update(cx, |store, cx| { - store.set_options(zeta2_args_to_options(&args.zeta2_args)); - store.register_buffer(&buffer, &project, cx); - }); - cx.spawn(async move |cx| { - let updates_rx = store.update(cx, |store, cx| { - let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor); - store.set_use_context(true); - store.refresh_context(&project, &buffer, cursor, cx); - store.project_context_updates(&project).unwrap() - })?; - - updates_rx.recv().await.ok(); - - let context = store.update(cx, |store, cx| { - store.context_for_project(&project, cx).to_vec() - })?; - - anyhow::Ok(serde_json::to_string_pretty(&context).unwrap()) - }) - })? - .await?; - - Ok(output) -} - -async fn zeta1_context( - args: ContextArgs, - app_state: &Arc, - cx: &mut AsyncApp, -) -> Result { - let LoadedContext { - full_path_str, - snapshot, - clipped_cursor, - .. - } = load_context(&args, app_state, cx).await?; - - let events = match args.edit_history { - Some(events) => events.read_to_string().await?, - None => String::new(), - }; - - let prompt_for_events = move || (events, 0); - cx.update(|cx| { - edit_prediction::zeta1::gather_context( - full_path_str, - &snapshot, - clipped_cursor, - prompt_for_events, - cloud_llm_client::PredictEditsRequestTrigger::Cli, - cx, - ) - })? - .await -} - -fn main() { - zlog::init(); - zlog::init_output_stderr(); - let args = ZetaCliArgs::parse(); + let mut examples = read_examples(&args.inputs); let http_client = Arc::new(ReqwestClient::new()); let app = Application::headless().with_http_client(http_client); app.run(move |cx| { let app_state = Arc::new(headless::init(cx)); + EditPredictionStore::global(&app_state.client, &app_state.user_store, cx); + cx.spawn(async move |cx| { - match args.command { - None => { - if args.printenv { - ::util::shell_env::print_env(); - } else { - panic!("Expected a command"); - } - } - Some(Command::Context(context_args)) => { - let result = match context_args.provider { - ContextProvider::Zeta1 => { - let context = - zeta1_context(context_args, &app_state, cx).await.unwrap(); - serde_json::to_string_pretty(&context.body).unwrap() - } - ContextProvider::Zeta2 => { - zeta2_context(context_args, &app_state, cx).await.unwrap() + match &command { + Command::Predict(args) => predict::sync_batches(&args.provider).await, + _ => (), + }; + + for data in examples.chunks_mut(args.max_parallelism) { + let mut futures = Vec::new(); + for example in data.iter_mut() { + let cx = cx.clone(); + let app_state = app_state.clone(); + futures.push(async { + match &command { + Command::ParseExample => {} + Command::LoadBuffer => { + run_load_project(example, app_state.clone(), cx).await; + } + Command::Context => { + run_context_retrieval(example, app_state, cx).await; + } + Command::FormatPrompt(args) => { + run_format_prompt(example, args.prompt_format, app_state, cx).await; + } + Command::Predict(args) => { + run_prediction( + example, + Some(args.provider), + args.repetitions, + app_state.clone(), + cx, + ) + .await; + } + Command::Score(args) | Command::Eval(args) => { + run_scoring(example, &args, app_state, cx).await; + } + Command::Clean => { + unreachable!() + } } - }; - println!("{}", result); - } - Some(Command::Predict(arguments)) => { - run_predict(arguments, &app_state, cx).await; - } - Some(Command::Eval(arguments)) => { - run_evaluate(arguments, &app_state, cx).await; + }); } - Some(Command::Distill(arguments)) => { - let _guard = cx - .update(|cx| gpui_tokio::Tokio::handle(cx)) - .unwrap() - .enter(); - run_distill(arguments).await.log_err(); - } - Some(Command::ConvertExample { - path, - output_format, - }) => { - let example = NamedExample::load(path).unwrap(); - example.write(output_format, io::stdout()).unwrap(); - } - Some(Command::Score { - golden_patch, - actual_patch, - }) => { - let golden_content = std::fs::read_to_string(golden_patch).unwrap(); - let actual_content = std::fs::read_to_string(actual_patch).unwrap(); - - let golden_diff: Vec = golden_content - .lines() - .map(|line| DiffLine::parse(line)) - .collect(); + futures::future::join_all(futures).await; + } - let actual_diff: Vec = actual_content - .lines() - .map(|line| DiffLine::parse(line)) - .collect(); + if args.output.is_some() || !matches!(command, Command::Eval(_)) { + write_examples(&examples, output.as_ref()); + } - let score = delta_chr_f(&golden_diff, &actual_diff); - println!("{:.2}", score); - } - Some(Command::Clean) => { - std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap() - } + match &command { + Command::Predict(args) => predict::sync_batches(&args.provider).await, + Command::Eval(_) => score::print_report(&examples), + _ => (), }; let _ = cx.update(|cx| cx.quit()); diff --git a/crates/edit_prediction_cli/src/metrics.rs b/crates/edit_prediction_cli/src/metrics.rs index 0fdb7fb535df12d00341997a64a96b97867f6f28..b3e5eb8688724c821953a56c4fe82e67c75e13b6 100644 --- a/crates/edit_prediction_cli/src/metrics.rs +++ b/crates/edit_prediction_cli/src/metrics.rs @@ -1,30 +1,34 @@ use collections::{HashMap, HashSet}; use edit_prediction::udiff::DiffLine; +use serde::{Deserialize, Serialize}; type Counts = HashMap; type CountsDelta = HashMap; -#[derive(Default, Debug, Clone)] -pub struct Scores { +#[derive(Default, Debug, Clone, Serialize, Deserialize)] +pub struct ClassificationMetrics { pub true_positives: usize, pub false_positives: usize, pub false_negatives: usize, } -impl Scores { - pub fn from_sets(expected: &HashSet, actual: &HashSet) -> Scores { +impl ClassificationMetrics { + pub fn from_sets( + expected: &HashSet, + actual: &HashSet, + ) -> ClassificationMetrics { let true_positives = expected.intersection(actual).count(); let false_positives = actual.difference(expected).count(); let false_negatives = expected.difference(actual).count(); - Scores { + ClassificationMetrics { true_positives, false_positives, false_negatives, } } - pub fn from_counts(expected: &Counts, actual: &Counts) -> Scores { + pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics { let mut true_positives = 0; let mut false_positives = 0; let mut false_negatives = 0; @@ -45,32 +49,16 @@ impl Scores { } } - Scores { + ClassificationMetrics { true_positives, false_positives, false_negatives, } } - pub fn to_markdown(&self) -> String { - format!( - " -Precision : {:.4} -Recall : {:.4} -F1 Score : {:.4} -True Positives : {} -False Positives : {} -False Negatives : {}", - self.precision(), - self.recall(), - self.f1_score(), - self.true_positives, - self.false_positives, - self.false_negatives - ) - } - - pub fn aggregate<'a>(scores: impl Iterator) -> Scores { + pub fn aggregate<'a>( + scores: impl Iterator, + ) -> ClassificationMetrics { let mut true_positives = 0; let mut false_positives = 0; let mut false_negatives = 0; @@ -81,7 +69,7 @@ False Negatives : {}", false_negatives += score.false_negatives; } - Scores { + ClassificationMetrics { true_positives, false_positives, false_negatives, @@ -115,7 +103,10 @@ False Negatives : {}", } } -pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine]) -> Scores { +pub fn line_match_score( + expected_patch: &[DiffLine], + actual_patch: &[DiffLine], +) -> ClassificationMetrics { let expected_change_lines = expected_patch .iter() .filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_))) @@ -128,7 +119,7 @@ pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine]) .map(|line| line.to_string()) .collect(); - Scores::from_sets(&expected_change_lines, &actual_change_lines) + ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines) } enum ChrfWhitespace { @@ -204,7 +195,7 @@ pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 { let expected_counts = ngram_delta_to_counts(&expected_delta); let actual_counts = ngram_delta_to_counts(&actual_delta); - let score = Scores::from_counts(&expected_counts, &actual_counts); + let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts); total_precision += score.precision(); total_recall += score.recall(); } diff --git a/crates/edit_prediction_cli/src/paths.rs b/crates/edit_prediction_cli/src/paths.rs index 3cc2beec5bd50380b9eef8b502dcba0ccba32772..0f470fae556b6d61739ab77083d7edbedf77ef89 100644 --- a/crates/edit_prediction_cli/src/paths.rs +++ b/crates/edit_prediction_cli/src/paths.rs @@ -1,57 +1,25 @@ -use std::{env, path::PathBuf, sync::LazyLock}; +use std::{ + path::{Path, PathBuf}, + sync::LazyLock, +}; -pub static TARGET_ZETA_DIR: LazyLock = - LazyLock::new(|| env::current_dir().unwrap().join("target/zeta")); -pub static CACHE_DIR: LazyLock = LazyLock::new(|| TARGET_ZETA_DIR.join("cache")); -pub static REPOS_DIR: LazyLock = LazyLock::new(|| TARGET_ZETA_DIR.join("repos")); -pub static WORKTREES_DIR: LazyLock = LazyLock::new(|| TARGET_ZETA_DIR.join("worktrees")); +pub static DATA_DIR: LazyLock = LazyLock::new(|| { + let dir = dirs::home_dir().unwrap().join(".zed_ep"); + ensure_dir(&dir) +}); +pub static CACHE_DIR: LazyLock = LazyLock::new(|| ensure_dir(&DATA_DIR.join("cache"))); +pub static REPOS_DIR: LazyLock = LazyLock::new(|| ensure_dir(&DATA_DIR.join("repos"))); +pub static WORKTREES_DIR: LazyLock = + LazyLock::new(|| ensure_dir(&DATA_DIR.join("worktrees"))); pub static RUN_DIR: LazyLock = LazyLock::new(|| { - TARGET_ZETA_DIR + DATA_DIR .join("runs") .join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string()) }); -pub static LATEST_EXAMPLE_RUN_DIR: LazyLock = - LazyLock::new(|| TARGET_ZETA_DIR.join("latest")); - -pub fn print_run_data_dir(deep: bool, use_color: bool) { - println!("\n## Run Data\n"); - let mut files = Vec::new(); - - let current_dir = std::env::current_dir().unwrap(); - for file in std::fs::read_dir(&*RUN_DIR).unwrap() { - let file = file.unwrap(); - if file.file_type().unwrap().is_dir() && deep { - for file in std::fs::read_dir(file.path()).unwrap() { - let path = file.unwrap().path(); - let path = path.strip_prefix(¤t_dir).unwrap_or(&path); - files.push(format!( - "- {}/{}{}{}", - path.parent().unwrap().display(), - if use_color { "\x1b[34m" } else { "" }, - path.file_name().unwrap().display(), - if use_color { "\x1b[0m" } else { "" }, - )); - } - } else { - let path = file.path(); - let path = path.strip_prefix(¤t_dir).unwrap_or(&path); - files.push(format!( - "- {}/{}{}{}", - path.parent().unwrap().display(), - if use_color { "\x1b[34m" } else { "" }, - path.file_name().unwrap().display(), - if use_color { "\x1b[0m" } else { "" } - )); - } - } - files.sort(); - - for file in files { - println!("{}", file); - } +pub static LATEST_EXAMPLE_RUN_DIR: LazyLock = LazyLock::new(|| DATA_DIR.join("latest")); +pub static LLM_CACHE_DB: LazyLock = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite")); - println!( - "\n💡 Tip of the day: {} always points to the latest run\n", - LATEST_EXAMPLE_RUN_DIR.display() - ); +fn ensure_dir(path: &Path) -> PathBuf { + std::fs::create_dir_all(path).expect("Failed to create directory"); + path.to_path_buf() } diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 74e939b887ce15790993ec15f5973c7f5fd01866..11ed0e3bab0551d1e9d3e87cc98ef91ee015ac13 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -1,374 +1,271 @@ -use crate::example::{ActualExcerpt, NamedExample}; -use crate::headless::ZetaCliAppState; -use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir}; use crate::{ - CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat, + PredictionProvider, PromptFormat, + anthropic_client::AnthropicClient, + example::{Example, ExamplePrediction}, + format_prompt::{PromptParser, TeacherPrompt, run_format_prompt}, + headless::EpAppState, + load_project::run_load_project, + paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR}, + retrieve_context::run_context_retrieval, +}; +use edit_prediction::{DebugEvent, EditPredictionStore}; +use futures::{FutureExt as _, StreamExt as _, future::Shared}; +use gpui::{AppContext as _, AsyncApp, Task}; +use std::{ + fs, + sync::{ + Arc, Mutex, OnceLock, + atomic::{AtomicUsize, Ordering::SeqCst}, + }, }; -use ::serde::Serialize; -use anyhow::{Context, Result, anyhow}; -use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock}; -use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey}; -use futures::StreamExt as _; -use gpui::{AppContext, AsyncApp, Entity}; -use project::Project; -use project::buffer_store::BufferStoreEvent; -use serde::Deserialize; -use std::fs; -use std::io::{IsTerminal, Write}; -use std::path::PathBuf; -use std::sync::Arc; -use std::sync::Mutex; -use std::time::{Duration, Instant}; -pub async fn run_predict( - args: PredictArguments, - app_state: &Arc, - cx: &mut AsyncApp, +pub async fn run_prediction( + example: &mut Example, + provider: Option, + repetition_count: usize, + app_state: Arc, + mut cx: AsyncApp, ) { - let example = NamedExample::load(args.example_path).unwrap(); - let project = example.setup_project(app_state, cx).await.unwrap(); - let store = setup_store(args.options.provider, &project, app_state, cx).unwrap(); - let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap(); - let result = perform_predict(example, project, store, None, args.options, cx) - .await - .unwrap(); - result.write(args.format, std::io::stdout()).unwrap(); - - print_run_data_dir(true, std::io::stdout().is_terminal()); -} - -pub fn setup_store( - provider: PredictionProvider, - project: &Entity, - app_state: &Arc, - cx: &mut AsyncApp, -) -> Result> { - let store = cx.new(|cx| { - edit_prediction::EditPredictionStore::new( - app_state.client.clone(), - app_state.user_store.clone(), - cx, - ) - })?; + if !example.predictions.is_empty() { + return; + } - store.update(cx, |store, _cx| { - let model = match provider { - PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1, - PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2, - PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, - }; - store.set_edit_prediction_model(model); - })?; + run_load_project(example, app_state.clone(), cx.clone()).await; + run_context_retrieval(example, app_state.clone(), cx.clone()).await; - let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?; + let provider = provider.unwrap(); - cx.subscribe(&buffer_store, { - let project = project.clone(); - let store = store.clone(); - move |_, event, cx| match event { - BufferStoreEvent::BufferAdded(buffer) => { - store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx)); - } - _ => {} + if matches!(provider, PredictionProvider::Teacher) { + if example.prompt.is_none() { + run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await; } - })? - .detach(); - anyhow::Ok(store) -} - -pub async fn perform_predict( - example: NamedExample, - project: Entity, - store: Entity, - repetition_ix: Option, - options: PredictionOptions, - cx: &mut AsyncApp, -) -> Result { - let mut cache_mode = options.cache; - if repetition_ix.is_some() { - if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip { - panic!("Repetitions are not supported in Auto cache mode"); - } else { - cache_mode = CacheMode::Skip; - } - } else if cache_mode == CacheMode::Auto { - cache_mode = CacheMode::Requests; + let batched = true; + return predict_anthropic(example, repetition_count, batched).await; } - let mut example_run_dir = RUN_DIR.join(&example.file_name()); - if let Some(repetition_ix) = repetition_ix { - example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix)); - } - fs::create_dir_all(&example_run_dir)?; - if LATEST_EXAMPLE_RUN_DIR.is_symlink() { - fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?; + if matches!( + provider, + PredictionProvider::Zeta1 | PredictionProvider::Zeta2 + ) { + static AUTHENTICATED: OnceLock>> = OnceLock::new(); + AUTHENTICATED + .get_or_init(|| { + let client = app_state.client.clone(); + cx.spawn(async move |cx| { + client + .sign_in_with_optional_connect(true, cx) + .await + .unwrap(); + }) + .shared() + }) + .clone() + .await; } - #[cfg(unix)] - std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR) - .context("creating latest link")?; - - #[cfg(windows)] - std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR) - .context("creating latest link")?; - - store.update(cx, |store, _cx| { - store.with_eval_cache(Arc::new(RunCache { - example_run_dir: example_run_dir.clone(), - cache_mode, - })); - })?; - - let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?; - - let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone()))); - - let prompt_format = options.zeta2.prompt_format; - - store.update(cx, |store, _cx| { - let mut options = store.options().clone(); - options.prompt_format = prompt_format.into(); - store.set_options(options); - })?; + let ep_store = cx + .update(|cx| EditPredictionStore::try_global(cx).unwrap()) + .unwrap(); - let mut debug_task = gpui::Task::ready(Ok(())); + ep_store + .update(&mut cx, |store, _cx| { + let model = match provider { + PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1, + PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2, + PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep, + PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, + PredictionProvider::Teacher => unreachable!(), + }; + store.set_edit_prediction_model(model); + }) + .unwrap(); + let state = example.state.as_ref().unwrap(); + let run_dir = RUN_DIR.join(&example.name); - if options.provider == crate::PredictionProvider::Zeta2 { - let mut debug_rx = store.update(cx, |store, _| store.debug_info())?; + let updated_example = Arc::new(Mutex::new(example.clone())); + let current_run_ix = Arc::new(AtomicUsize::new(0)); - debug_task = cx.background_spawn({ - let result = result.clone(); - async move { - let mut start_time = None; - let mut retrieval_finished_at = None; - while let Some(event) = debug_rx.next().await { - match event { - edit_prediction::DebugEvent::ContextRetrievalStarted(info) => { - start_time = Some(info.timestamp); - fs::write( - example_run_dir.join("search_prompt.md"), - &info.search_prompt, - )?; + let mut debug_rx = ep_store + .update(&mut cx, |store, cx| store.debug_info(&state.project, cx)) + .unwrap(); + let debug_task = cx.background_spawn({ + let updated_example = updated_example.clone(); + let current_run_ix = current_run_ix.clone(); + let run_dir = run_dir.clone(); + async move { + while let Some(event) = debug_rx.next().await { + let run_ix = current_run_ix.load(SeqCst); + let mut updated_example = updated_example.lock().unwrap(); + + let run_dir = if repetition_count > 1 { + run_dir.join(format!("{:03}", run_ix)) + } else { + run_dir.clone() + }; + + match event { + DebugEvent::EditPredictionStarted(request) => { + assert_eq!(updated_example.predictions.len(), run_ix + 1); + + if let Some(prompt) = request.prompt { + fs::write(run_dir.join("prediction_prompt.md"), &prompt)?; } - edit_prediction::DebugEvent::ContextRetrievalFinished(info) => { - retrieval_finished_at = Some(info.timestamp); - for (key, value) in &info.metadata { - if *key == "search_queries" { - fs::write( - example_run_dir.join("search_queries.json"), - value.as_bytes(), - )?; - } - } + } + DebugEvent::EditPredictionFinished(request) => { + assert_eq!(updated_example.predictions.len(), run_ix + 1); + + if let Some(output) = request.model_output { + fs::write(run_dir.join("prediction_response.md"), &output)?; + updated_example + .predictions + .last_mut() + .unwrap() + .actual_output = output; } - edit_prediction::DebugEvent::EditPredictionRequested(request) => { - let prediction_started_at = Instant::now(); - start_time.get_or_insert(prediction_started_at); - let prompt = request.local_prompt.unwrap_or_default(); - fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?; - - { - let mut result = result.lock().unwrap(); - result.prompt_len = prompt.chars().count(); - - for included_file in request.inputs.included_files { - let insertions = - vec![(request.inputs.cursor_point, CURSOR_MARKER)]; - result.excerpts.extend(included_file.excerpts.iter().map( - |excerpt| ActualExcerpt { - path: included_file.path.components().skip(1).collect(), - text: String::from(excerpt.text.as_ref()), - }, - )); - write_codeblock( - &included_file.path, - included_file.excerpts.iter(), - if included_file.path == request.inputs.cursor_path { - &insertions - } else { - &[] - }, - included_file.max_row, - false, - &mut result.excerpts_text, - ); - } - } - - let response = - request.response_rx.await?.0.map_err(|err| anyhow!(err))?; - let response = - edit_prediction::open_ai_response::text_from_response(response) - .unwrap_or_default(); - let prediction_finished_at = Instant::now(); - fs::write(example_run_dir.join("prediction_response.md"), &response)?; - - let mut result = result.lock().unwrap(); - result.generated_len = response.chars().count(); - result.retrieval_time = - retrieval_finished_at.unwrap() - start_time.unwrap(); - result.prediction_time = prediction_finished_at - prediction_started_at; - result.total_time = prediction_finished_at - start_time.unwrap(); - + if run_ix >= repetition_count { break; } } + _ => {} } - anyhow::Ok(()) } - }); - - store.update(cx, |store, cx| { - store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx) - })?; - } - - let prediction = store - .update(cx, |store, cx| { - store.request_prediction( - &project, - &cursor_buffer, - cursor_anchor, - cloud_llm_client::PredictEditsRequestTrigger::Cli, - cx, - ) - })? - .await?; - - debug_task.await?; - - let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap(); - - result.diff = prediction - .and_then(|prediction| { - let prediction = prediction.prediction.ok()?; - prediction.edit_preview.as_unified_diff(&prediction.edits) - }) - .unwrap_or_default(); - - anyhow::Ok(result) -} - -struct RunCache { - cache_mode: CacheMode, - example_run_dir: PathBuf, -} + anyhow::Ok(()) + } + }); -impl RunCache { - fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf { - CACHE_DIR.join(format!("{kind}_out_{key:x}.json",)) - } + for ix in 0..repetition_count { + current_run_ix.store(ix, SeqCst); + let run_dir = if repetition_count > 1 { + run_dir.join(format!("{:03}", ix)) + } else { + run_dir.clone() + }; - fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf { - CACHE_DIR.join(format!("{kind}_in_{key:x}.json",)) + fs::create_dir_all(&run_dir).unwrap(); + if LATEST_EXAMPLE_RUN_DIR.is_symlink() { + fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap(); + } + #[cfg(unix)] + std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap(); + #[cfg(windows)] + std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap(); + + updated_example + .lock() + .unwrap() + .predictions + .push(ExamplePrediction { + actual_patch: String::new(), + actual_output: String::new(), + provider, + }); + + let prediction = ep_store + .update(&mut cx, |store, cx| { + store.request_prediction( + &state.project, + &state.buffer, + state.cursor_position, + cloud_llm_client::PredictEditsRequestTrigger::Cli, + cx, + ) + }) + .unwrap() + .await + .unwrap(); + + updated_example + .lock() + .unwrap() + .predictions + .last_mut() + .unwrap() + .actual_patch = prediction + .and_then(|prediction| { + let prediction = prediction.prediction.ok()?; + prediction.edit_preview.as_unified_diff(&prediction.edits) + }) + .unwrap_or_default(); } - fn link_to_run(&self, key: &EvalCacheKey) { - let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0)); - fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap(); + ep_store + .update(&mut cx, |store, _| { + store.remove_project(&state.project); + }) + .unwrap(); + debug_task.await.unwrap(); - let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0)); - fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap(); - } + *example = Arc::into_inner(updated_example) + .unwrap() + .into_inner() + .unwrap(); } -impl EvalCache for RunCache { - fn read(&self, key: EvalCacheKey) -> Option { - let path = RunCache::output_cache_path(&key); - - if path.exists() { - let use_cache = match key.0 { - EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(), - EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => { - self.cache_mode.use_cached_llm_responses() - } - }; - if use_cache { - log::info!("Using cache entry: {}", path.display()); - self.link_to_run(&key); - Some(fs::read_to_string(path).unwrap()) - } else { - log::trace!("Skipping cached entry: {}", path.display()); - None - } - } else if matches!(self.cache_mode, CacheMode::Force) { - panic!( - "No cached entry found for {:?}. Run without `--cache force` at least once.", - key.0 - ); - } else { - None - } - } - - fn write(&self, key: EvalCacheKey, input: &str, output: &str) { - fs::create_dir_all(&*CACHE_DIR).unwrap(); +async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) { + let llm_model_name = "claude-sonnet-4-5"; + let max_tokens = 16384; + let llm_client = if batched { + AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref()) + } else { + AnthropicClient::plain() + }; + let llm_client = llm_client.expect("Failed to create LLM client"); + + let prompt = example + .prompt + .as_ref() + .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name)); + + let messages = vec![anthropic::Message { + role: anthropic::Role::User, + content: vec![anthropic::RequestContent::Text { + text: prompt.input.clone(), + cache_control: None, + }], + }]; + + let Some(response) = llm_client + .generate(llm_model_name, max_tokens, messages) + .await + .unwrap() + else { + // Request stashed for batched processing + return; + }; + + let actual_output = response + .content + .into_iter() + .filter_map(|content| match content { + anthropic::ResponseContent::Text { text } => Some(text), + _ => None, + }) + .collect::>() + .join("\n"); - let input_path = RunCache::input_cache_path(&key); - fs::write(&input_path, input).unwrap(); + let actual_patch = TeacherPrompt::parse(example, &actual_output); - let output_path = RunCache::output_cache_path(&key); - log::trace!("Writing cache entry: {}", output_path.display()); - fs::write(&output_path, output).unwrap(); + let prediction = ExamplePrediction { + actual_patch, + actual_output, + provider: PredictionProvider::Teacher, + }; - self.link_to_run(&key); - } + example.predictions.push(prediction); } -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct PredictionDetails { - pub diff: String, - pub excerpts: Vec, - pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly - pub retrieval_time: Duration, - pub prediction_time: Duration, - pub total_time: Duration, - pub run_example_dir: PathBuf, - pub prompt_len: usize, - pub generated_len: usize, -} - -impl PredictionDetails { - pub fn new(run_example_dir: PathBuf) -> Self { - Self { - diff: Default::default(), - excerpts: Default::default(), - excerpts_text: Default::default(), - retrieval_time: Default::default(), - prediction_time: Default::default(), - total_time: Default::default(), - run_example_dir, - prompt_len: 0, - generated_len: 0, +pub async fn sync_batches(provider: &PredictionProvider) { + match provider { + PredictionProvider::Teacher => { + let cache_path = crate::paths::LLM_CACHE_DB.as_ref(); + let llm_client = + AnthropicClient::batch(cache_path).expect("Failed to create LLM client"); + llm_client + .sync_batches() + .await + .expect("Failed to sync batches"); } - } - - pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> { - let formatted = match format { - PredictionsOutputFormat::Md => self.to_markdown(), - PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?, - PredictionsOutputFormat::Diff => self.diff.clone(), - }; - - Ok(out.write_all(formatted.as_bytes())?) - } - - pub fn to_markdown(&self) -> String { - format!( - "## Excerpts\n\n\ - {}\n\n\ - ## Prediction\n\n\ - {}\n\n\ - ## Time\n\n\ - Retrieval: {}ms\n\ - Prediction: {}ms\n\n\ - Total: {}ms\n", - self.excerpts_text, - self.diff, - self.retrieval_time.as_millis(), - self.prediction_time.as_millis(), - self.total_time.as_millis(), - ) + _ => (), } } diff --git a/crates/edit_prediction_cli/src/util.rs b/crates/edit_prediction_cli/src/retrieve_context.rs similarity index 53% rename from crates/edit_prediction_cli/src/util.rs rename to crates/edit_prediction_cli/src/retrieve_context.rs index f4a51d94585f82da008ac832dc62392c365738fd..2344b4250e2dd0d3a94928b05689377dcabba84a 100644 --- a/crates/edit_prediction_cli/src/util.rs +++ b/crates/edit_prediction_cli/src/retrieve_context.rs @@ -1,106 +1,136 @@ -use anyhow::{Result, anyhow}; -use futures::channel::mpsc; -use futures::{FutureExt as _, StreamExt as _}; +use crate::{ + example::{Example, ExampleContext}, + headless::EpAppState, + load_project::run_load_project, +}; +use anyhow::Result; +use collections::HashSet; +use edit_prediction::{DebugEvent, EditPredictionStore}; +use futures::{FutureExt as _, StreamExt as _, channel::mpsc}; use gpui::{AsyncApp, Entity, Task}; -use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus}; -use project::lsp_store::OpenLspBufferHandle; -use project::{Project, ProjectPath, Worktree}; -use std::collections::HashSet; -use std::sync::Arc; -use std::time::Duration; -use util::rel_path::RelPath; - -pub fn open_buffer( - project: Entity, - worktree: Entity, - path: Arc, - cx: &AsyncApp, -) -> Task>> { - cx.spawn(async move |cx| { - let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath { - worktree_id: worktree.id(), - path, - })?; - - let buffer = project - .update(cx, |project, cx| project.open_buffer(project_path, cx))? - .await?; - - let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?; - while *parse_status.borrow() != ParseStatus::Idle { - parse_status.changed().await?; +use language::{Buffer, LanguageNotFound}; +use project::Project; +use std::{sync::Arc, time::Duration}; + +pub async fn run_context_retrieval( + example: &mut Example, + app_state: Arc, + mut cx: AsyncApp, +) { + if example.context.is_some() { + return; + } + + run_load_project(example, app_state.clone(), cx.clone()).await; + + let state = example.state.as_ref().unwrap(); + let project = state.project.clone(); + + let _lsp_handle = project + .update(&mut cx, |project, cx| { + project.register_buffer_with_language_servers(&state.buffer, cx) + }) + .unwrap(); + + wait_for_language_server_to_start(example, &project, &state.buffer, &mut cx).await; + + let ep_store = cx + .update(|cx| EditPredictionStore::try_global(cx).unwrap()) + .unwrap(); + + let mut events = ep_store + .update(&mut cx, |store, cx| { + store.register_buffer(&state.buffer, &project, cx); + store.set_use_context(true); + store.refresh_context(&project, &state.buffer, state.cursor_position, cx); + store.debug_info(&project, cx) + }) + .unwrap(); + + while let Some(event) = events.next().await { + match event { + DebugEvent::ContextRetrievalFinished(_) => { + break; + } + _ => {} } + } - Ok(buffer) - }) + let context_files = ep_store + .update(&mut cx, |store, cx| store.context_for_project(&project, cx)) + .unwrap(); + + example.context = Some(ExampleContext { + files: context_files, + }); } -pub async fn open_buffer_with_language_server( - project: Entity, - worktree: Entity, - path: Arc, - ready_languages: &mut HashSet, +async fn wait_for_language_server_to_start( + example: &Example, + project: &Entity, + buffer: &Entity, cx: &mut AsyncApp, -) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity)> { - let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?; - - let (lsp_open_handle, path_style) = project.update(cx, |project, cx| { - ( - project.register_buffer_with_language_servers(&buffer, cx), - project.path_style(cx), - ) - })?; - - let language_registry = project.read_with(cx, |project, _| project.languages().clone())?; +) { + let language_registry = project + .read_with(cx, |project, _| project.languages().clone()) + .unwrap(); let result = language_registry - .load_language_for_file_path(path.as_std_path()) + .load_language_for_file_path(&example.cursor_path) .await; if let Err(error) = result && !error.is::() { - anyhow::bail!(error); + panic!("Failed to load language for file path: {}", error); } - let Some(language_id) = buffer.read_with(cx, |buffer, _cx| { - buffer.language().map(|language| language.id()) - })? + let Some(language_id) = buffer + .read_with(cx, |buffer, _cx| { + buffer.language().map(|language| language.id()) + }) + .unwrap() else { - return Err(anyhow!("No language for {}", path.display(path_style))); + panic!("No language for {:?}", example.cursor_path); }; - let log_prefix = format!("{} | ", path.display(path_style)); + let mut ready_languages = HashSet::default(); + let log_prefix = format!("{} | ", example.name); if !ready_languages.contains(&language_id) { - wait_for_lang_server(&project, &buffer, log_prefix, cx).await?; + wait_for_lang_server(&project, &buffer, log_prefix, cx) + .await + .unwrap(); ready_languages.insert(language_id); } - let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?; + let lsp_store = project + .read_with(cx, |project, _cx| project.lsp_store()) + .unwrap(); // hacky wait for buffer to be registered with the language server for _ in 0..100 { - let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| { - buffer.update(cx, |buffer, cx| { - lsp_store - .language_servers_for_local_buffer(&buffer, cx) - .next() - .map(|(_, language_server)| language_server.server_id()) + if lsp_store + .update(cx, |lsp_store, cx| { + buffer.update(cx, |buffer, cx| { + lsp_store + .language_servers_for_local_buffer(&buffer, cx) + .next() + .map(|(_, language_server)| language_server.server_id()) + }) }) - })? - else { + .unwrap() + .is_some() + { + return; + } else { cx.background_executor() .timer(Duration::from_millis(10)) .await; - continue; - }; - - return Ok((lsp_open_handle, language_server_id, buffer)); + } } - return Err(anyhow!("No language server found for buffer")); + panic!("No language server found for buffer"); } -// TODO: Dedupe with similar function in crates/eval/src/instance.rs pub fn wait_for_lang_server( project: &Entity, buffer: &Entity, diff --git a/crates/edit_prediction_cli/src/score.rs b/crates/edit_prediction_cli/src/score.rs new file mode 100644 index 0000000000000000000000000000000000000000..88ec5d5831c763b604c53d762a1ea9722e7279cb --- /dev/null +++ b/crates/edit_prediction_cli/src/score.rs @@ -0,0 +1,119 @@ +use crate::{ + PredictArgs, + example::{Example, ExampleScore}, + headless::EpAppState, + metrics::{self, ClassificationMetrics}, + predict::run_prediction, +}; +use edit_prediction::udiff::DiffLine; +use gpui::AsyncApp; +use std::sync::Arc; + +pub async fn run_scoring( + example: &mut Example, + args: &PredictArgs, + app_state: Arc, + cx: AsyncApp, +) { + run_prediction( + example, + Some(args.provider), + args.repetitions, + app_state, + cx, + ) + .await; + + let expected_patch = parse_patch(&example.expected_patch); + + let mut scores = vec![]; + + for pred in &example.predictions { + let actual_patch = parse_patch(&pred.actual_patch); + let line_match = metrics::line_match_score(&expected_patch, &actual_patch); + let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32; + + scores.push(ExampleScore { + delta_chr_f, + line_match, + }); + } + + example.score = scores; +} + +fn parse_patch(patch: &str) -> Vec> { + patch.lines().map(DiffLine::parse).collect() +} + +pub fn print_report(examples: &[Example]) { + eprintln!( + "──────────────────────────────────────────────────────────────────────────────────────" + ); + eprintln!( + "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}", + "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF" + ); + eprintln!( + "──────────────────────────────────────────────────────────────────────────────────────" + ); + + let mut all_line_match_scores = Vec::new(); + let mut all_delta_chr_f_scores = Vec::new(); + + for example in examples { + for score in example.score.iter() { + let line_match = &score.line_match; + + eprintln!( + "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}", + truncate_name(&example.name, 30), + line_match.true_positives, + line_match.false_positives, + line_match.false_negatives, + line_match.precision() * 100.0, + line_match.recall() * 100.0, + line_match.f1_score() * 100.0, + score.delta_chr_f + ); + + all_line_match_scores.push(line_match.clone()); + all_delta_chr_f_scores.push(score.delta_chr_f); + } + } + + eprintln!( + "──────────────────────────────────────────────────────────────────────────────────────" + ); + + if !all_line_match_scores.is_empty() { + let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter()); + let avg_delta_chr_f: f32 = + all_delta_chr_f_scores.iter().sum::() / all_delta_chr_f_scores.len() as f32; + + eprintln!( + "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}", + "TOTAL", + total_line_match.true_positives, + total_line_match.false_positives, + total_line_match.false_negatives, + total_line_match.precision() * 100.0, + total_line_match.recall() * 100.0, + total_line_match.f1_score() * 100.0, + avg_delta_chr_f + ); + eprintln!( + "──────────────────────────────────────────────────────────────────────────────────────" + ); + } + + eprintln!("\n"); +} + +fn truncate_name(name: &str, max_len: usize) -> String { + if name.len() <= max_len { + name.to_string() + } else { + format!("{}...", &name[..max_len - 3]) + } +} diff --git a/crates/edit_prediction_cli/src/source_location.rs b/crates/edit_prediction_cli/src/source_location.rs deleted file mode 100644 index 3438675e78ac4d8bba6f58f7ce8a9016aed6c0c7..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_cli/src/source_location.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::{fmt, fmt::Display, path::Path, str::FromStr, sync::Arc}; - -use ::util::{paths::PathStyle, rel_path::RelPath}; -use anyhow::{Result, anyhow}; -use language::Point; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; - -#[derive(Debug, Clone, Hash, Eq, PartialEq)] -pub struct SourceLocation { - pub path: Arc, - pub point: Point, -} - -impl Serialize for SourceLocation { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str(&self.to_string()) - } -} - -impl<'de> Deserialize<'de> for SourceLocation { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let s = String::deserialize(deserializer)?; - s.parse().map_err(serde::de::Error::custom) - } -} - -impl Display for SourceLocation { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{}:{}:{}", - self.path.display(PathStyle::Posix), - self.point.row + 1, - self.point.column + 1 - ) - } -} - -impl FromStr for SourceLocation { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - let parts: Vec<&str> = s.split(':').collect(); - if parts.len() != 3 { - return Err(anyhow!( - "Invalid source location. Expected 'file.rs:line:column', got '{}'", - s - )); - } - - let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc(); - let line: u32 = parts[1] - .parse() - .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?; - let column: u32 = parts[2] - .parse() - .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?; - - // Convert from 1-based to 0-based indexing - let point = Point::new(line.saturating_sub(1), column.saturating_sub(1)); - - Ok(SourceLocation { path, point }) - } -} diff --git a/crates/edit_prediction_cli/src/training/teacher.prompt.md b/crates/edit_prediction_cli/src/teacher.prompt.md similarity index 98% rename from crates/edit_prediction_cli/src/training/teacher.prompt.md rename to crates/edit_prediction_cli/src/teacher.prompt.md index af67c871ef31a21a8744bf71375a50128d9699b6..238d3b7ac1297583727f562f1755d084ff5a3ceb 100644 --- a/crates/edit_prediction_cli/src/training/teacher.prompt.md +++ b/crates/edit_prediction_cli/src/teacher.prompt.md @@ -46,3 +46,7 @@ Output example: ## Code Context {{context}} + +## Editable region + +{{editable_region}} diff --git a/crates/edit_prediction_cli/src/training/context.rs b/crates/edit_prediction_cli/src/training/context.rs deleted file mode 100644 index 7b6d9cc19c1c3750bbf03158ceec5c79a9df0340..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_cli/src/training/context.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::path::Path; - -use crate::{source_location::SourceLocation, training::teacher::TeacherModel}; - -#[derive(Debug, Clone, Default, clap::ValueEnum)] -pub enum ContextType { - #[default] - CurrentFile, -} - -const MAX_CONTEXT_SIZE: usize = 32768; - -pub fn collect_context( - context_type: &ContextType, - worktree_dir: &Path, - cursor: SourceLocation, -) -> String { - let context = match context_type { - ContextType::CurrentFile => { - let file_path = worktree_dir.join(cursor.path.as_std_path()); - let context = std::fs::read_to_string(&file_path).unwrap_or_default(); - - let context = add_special_tags(&context, worktree_dir, cursor); - context - } - }; - - let region_end_offset = context.find(TeacherModel::REGION_END); - - if context.len() <= MAX_CONTEXT_SIZE { - return context; - } - - if let Some(region_end_offset) = region_end_offset - && region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE - { - let to_truncate = context.len() - MAX_CONTEXT_SIZE; - format!( - "[...{} bytes truncated]\n{}\n", - to_truncate, - &context[to_truncate..] - ) - } else { - format!( - "{}\n[...{} bytes truncated]\n", - &context[..MAX_CONTEXT_SIZE], - context.len() - MAX_CONTEXT_SIZE - ) - } -} - -/// Add <|editable_region_start/end|> tags -fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String { - let path = worktree_dir.join(cursor.path.as_std_path()); - let file = std::fs::read_to_string(&path).unwrap_or_default(); - let lines = file.lines().collect::>(); - let cursor_row = cursor.point.row as usize; - let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE); - let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len()); - - let snippet = lines[start_line..end_line].join("\n"); - - if context.contains(&snippet) { - let mut cursor_line = lines[cursor_row].to_string(); - cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR); - - let mut snippet_with_tags_lines = vec![]; - snippet_with_tags_lines.push(TeacherModel::REGION_START); - snippet_with_tags_lines.extend(&lines[start_line..cursor_row]); - snippet_with_tags_lines.push(&cursor_line); - snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]); - snippet_with_tags_lines.push(TeacherModel::REGION_END); - let snippet_with_tags = snippet_with_tags_lines.join("\n"); - - context.replace(&snippet, &snippet_with_tags) - } else { - log::warn!( - "Can't find area around the cursor in the context; proceeding without special tags" - ); - context.to_string() - } -} - -pub fn strip_special_tags(context: &str) -> String { - context - .replace(TeacherModel::REGION_START, "") - .replace(TeacherModel::REGION_END, "") - .replace(TeacherModel::USER_CURSOR, "") -} diff --git a/crates/edit_prediction_cli/src/training/distill.rs b/crates/edit_prediction_cli/src/training/distill.rs deleted file mode 100644 index 277e35551a9fbce43982de832de5ccecf8d6e92e..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_cli/src/training/distill.rs +++ /dev/null @@ -1,94 +0,0 @@ -use serde::Deserialize; -use std::sync::Arc; - -use crate::{ - DistillArguments, - example::Example, - source_location::SourceLocation, - training::{ - context::ContextType, - llm_client::LlmClient, - teacher::{TeacherModel, TeacherOutput}, - }, -}; -use anyhow::Result; -use reqwest_client::ReqwestClient; - -#[derive(Debug, Deserialize)] -pub struct SplitCommit { - repo_url: String, - commit_sha: String, - edit_history: String, - expected_patch: String, - cursor_position: String, -} - -pub async fn run_distill(arguments: DistillArguments) -> Result<()> { - let split_commits: Vec = std::fs::read_to_string(&arguments.split_commit_dataset) - .expect("Failed to read split commit dataset") - .lines() - .map(|line| serde_json::from_str(line).expect("Failed to parse JSON line")) - .collect(); - - let http_client: Arc = Arc::new(ReqwestClient::new()); - - let llm_client = if let Some(cache_path) = arguments.batch { - LlmClient::batch(&cache_path, http_client)? - } else { - LlmClient::plain(http_client)? - }; - - let mut teacher = TeacherModel::new( - "claude-sonnet-4-5".to_string(), - ContextType::CurrentFile, - llm_client, - ); - - let mut num_marked_for_batching = 0; - - for commit in split_commits { - if let Some(distilled) = distill_one(&mut teacher, commit).await? { - println!("{}", serde_json::to_string(&distilled)?); - } else { - if num_marked_for_batching == 0 { - log::warn!("Marked for batching"); - } - num_marked_for_batching += 1; - } - } - - eprintln!( - "{} requests are marked for batching", - num_marked_for_batching - ); - let llm_client = teacher.client; - llm_client.sync_batches().await?; - - Ok(()) -} - -pub async fn distill_one( - teacher: &mut TeacherModel, - commit: SplitCommit, -) -> Result> { - let cursor: SourceLocation = commit - .cursor_position - .parse() - .expect("Failed to parse cursor position"); - - let path = cursor.path.to_rel_path_buf(); - - let example = Example { - repository_url: commit.repo_url, - revision: commit.commit_sha, - uncommitted_diff: commit.edit_history.clone(), - cursor_path: path.as_std_path().to_path_buf(), - cursor_position: commit.cursor_position, - edit_history: commit.edit_history, // todo: trim - expected_patch: commit.expected_patch, - }; - - let prediction = teacher.predict(example).await; - - prediction -} diff --git a/crates/edit_prediction_cli/src/training/mod.rs b/crates/edit_prediction_cli/src/training/mod.rs deleted file mode 100644 index dc564c4dc86c8e095e8e93ccbdfb29d3313e922a..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_cli/src/training/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod context; -pub mod distill; -pub mod llm_client; -pub mod teacher; diff --git a/crates/edit_prediction_cli/src/training/teacher.rs b/crates/edit_prediction_cli/src/training/teacher.rs deleted file mode 100644 index 99672db8f99a87b99a43c8876db2fd0c2f307b21..0000000000000000000000000000000000000000 --- a/crates/edit_prediction_cli/src/training/teacher.rs +++ /dev/null @@ -1,266 +0,0 @@ -use crate::{ - example::Example, - source_location::SourceLocation, - training::{ - context::{ContextType, collect_context, strip_special_tags}, - llm_client::LlmClient, - }, -}; -use anthropic::{Message, RequestContent, ResponseContent, Role}; -use anyhow::Result; - -pub struct TeacherModel { - pub llm_name: String, - pub context: ContextType, - pub client: LlmClient, -} - -#[derive(Debug, serde::Serialize)] -pub struct TeacherOutput { - parsed_output: String, - prompt: String, - raw_llm_response: String, - context: String, - diff: String, -} - -impl TeacherModel { - const PROMPT: &str = include_str!("teacher.prompt.md"); - pub(crate) const REGION_START: &str = "<|editable_region_start|>\n"; - pub(crate) const REGION_END: &str = "<|editable_region_end|>"; - pub(crate) const USER_CURSOR: &str = "<|user_cursor|>"; - - /// Number of lines to include before the cursor position - pub(crate) const LEFT_CONTEXT_SIZE: usize = 5; - - /// Number of lines to include after the cursor position - pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5; - - /// Truncate edit history to this number of last lines - const MAX_HISTORY_LINES: usize = 128; - - pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self { - TeacherModel { - llm_name, - context, - client, - } - } - - pub async fn predict(&self, input: Example) -> Result> { - let name = input.unique_name(); - let worktree_dir = input.setup_worktree(name).await?; - let cursor: SourceLocation = input - .cursor_position - .parse() - .expect("Failed to parse cursor position"); - - let context = collect_context(&self.context, &worktree_dir, cursor.clone()); - let edit_history = Self::format_edit_history(&input.edit_history); - - let prompt = Self::PROMPT - .replace("{{context}}", &context) - .replace("{{edit_history}}", &edit_history); - - let messages = vec![Message { - role: Role::User, - content: vec![RequestContent::Text { - text: prompt.clone(), - cache_control: None, - }], - }]; - - let Some(response) = self - .client - .generate(self.llm_name.clone(), 16384, messages) - .await? - else { - return Ok(None); - }; - - let response_text = response - .content - .into_iter() - .filter_map(|content| match content { - ResponseContent::Text { text } => Some(text), - _ => None, - }) - .collect::>() - .join("\n"); - - let parsed_output = self.parse_response(&response_text); - - let original_editable_region = Self::extract_editable_region(&context); - let context_after_edit = context.replace(&original_editable_region, &parsed_output); - let context_after_edit = strip_special_tags(&context_after_edit); - let context_before_edit = strip_special_tags(&context); - let diff = language::unified_diff(&context_before_edit, &context_after_edit); - - // zeta distill --batch batch_results.txt - // zeta distill - // 1. Run `zeta distill <2000 examples <- all examples>` for the first time - // - store LLM requests in a batch, don't actual send the request - // - send the batch (2000 requests) after all inputs are processed - // 2. `zeta send-batches` - // - upload the batch to Anthropic - - // https://platform.claude.com/docs/en/build-with-claude/batch-processing - // https://crates.io/crates/anthropic-sdk-rust - - // - poll for results - // - when ready, store results in cache (a database) - // 3. `zeta distill` again - // - use the cached results this time - - Ok(Some(TeacherOutput { - parsed_output, - prompt, - raw_llm_response: response_text, - context, - diff, - })) - } - - fn parse_response(&self, content: &str) -> String { - let codeblock = Self::extract_last_codeblock(content); - let editable_region = Self::extract_editable_region(&codeblock); - - editable_region - } - - /// Extract content from the last code-fenced block if any, or else return content as is - fn extract_last_codeblock(text: &str) -> String { - let mut last_block = None; - let mut search_start = 0; - - while let Some(start) = text[search_start..].find("```") { - let start = start + search_start; - let bytes = text.as_bytes(); - let mut backtick_end = start; - - while backtick_end < bytes.len() && bytes[backtick_end] == b'`' { - backtick_end += 1; - } - - let backtick_count = backtick_end - start; - let closing_backticks = "`".repeat(backtick_count); - - if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) { - let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1]; - last_block = Some(code_block.to_string()); - search_start = backtick_end + end_pos + backtick_count; - } else { - break; - } - } - - last_block.unwrap_or_else(|| text.to_string()) - } - - fn extract_editable_region(text: &str) -> String { - let start = text - .find(Self::REGION_START) - .map_or(0, |pos| pos + Self::REGION_START.len()); - let end = text.find(Self::REGION_END).unwrap_or(text.len()); - - text[start..end].to_string() - } - - /// Truncates edit history to a maximum length and removes comments (unified diff garbage lines) - fn format_edit_history(edit_history: &str) -> String { - let lines = edit_history - .lines() - .filter(|&s| Self::is_content_line(s)) - .collect::>(); - - let history_lines = if lines.len() > Self::MAX_HISTORY_LINES { - &lines[lines.len() - Self::MAX_HISTORY_LINES..] - } else { - &lines - }; - history_lines.join("\n") - } - - fn is_content_line(s: &str) -> bool { - s.starts_with("-") - || s.starts_with("+") - || s.starts_with(" ") - || s.starts_with("---") - || s.starts_with("+++") - || s.starts_with("@@") - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_response() { - let teacher = TeacherModel::new( - "test".to_string(), - ContextType::CurrentFile, - LlmClient::dummy(), - ); - let response = "This is a test response."; - let parsed = teacher.parse_response(response); - assert_eq!(parsed, response.to_string()); - - let response = indoc::indoc! {" - Some thinking - - ````` - actual response - ````` - "}; - let parsed = teacher.parse_response(response); - assert_eq!(parsed, "actual response"); - } - - #[test] - fn test_extract_last_code_block() { - let text = indoc::indoc! {" - Some thinking - - ``` - first block - ``` - - ````` - last block - ````` - "}; - let last_block = TeacherModel::extract_last_codeblock(text); - assert_eq!(last_block, "last block"); - } - - #[test] - fn test_extract_editable_region() { - let teacher = TeacherModel::new( - "test".to_string(), - ContextType::CurrentFile, - LlmClient::dummy(), - ); - let response = indoc::indoc! {" - some lines - are - here - <|editable_region_start|> - one - two three - - <|editable_region_end|> - more - lines here - "}; - let parsed = teacher.parse_response(response); - assert_eq!( - parsed, - indoc::indoc! {" - one - two three - - "} - ); - } -} diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml index f113c3c46075ca70e61d8d07947d37502e8528e8..731ffc85d159e285ad497c29fba2f74179d4149b 100644 --- a/crates/edit_prediction_context/Cargo.toml +++ b/crates/edit_prediction_context/Cargo.toml @@ -26,6 +26,7 @@ serde.workspace = true smallvec.workspace = true tree-sitter.workspace = true util.workspace = true +zeta_prompt.workspace = true [dev-dependencies] env_logger.workspace = true diff --git a/crates/edit_prediction_context/src/assemble_excerpts.rs b/crates/edit_prediction_context/src/assemble_excerpts.rs index 15f4c03d653429af671c22d6b5abc652d282a38e..e337211cf90f0e4fbcb481f836e512b1ceb6477f 100644 --- a/crates/edit_prediction_context/src/assemble_excerpts.rs +++ b/crates/edit_prediction_context/src/assemble_excerpts.rs @@ -1,6 +1,6 @@ -use crate::RelatedExcerpt; use language::{BufferSnapshot, OffsetRangeExt as _, Point}; use std::ops::Range; +use zeta_prompt::RelatedExcerpt; #[cfg(not(test))] const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512; @@ -76,14 +76,9 @@ pub fn assemble_excerpts( input_ranges .into_iter() - .map(|range| { - let offset_range = range.to_offset(buffer); - RelatedExcerpt { - point_range: range, - anchor_range: buffer.anchor_before(offset_range.start) - ..buffer.anchor_after(offset_range.end), - text: buffer.as_rope().slice(offset_range), - } + .map(|range| RelatedExcerpt { + row_range: range.start.row..range.end.row, + text: buffer.text_for_range(range).collect(), }) .collect() } diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index d3aefaa6e4ec585dc7c90fee1e95de17e018f90f..15576a835d9b4b0781b1e3979edbed443fa40f62 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -3,13 +3,13 @@ use anyhow::Result; use collections::HashMap; use futures::{FutureExt, StreamExt as _, channel::mpsc, future}; use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity}; -use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _}; +use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _}; use project::{LocationLink, Project, ProjectPath}; -use serde::{Serialize, Serializer}; use smallvec::SmallVec; use std::{ collections::hash_map, ops::Range, + path::Path, sync::Arc, time::{Duration, Instant}, }; @@ -24,12 +24,14 @@ mod fake_definition_lsp; pub use cloud_llm_client::predict_edits_v3::Line; pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText}; +pub use zeta_prompt::{RelatedExcerpt, RelatedFile}; const IDENTIFIER_LINE_COUNT: u32 = 3; pub struct RelatedExcerptStore { project: WeakEntity, - related_files: Vec, + related_files: Arc<[RelatedFile]>, + related_file_buffers: Vec>, cache: HashMap>, update_tx: mpsc::UnboundedSender<(Entity, Anchor)>, identifier_line_count: u32, @@ -68,82 +70,6 @@ struct CachedDefinition { anchor_range: Range, } -#[derive(Clone, Debug, Serialize)] -pub struct RelatedFile { - #[serde(serialize_with = "serialize_project_path")] - pub path: ProjectPath, - #[serde(skip)] - pub buffer: WeakEntity, - pub excerpts: Vec, - pub max_row: u32, -} - -impl RelatedFile { - pub fn merge_excerpts(&mut self) { - self.excerpts.sort_unstable_by(|a, b| { - a.point_range - .start - .cmp(&b.point_range.start) - .then(b.point_range.end.cmp(&a.point_range.end)) - }); - - let mut index = 1; - while index < self.excerpts.len() { - if self.excerpts[index - 1] - .point_range - .end - .cmp(&self.excerpts[index].point_range.start) - .is_ge() - { - let removed = self.excerpts.remove(index); - if removed - .point_range - .end - .cmp(&self.excerpts[index - 1].point_range.end) - .is_gt() - { - self.excerpts[index - 1].point_range.end = removed.point_range.end; - self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end; - } - } else { - index += 1; - } - } - } -} - -#[derive(Clone, Debug, Serialize)] -pub struct RelatedExcerpt { - #[serde(skip)] - pub anchor_range: Range, - #[serde(serialize_with = "serialize_point_range")] - pub point_range: Range, - #[serde(serialize_with = "serialize_rope")] - pub text: Rope, -} - -fn serialize_project_path( - project_path: &ProjectPath, - serializer: S, -) -> Result { - project_path.path.serialize(serializer) -} - -fn serialize_rope(rope: &Rope, serializer: S) -> Result { - rope.to_string().serialize(serializer) -} - -fn serialize_point_range( - range: &Range, - serializer: S, -) -> Result { - [ - [range.start.row, range.start.column], - [range.end.row, range.end.column], - ] - .serialize(serializer) -} - const DEBOUNCE_DURATION: Duration = Duration::from_millis(100); impl EventEmitter for RelatedExcerptStore {} @@ -179,7 +105,8 @@ impl RelatedExcerptStore { RelatedExcerptStore { project: project.downgrade(), update_tx, - related_files: Vec::new(), + related_files: Vec::new().into(), + related_file_buffers: Vec::new(), cache: Default::default(), identifier_line_count: IDENTIFIER_LINE_COUNT, } @@ -193,8 +120,21 @@ impl RelatedExcerptStore { self.update_tx.unbounded_send((buffer, position)).ok(); } - pub fn related_files(&self) -> &[RelatedFile] { - &self.related_files + pub fn related_files(&self) -> Arc<[RelatedFile]> { + self.related_files.clone() + } + + pub fn related_files_with_buffers( + &self, + ) -> impl Iterator)> { + self.related_files + .iter() + .cloned() + .zip(self.related_file_buffers.iter().cloned()) + } + + pub fn set_related_files(&mut self, files: Vec) { + self.related_files = files.into(); } async fn fetch_excerpts( @@ -297,7 +237,8 @@ impl RelatedExcerptStore { } mean_definition_latency /= cache_miss_count.max(1) as u32; - let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?; + let (new_cache, related_files, related_file_buffers) = + rebuild_related_files(&project, new_cache, cx).await?; if let Some(file) = &file { log::debug!( @@ -309,7 +250,8 @@ impl RelatedExcerptStore { this.update(cx, |this, cx| { this.cache = new_cache; - this.related_files = related_files; + this.related_files = related_files.into(); + this.related_file_buffers = related_file_buffers; cx.emit(RelatedExcerptStoreEvent::FinishedRefresh { cache_hit_count, cache_miss_count, @@ -323,10 +265,16 @@ impl RelatedExcerptStore { } async fn rebuild_related_files( + project: &Entity, new_entries: HashMap>, cx: &mut AsyncApp, -) -> Result<(HashMap>, Vec)> { +) -> Result<( + HashMap>, + Vec, + Vec>, +)> { let mut snapshots = HashMap::default(); + let mut worktree_root_names = HashMap::default(); for entry in new_entries.values() { for definition in &entry.definitions { if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) { @@ -340,12 +288,22 @@ async fn rebuild_related_files( .read_with(cx, |buffer, _| buffer.snapshot())?, ); } + let worktree_id = definition.path.worktree_id; + if let hash_map::Entry::Vacant(e) = + worktree_root_names.entry(definition.path.worktree_id) + { + project.read_with(cx, |project, cx| { + if let Some(worktree) = project.worktree_for_id(worktree_id, cx) { + e.insert(worktree.read(cx).root_name().as_unix_str().to_string()); + } + })?; + } } } Ok(cx .background_spawn(async move { - let mut files = Vec::::new(); + let mut files = Vec::new(); let mut ranges_by_buffer = HashMap::<_, Vec>>::default(); let mut paths_by_buffer = HashMap::default(); for entry in new_entries.values() { @@ -369,16 +327,31 @@ async fn rebuild_related_files( continue; }; let excerpts = assemble_excerpts(snapshot, ranges); - files.push(RelatedFile { - path: project_path.clone(), - buffer: buffer.downgrade(), - excerpts, - max_row: snapshot.max_point().row, - }); + let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else { + continue; + }; + + let path = Path::new(&format!( + "{}/{}", + root_name, + project_path.path.as_unix_str() + )) + .into(); + + files.push(( + buffer, + RelatedFile { + path, + excerpts, + max_row: snapshot.max_point().row, + }, + )); } - files.sort_by_key(|file| file.path.clone()); - (new_entries, files) + files.sort_by_key(|(_, file)| file.path.clone()); + let (related_buffers, related_files) = files.into_iter().unzip(); + + (new_entries, related_files, related_buffers) }) .await) } diff --git a/crates/edit_prediction_context/src/edit_prediction_context_tests.rs b/crates/edit_prediction_context/src/edit_prediction_context_tests.rs index dba8d89e593ccb60e7eae5d091708e82debef0f5..d93a66081164a3fc70f7e1072d91a02bd9adbd37 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context_tests.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context_tests.rs @@ -48,7 +48,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) { &excerpts, &[ ( - "src/company.rs", + "root/src/company.rs", &[indoc! {" pub struct Company { owner: Arc, @@ -56,7 +56,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) { }"}], ), ( - "src/main.rs", + "root/src/main.rs", &[ indoc! {" pub struct Session { @@ -71,7 +71,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) { ], ), ( - "src/person.rs", + "root/src/person.rs", &[ indoc! {" impl Person { @@ -446,7 +446,7 @@ fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, & .iter() .map(|excerpt| excerpt.text.to_string()) .collect::>(); - (file.path.path.as_unix_str(), excerpts) + (file.path.to_str().unwrap(), excerpts) }) .collect::>(); let expected_excerpts = expected_files @@ -492,10 +492,10 @@ fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String { if excerpt.text.is_empty() { continue; } - if current_row < excerpt.point_range.start.row { + if current_row < excerpt.row_range.start { writeln!(&mut output, "…").unwrap(); } - current_row = excerpt.point_range.start.row; + current_row = excerpt.row_range.start; for line in excerpt.text.to_string().lines() { output.push_str(line); diff --git a/crates/edit_prediction_ui/Cargo.toml b/crates/edit_prediction_ui/Cargo.toml index fb846f35d76ae2f6478ef675f246e4d06fe5f469..d6fc45512132197a3b9e7bd200c3005efa52ae10 100644 --- a/crates/edit_prediction_ui/Cargo.toml +++ b/crates/edit_prediction_ui/Cargo.toml @@ -17,7 +17,6 @@ anyhow.workspace = true buffer_diff.workspace = true client.workspace = true cloud_llm_client.workspace = true -cloud_zeta2_prompt.workspace = true codestral.workspace = true command_palette_hooks.workspace = true copilot.workspace = true @@ -46,6 +45,7 @@ ui_input.workspace = true util.workspace = true workspace.workspace = true zed_actions.workspace = true +zeta_prompt.workspace = true [dev-dependencies] copilot = { workspace = true, features = ["test-support"] } diff --git a/crates/edit_prediction_ui/src/edit_prediction_context_view.rs b/crates/edit_prediction_ui/src/edit_prediction_context_view.rs index 0e343fe3fcb8ed7bb6bf3e8481927344d63133ee..92d66d2bec3a7a3b35678f1d4da92fae6b071633 100644 --- a/crates/edit_prediction_ui/src/edit_prediction_context_view.rs +++ b/crates/edit_prediction_ui/src/edit_prediction_context_view.rs @@ -17,7 +17,7 @@ use gpui::{ }; use multi_buffer::MultiBuffer; use project::Project; -use text::OffsetRangeExt; +use text::Point; use ui::{ ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName, StyledTypography as _, h_flex, v_flex, @@ -66,7 +66,7 @@ impl EditPredictionContextView { ) -> Self { let store = EditPredictionStore::global(client, user_store, cx); - let mut debug_rx = store.update(cx, |store, _| store.debug_info()); + let mut debug_rx = store.update(cx, |store, cx| store.debug_info(&project, cx)); let _update_task = cx.spawn_in(window, async move |this, cx| { while let Some(event) = debug_rx.next().await { this.update_in(cx, |this, window, cx| { @@ -103,7 +103,8 @@ impl EditPredictionContextView { self.handle_context_retrieval_finished(info, window, cx); } } - DebugEvent::EditPredictionRequested(_) => {} + DebugEvent::EditPredictionStarted(_) => {} + DebugEvent::EditPredictionFinished(_) => {} } } @@ -152,12 +153,11 @@ impl EditPredictionContextView { run.finished_at = Some(info.timestamp); run.metadata = info.metadata; - let project = self.project.clone(); let related_files = self .store .read(cx) - .context_for_project(&self.project, cx) - .to_vec(); + .context_for_project_with_buffers(&self.project, cx) + .map_or(Vec::new(), |files| files.collect()); let editor = run.editor.clone(); let multibuffer = run.editor.read(cx).buffer().clone(); @@ -168,33 +168,14 @@ impl EditPredictionContextView { cx.spawn_in(window, async move |this, cx| { let mut paths = Vec::new(); - for related_file in related_files { - let (buffer, point_ranges): (_, Vec<_>) = - if let Some(buffer) = related_file.buffer.upgrade() { - let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?; - - ( - buffer, - related_file - .excerpts - .iter() - .map(|excerpt| excerpt.anchor_range.to_point(&snapshot)) - .collect(), - ) - } else { - ( - project - .update(cx, |project, cx| { - project.open_buffer(related_file.path.clone(), cx) - })? - .await?, - related_file - .excerpts - .iter() - .map(|excerpt| excerpt.point_range.clone()) - .collect(), - ) - }; + for (related_file, buffer) in related_files { + let point_ranges = related_file + .excerpts + .iter() + .map(|excerpt| { + Point::new(excerpt.row_range.start, 0)..Point::new(excerpt.row_range.end, 0) + }) + .collect::>(); cx.update(|_, cx| { let path = PathKey::for_buffer(&buffer, cx); paths.push((path, buffer, point_ranges)); diff --git a/crates/edit_prediction_ui/src/rate_prediction_modal.rs b/crates/edit_prediction_ui/src/rate_prediction_modal.rs index 8e754b33dc18c5be60bc052c33aa08cdcb980acb..54933fbf904f8fc7146dcce9a6bd3340884cc8bf 100644 --- a/crates/edit_prediction_ui/src/rate_prediction_modal.rs +++ b/crates/edit_prediction_ui/src/rate_prediction_modal.rs @@ -1,5 +1,4 @@ use buffer_diff::{BufferDiff, BufferDiffSnapshot}; -use cloud_zeta2_prompt::write_codeblock; use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore}; use editor::{Editor, ExcerptRange, MultiBuffer}; use feature_flags::FeatureFlag; @@ -362,14 +361,14 @@ impl RatePredictionsModal { write!(&mut formatted_inputs, "## Events\n\n").unwrap(); for event in &prediction.inputs.events { - write!(&mut formatted_inputs, "```diff\n{event}```\n\n").unwrap(); + formatted_inputs.push_str("```diff\n"); + zeta_prompt::write_event(&mut formatted_inputs, event.as_ref()); + formatted_inputs.push_str("```\n\n"); } - write!(&mut formatted_inputs, "## Included files\n\n").unwrap(); - - for included_file in &prediction.inputs.included_files { - let cursor_insertions = &[(prediction.inputs.cursor_point, "<|CURSOR|>")]; + write!(&mut formatted_inputs, "## Related files\n\n").unwrap(); + for included_file in prediction.inputs.related_files.as_ref() { write!( &mut formatted_inputs, "### {}\n\n", @@ -377,20 +376,28 @@ impl RatePredictionsModal { ) .unwrap(); - write_codeblock( - &included_file.path, - &included_file.excerpts, - if included_file.path == prediction.inputs.cursor_path { - cursor_insertions.as_slice() - } else { - &[] - }, - included_file.max_row, - false, - &mut formatted_inputs, - ); + for excerpt in included_file.excerpts.iter() { + write!( + &mut formatted_inputs, + "```{}\n{}\n```\n", + included_file.path.display(), + excerpt.text + ) + .unwrap(); + } } + write!(&mut formatted_inputs, "## Cursor Excerpt\n\n").unwrap(); + + writeln!( + &mut formatted_inputs, + "```{}\n{}{}\n```\n", + prediction.inputs.cursor_path.display(), + &prediction.inputs.cursor_excerpt[..prediction.inputs.cursor_offset_in_excerpt], + &prediction.inputs.cursor_excerpt[prediction.inputs.cursor_offset_in_excerpt..], + ) + .unwrap(); + self.active_prediction = Some(ActivePrediction { prediction, feedback_editor: cx.new(|cx| { diff --git a/crates/zeta_prompt/Cargo.toml b/crates/zeta_prompt/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..c9b1e2d784d10ea2fd278f70ffdae2ef0981fce0 --- /dev/null +++ b/crates/zeta_prompt/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "zeta_prompt" +version = "0.1.0" +publish.workspace = true +edition.workspace = true +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/zeta_prompt.rs" + +[dependencies] +serde.workspace = true \ No newline at end of file diff --git a/crates/cloud_zeta2_prompt/LICENSE-GPL b/crates/zeta_prompt/LICENSE-GPL similarity index 100% rename from crates/cloud_zeta2_prompt/LICENSE-GPL rename to crates/zeta_prompt/LICENSE-GPL diff --git a/crates/zeta_prompt/src/zeta_prompt.rs b/crates/zeta_prompt/src/zeta_prompt.rs new file mode 100644 index 0000000000000000000000000000000000000000..21fbca1ae10b715d0c11a31dc9390aada03fa157 --- /dev/null +++ b/crates/zeta_prompt/src/zeta_prompt.rs @@ -0,0 +1,165 @@ +use serde::{Deserialize, Serialize}; +use std::fmt::Write; +use std::ops::Range; +use std::path::Path; +use std::sync::Arc; + +pub const CURSOR_MARKER: &str = "<|user_cursor|>"; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ZetaPromptInput { + pub cursor_path: Arc, + pub cursor_excerpt: Arc, + pub editable_range_in_excerpt: Range, + pub cursor_offset_in_excerpt: usize, + pub events: Vec>, + pub related_files: Arc<[RelatedFile]>, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "event")] +pub enum Event { + BufferChange { + path: Arc, + old_path: Arc, + diff: String, + predicted: bool, + in_open_source_repo: bool, + }, +} + +pub fn write_event(prompt: &mut String, event: &Event) { + fn write_path_as_unix_str(prompt: &mut String, path: &Path) { + for component in path.components() { + prompt.push('/'); + write!(prompt, "{}", component.as_os_str().display()).ok(); + } + } + match event { + Event::BufferChange { + path, + old_path, + diff, + predicted, + in_open_source_repo: _, + } => { + if *predicted { + prompt.push_str("// User accepted prediction:\n"); + } + prompt.push_str("--- a"); + write_path_as_unix_str(prompt, old_path.as_ref()); + prompt.push_str("\n+++ b"); + write_path_as_unix_str(prompt, path.as_ref()); + prompt.push('\n'); + prompt.push_str(diff); + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RelatedFile { + pub path: Arc, + pub max_row: u32, + pub excerpts: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RelatedExcerpt { + pub row_range: Range, + pub text: String, +} + +pub fn format_zeta_prompt(input: &ZetaPromptInput) -> String { + let mut prompt = String::new(); + write_related_files(&mut prompt, &input.related_files); + write_edit_history_section(&mut prompt, input); + write_cursor_excerpt_section(&mut prompt, input); + prompt +} + +pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) { + push_delimited(prompt, "related_files", &[], |prompt| { + for file in related_files { + let path_str = file.path.to_string_lossy(); + push_delimited(prompt, "related_file", &[("path", &path_str)], |prompt| { + for excerpt in &file.excerpts { + push_delimited( + prompt, + "related_excerpt", + &[( + "lines", + &format!( + "{}-{}", + excerpt.row_range.start + 1, + excerpt.row_range.end + 1 + ), + )], + |prompt| { + prompt.push_str(&excerpt.text); + prompt.push('\n'); + }, + ); + } + }); + } + }); +} + +fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) { + push_delimited(prompt, "edit_history", &[], |prompt| { + if input.events.is_empty() { + prompt.push_str("(No edit history)"); + } else { + for event in &input.events { + write_event(prompt, event); + } + } + }); +} + +fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) { + push_delimited(prompt, "cursor_excerpt", &[], |prompt| { + let path_str = input.cursor_path.to_string_lossy(); + push_delimited(prompt, "file", &[("path", &path_str)], |prompt| { + prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]); + push_delimited(prompt, "editable_region", &[], |prompt| { + prompt.push_str( + &input.cursor_excerpt + [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt], + ); + prompt.push_str(CURSOR_MARKER); + prompt.push_str( + &input.cursor_excerpt + [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end], + ); + }); + prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]); + }); + }); +} + +fn push_delimited( + prompt: &mut String, + tag: &'static str, + arguments: &[(&str, &str)], + cb: impl FnOnce(&mut String), +) { + if !prompt.ends_with("\n") { + prompt.push('\n'); + } + prompt.push('<'); + prompt.push_str(tag); + for (arg_name, arg_value) in arguments { + write!(prompt, " {}=\"{}\"", arg_name, arg_value).ok(); + } + prompt.push_str(">\n"); + + cb(prompt); + + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + prompt.push_str("\n"); +}