Detailed changes
@@ -3111,16 +3111,6 @@ dependencies = [
"uuid",
]
-[[package]]
-name = "cloud_zeta2_prompt"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "cloud_llm_client",
- "indoc",
- "serde",
-]
-
[[package]]
name = "cmake"
version = "0.1.54"
@@ -5119,7 +5109,6 @@ dependencies = [
"clock",
"cloud_api_types",
"cloud_llm_client",
- "cloud_zeta2_prompt",
"collections",
"copilot",
"credentials_provider",
@@ -5150,8 +5139,6 @@ dependencies = [
"serde",
"serde_json",
"settings",
- "smol",
- "strsim",
"strum 0.27.2",
"telemetry",
"telemetry_events",
@@ -5162,6 +5149,7 @@ dependencies = [
"workspace",
"worktree",
"zed_actions",
+ "zeta_prompt",
"zlog",
]
@@ -5175,11 +5163,10 @@ dependencies = [
"clap",
"client",
"cloud_llm_client",
- "cloud_zeta2_prompt",
"collections",
"debug_adapter_extension",
+ "dirs 4.0.0",
"edit_prediction",
- "edit_prediction_context",
"extension",
"fs",
"futures 0.3.31",
@@ -5209,9 +5196,10 @@ dependencies = [
"sqlez",
"sqlez_macros",
"terminal_view",
- "toml 0.8.23",
"util",
+ "wasmtime",
"watch",
+ "zeta_prompt",
"zlog",
]
@@ -5239,6 +5227,7 @@ dependencies = [
"text",
"tree-sitter",
"util",
+ "zeta_prompt",
"zlog",
]
@@ -5260,7 +5249,6 @@ dependencies = [
"buffer_diff",
"client",
"cloud_llm_client",
- "cloud_zeta2_prompt",
"codestral",
"command_palette_hooks",
"copilot",
@@ -5291,6 +5279,7 @@ dependencies = [
"util",
"workspace",
"zed_actions",
+ "zeta_prompt",
]
[[package]]
@@ -20933,6 +20922,13 @@ dependencies = [
"syn 2.0.106",
]
+[[package]]
+name = "zeta_prompt"
+version = "0.1.0"
+dependencies = [
+ "serde",
+]
+
[[package]]
name = "zip"
version = "0.6.6"
@@ -32,7 +32,6 @@ members = [
"crates/cloud_api_client",
"crates/cloud_api_types",
"crates/cloud_llm_client",
- "crates/cloud_zeta2_prompt",
"crates/collab",
"crates/collab_ui",
"crates/collections",
@@ -202,6 +201,7 @@ members = [
"crates/zed_actions",
"crates/zed_env_vars",
"crates/edit_prediction_cli",
+ "crates/zeta_prompt",
"crates/zlog",
"crates/zlog_settings",
"crates/ztracing",
@@ -266,7 +266,6 @@ clock = { path = "crates/clock" }
cloud_api_client = { path = "crates/cloud_api_client" }
cloud_api_types = { path = "crates/cloud_api_types" }
cloud_llm_client = { path = "crates/cloud_llm_client" }
-cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections", version = "0.1.0" }
command_palette = { path = "crates/command_palette" }
@@ -425,6 +424,7 @@ zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_env_vars = { path = "crates/zed_env_vars" }
edit_prediction = { path = "crates/edit_prediction" }
+zeta_prompt = { path = "crates/zeta_prompt" }
zlog = { path = "crates/zlog" }
zlog_settings = { path = "crates/zlog_settings" }
ztracing = { path = "crates/ztracing" }
@@ -657,6 +657,7 @@ time = { version = "0.3", features = [
tiny_http = "0.8"
tokio = { version = "1" }
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
+tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io", "tokio"] }
toml = "0.8"
toml_edit = { version = "0.22", default-features = false, features = ["display", "parse", "serde"] }
tower-http = "0.4.4"
@@ -53,7 +53,7 @@ text.workspace = true
thiserror.workspace = true
time.workspace = true
tiny_http.workspace = true
-tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
+tokio-socks.workspace = true
tokio.workspace = true
url.workspace = true
util.workspace = true
@@ -1,18 +0,0 @@
-[package]
-name = "cloud_zeta2_prompt"
-version = "0.1.0"
-publish.workspace = true
-edition.workspace = true
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/cloud_zeta2_prompt.rs"
-
-[dependencies]
-anyhow.workspace = true
-cloud_llm_client.workspace = true
-indoc.workspace = true
-serde.workspace = true
@@ -1,485 +0,0 @@
-use anyhow::Result;
-use cloud_llm_client::predict_edits_v3::{
- self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
-};
-use indoc::indoc;
-use std::cmp;
-use std::fmt::Write;
-use std::path::Path;
-use std::sync::Arc;
-
-pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
-
-pub const CURSOR_MARKER: &str = "<|user_cursor|>";
-/// NOTE: Differs from zed version of constant - includes a newline
-pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
-/// NOTE: Differs from zed version of constant - includes a newline
-pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
-
-const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
- You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
-
- ## Edit History
-
- "#};
-
-const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
- ---
-
- Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
- Do not include the cursor marker in your output.
- If you're editing multiple files, be sure to reflect filename in the hunk's header.
- "};
-
-const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#"
- # Instructions
-
- You are an edit prediction agent in a code editor.
-
- Analyze the history of edits made by the user in order to infer what they are currently trying to accomplish.
- Then complete the remainder of the current change if it is incomplete, or predict the next edit the user intends to make.
- Always continue along the user's current trajectory, rather than changing course.
-
- ## Output Format
-
- You should briefly explain your understanding of the user's overall goal in one sentence, then explain what the next change
- along the users current trajectory will be in another, and finally specify the next edit using the following XML-like format:
-
- <edits path="my-project/src/myapp/cli.py">
- <old_text>
- OLD TEXT 1 HERE
- </old_text>
- <new_text>
- NEW TEXT 1 HERE
- </new_text>
-
- <old_text>
- OLD TEXT 1 HERE
- </old_text>
- <new_text>
- NEW TEXT 1 HERE
- </new_text>
- </edits>
-
- - Specify the file to edit using the `path` attribute.
- - Use `<old_text>` and `<new_text>` tags to replace content
- - `<old_text>` must exactly match existing file content, including indentation
- - `<old_text>` cannot be empty
- - Do not escape quotes, newlines, or other characters within tags
- - Always close all tags properly
- - Don't include the <|user_cursor|> marker in your output.
-
- ## Edit History
-
-"#};
-
-const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
- ---
-
- Remember that the edits in the edit history have already been applied.
-"#};
-
-pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
- let prompt_data = PromptData {
- events: request.events.clone(),
- cursor_point: request.cursor_point,
- cursor_path: request.excerpt_path.clone(),
- included_files: request.related_files.clone(),
- };
- match request.prompt_format {
- PromptFormat::MinimalQwen => {
- return Ok(MinimalQwenPrompt.render(&prompt_data));
- }
- PromptFormat::SeedCoder1120 => {
- return Ok(SeedCoder1120Prompt.render(&prompt_data));
- }
- _ => (),
- };
-
- let insertions = match request.prompt_format {
- PromptFormat::Minimal | PromptFormat::OldTextNewText => {
- vec![(request.cursor_point, CURSOR_MARKER)]
- }
- PromptFormat::OnlySnippets => vec![],
- PromptFormat::MinimalQwen => unreachable!(),
- PromptFormat::SeedCoder1120 => unreachable!(),
- };
-
- let mut prompt = match request.prompt_format {
- PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
- PromptFormat::OnlySnippets => String::new(),
- PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
- PromptFormat::MinimalQwen => unreachable!(),
- PromptFormat::SeedCoder1120 => unreachable!(),
- };
-
- if request.events.is_empty() {
- prompt.push_str("(No edit history)\n\n");
- } else {
- let edit_preamble = if request.prompt_format == PromptFormat::Minimal {
- "The following are the latest edits made by the user, from earlier to later.\n\n"
- } else {
- "Here are the latest edits made by the user, from earlier to later.\n\n"
- };
- prompt.push_str(edit_preamble);
- push_events(&mut prompt, &request.events);
- }
-
- let excerpts_preamble = match request.prompt_format {
- PromptFormat::Minimal => indoc! {"
- ## Part of the file under the cursor
-
- (The cursor marker <|user_cursor|> indicates the current user cursor position.
- The file is in current state, edits from edit history has been applied.
- We only show part of the file around the cursor.
- You can only edit exactly this part of the file.
- We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
- "},
- PromptFormat::OldTextNewText => indoc! {"
- ## Code Excerpts
-
- Here is some excerpts of code that you should take into account to predict the next edit.
-
- The cursor position is marked by `<|user_cursor|>` as it stands after the last edit in the history.
-
- In addition other excerpts are included to better understand what the edit will be, including the declaration
- or references of symbols around the cursor, or other similar code snippets that may need to be updated
- following patterns that appear in the edit history.
-
- Consider each of them carefully in relation to the edit history, and that the user may not have navigated
- to the next place they want to edit yet.
-
- Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs.
- "},
- PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
- indoc! {"
- ## Code Excerpts
-
- The cursor marker <|user_cursor|> indicates the current user cursor position.
- The file is in current state, edits from edit history have been applied.
- "}
- }
- };
-
- prompt.push_str(excerpts_preamble);
- prompt.push('\n');
-
- let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
- for related_file in &request.related_files {
- if request.prompt_format == PromptFormat::Minimal {
- write_codeblock_with_filename(
- &related_file.path,
- &related_file.excerpts,
- if related_file.path == request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- related_file.max_row,
- include_line_numbers,
- &mut prompt,
- );
- } else {
- write_codeblock(
- &related_file.path,
- &related_file.excerpts,
- if related_file.path == request.excerpt_path {
- &insertions
- } else {
- &[]
- },
- related_file.max_row,
- include_line_numbers,
- &mut prompt,
- );
- }
- }
-
- match request.prompt_format {
- PromptFormat::OldTextNewText => {
- prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
- }
- PromptFormat::Minimal => {
- prompt.push_str(MINIMAL_PROMPT_REMINDER);
- }
- _ => {}
- }
-
- Ok(prompt)
-}
-
-pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
- match prompt_format {
- PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(),
- _ => GenerationParams::default(),
- }
-}
-
-pub fn write_codeblock<'a>(
- path: &Path,
- excerpts: impl IntoIterator<Item = &'a Excerpt>,
- sorted_insertions: &[(Point, &str)],
- file_line_count: Line,
- include_line_numbers: bool,
- output: &'a mut String,
-) {
- writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
-
- write_excerpts(
- excerpts,
- sorted_insertions,
- file_line_count,
- include_line_numbers,
- output,
- );
- write!(output, "`````\n\n").unwrap();
-}
-
-fn write_codeblock_with_filename<'a>(
- path: &Path,
- excerpts: impl IntoIterator<Item = &'a Excerpt>,
- sorted_insertions: &[(Point, &str)],
- file_line_count: Line,
- include_line_numbers: bool,
- output: &'a mut String,
-) {
- writeln!(output, "`````filename={}", DiffPathFmt(path)).unwrap();
-
- write_excerpts(
- excerpts,
- sorted_insertions,
- file_line_count,
- include_line_numbers,
- output,
- );
- write!(output, "`````\n\n").unwrap();
-}
-
-pub fn write_excerpts<'a>(
- excerpts: impl IntoIterator<Item = &'a Excerpt>,
- sorted_insertions: &[(Point, &str)],
- file_line_count: Line,
- include_line_numbers: bool,
- output: &mut String,
-) {
- let mut current_row = Line(0);
- let mut sorted_insertions = sorted_insertions.iter().peekable();
-
- for excerpt in excerpts {
- if excerpt.start_line > current_row {
- writeln!(output, "…").unwrap();
- }
- if excerpt.text.is_empty() {
- return;
- }
-
- current_row = excerpt.start_line;
-
- for mut line in excerpt.text.lines() {
- if include_line_numbers {
- write!(output, "{}|", current_row.0 + 1).unwrap();
- }
-
- while let Some((insertion_location, insertion_marker)) = sorted_insertions.peek() {
- match current_row.cmp(&insertion_location.line) {
- cmp::Ordering::Equal => {
- let (prefix, suffix) = line.split_at(insertion_location.column as usize);
- output.push_str(prefix);
- output.push_str(insertion_marker);
- line = suffix;
- sorted_insertions.next();
- }
- cmp::Ordering::Less => break,
- cmp::Ordering::Greater => {
- sorted_insertions.next();
- break;
- }
- }
- }
- output.push_str(line);
- output.push('\n');
- current_row.0 += 1;
- }
- }
-
- if current_row < file_line_count {
- writeln!(output, "…").unwrap();
- }
-}
-
-pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>]) {
- if events.is_empty() {
- return;
- };
-
- writeln!(output, "`````diff").unwrap();
- for event in events {
- writeln!(output, "{}", event).unwrap();
- }
- writeln!(output, "`````\n").unwrap();
-}
-
-struct PromptData {
- events: Vec<Arc<Event>>,
- cursor_point: Point,
- cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
- included_files: Vec<RelatedFile>,
-}
-
-#[derive(Default)]
-pub struct GenerationParams {
- pub temperature: Option<f32>,
- pub top_p: Option<f32>,
- pub stop: Option<Vec<String>>,
-}
-
-trait PromptFormatter {
- fn render(&self, data: &PromptData) -> String;
-
- fn generation_params() -> GenerationParams {
- return GenerationParams::default();
- }
-}
-
-struct MinimalQwenPrompt;
-
-impl PromptFormatter for MinimalQwenPrompt {
- fn render(&self, data: &PromptData) -> String {
- let edit_history = self.fmt_edit_history(data);
- let context = self.fmt_context(data);
-
- format!(
- "{instructions}\n\n{edit_history}\n\n{context}",
- instructions = MinimalQwenPrompt::INSTRUCTIONS,
- edit_history = edit_history,
- context = context
- )
- }
-}
-
-impl MinimalQwenPrompt {
- const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n";
-
- fn fmt_edit_history(&self, data: &PromptData) -> String {
- if data.events.is_empty() {
- "(No edit history)\n\n".to_string()
- } else {
- let mut events_str = String::new();
- push_events(&mut events_str, &data.events);
- format!(
- "The following are the latest edits made by the user, from earlier to later.\n\n{}",
- events_str
- )
- }
- }
-
- fn fmt_context(&self, data: &PromptData) -> String {
- let mut context = String::new();
- let include_line_numbers = true;
-
- for related_file in &data.included_files {
- writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap();
-
- if related_file.path == data.cursor_path {
- write!(context, "<|fim_prefix|>").unwrap();
- write_excerpts(
- &related_file.excerpts,
- &[(data.cursor_point, "<|fim_suffix|>")],
- related_file.max_row,
- include_line_numbers,
- &mut context,
- );
- writeln!(context, "<|fim_middle|>").unwrap();
- } else {
- write_excerpts(
- &related_file.excerpts,
- &[],
- related_file.max_row,
- include_line_numbers,
- &mut context,
- );
- }
- }
- context
- }
-}
-
-struct SeedCoder1120Prompt;
-
-impl PromptFormatter for SeedCoder1120Prompt {
- fn render(&self, data: &PromptData) -> String {
- let edit_history = self.fmt_edit_history(data);
- let context = self.fmt_context(data);
-
- format!(
- "# Edit History:\n{edit_history}\n\n{context}",
- edit_history = edit_history,
- context = context
- )
- }
-
- fn generation_params() -> GenerationParams {
- GenerationParams {
- temperature: Some(0.2),
- top_p: Some(0.9),
- stop: Some(vec!["<[end_of_sentence]>".into()]),
- }
- }
-}
-
-impl SeedCoder1120Prompt {
- fn fmt_edit_history(&self, data: &PromptData) -> String {
- if data.events.is_empty() {
- "(No edit history)\n\n".to_string()
- } else {
- let mut events_str = String::new();
- push_events(&mut events_str, &data.events);
- events_str
- }
- }
-
- fn fmt_context(&self, data: &PromptData) -> String {
- let mut context = String::new();
- let include_line_numbers = true;
-
- for related_file in &data.included_files {
- writeln!(context, "# Path: {}\n", DiffPathFmt(&related_file.path)).unwrap();
-
- if related_file.path == data.cursor_path {
- let fim_prompt = self.fmt_fim(&related_file, data.cursor_point);
- context.push_str(&fim_prompt);
- } else {
- write_excerpts(
- &related_file.excerpts,
- &[],
- related_file.max_row,
- include_line_numbers,
- &mut context,
- );
- }
- }
- context
- }
-
- fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
- let mut buf = String::new();
- const FIM_SUFFIX: &str = "<[fim-suffix]>";
- const FIM_PREFIX: &str = "<[fim-prefix]>";
- const FIM_MIDDLE: &str = "<[fim-middle]>";
- write!(buf, "{}", FIM_PREFIX).unwrap();
- write_excerpts(
- &file.excerpts,
- &[(cursor_point, FIM_SUFFIX)],
- file.max_row,
- true,
- &mut buf,
- );
-
- // Swap prefix and suffix parts
- let index = buf.find(FIM_SUFFIX).unwrap();
- let prefix = &buf[..index];
- let suffix = &buf[index..];
-
- format!("{}{}{}", suffix, prefix, FIM_MIDDLE)
- }
-}
@@ -21,7 +21,6 @@ arrayvec.workspace = true
brotli.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
-cloud_zeta2_prompt.workspace = true
collections.workspace = true
copilot.workspace = true
credentials_provider.workspace = true
@@ -50,8 +49,6 @@ semver.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
-smol.workspace = true
-strsim.workspace = true
strum.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
@@ -62,6 +59,7 @@ uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
zed_actions.workspace = true
+zeta_prompt.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }
@@ -1,14 +1,13 @@
use anyhow::Result;
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
-use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
+use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
ZED_VERSION_HEADER_NAME,
};
-use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
use collections::{HashMap, HashSet};
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use edit_prediction_context::EditPredictionExcerptOptions;
@@ -16,10 +15,7 @@ use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, Rel
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
use futures::{
AsyncReadExt as _, FutureExt as _, StreamExt as _,
- channel::{
- mpsc::{self, UnboundedReceiver},
- oneshot,
- },
+ channel::mpsc::{self, UnboundedReceiver},
select_biased,
};
use gpui::BackgroundExecutor;
@@ -58,8 +54,10 @@ mod onboarding_modal;
pub mod open_ai_response;
mod prediction;
pub mod sweep_ai;
+
+#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
pub mod udiff;
-mod xml_edits;
+
mod zed_edit_prediction_delegate;
pub mod zeta1;
pub mod zeta2;
@@ -72,7 +70,6 @@ use crate::mercury::Mercury;
use crate::onboarding_modal::ZedPredictModal;
pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
-pub use crate::prediction::EditPredictionInputs;
use crate::prediction::EditPredictionResult;
pub use crate::sweep_ai::SweepAi;
pub use telemetry_events::EditPredictionRating;
@@ -112,7 +109,6 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
min_bytes: 128,
target_before_cursor_over_total_bytes: 0.5,
},
- max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
prompt_format: PromptFormat::DEFAULT,
};
@@ -162,7 +158,6 @@ pub struct EditPredictionStore {
use_context: bool,
options: ZetaOptions,
update_required: bool,
- debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
#[cfg(feature = "eval-support")]
eval_cache: Option<Arc<dyn EvalCache>>,
edit_prediction_model: EditPredictionModel,
@@ -183,10 +178,22 @@ pub enum EditPredictionModel {
Mercury,
}
+pub struct EditPredictionModelInput {
+ project: Entity<Project>,
+ buffer: Entity<Buffer>,
+ snapshot: BufferSnapshot,
+ position: Anchor,
+ events: Vec<Arc<zeta_prompt::Event>>,
+ related_files: Arc<[RelatedFile]>,
+ recent_paths: VecDeque<ProjectPath>,
+ trigger: PredictEditsRequestTrigger,
+ diagnostic_search_range: Range<Point>,
+ debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
+}
+
#[derive(Debug, Clone, PartialEq)]
pub struct ZetaOptions {
pub context: EditPredictionExcerptOptions,
- pub max_prompt_bytes: usize,
pub prompt_format: predict_edits_v3::PromptFormat,
}
@@ -194,7 +201,8 @@ pub struct ZetaOptions {
pub enum DebugEvent {
ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
- EditPredictionRequested(EditPredictionRequestedDebugEvent),
+ EditPredictionStarted(EditPredictionStartedDebugEvent),
+ EditPredictionFinished(EditPredictionFinishedDebugEvent),
}
#[derive(Debug)]
@@ -212,27 +220,30 @@ pub struct ContextRetrievalFinishedDebugEvent {
}
#[derive(Debug)]
-pub struct EditPredictionRequestedDebugEvent {
- pub inputs: EditPredictionInputs,
- pub retrieval_time: Duration,
+pub struct EditPredictionStartedDebugEvent {
pub buffer: WeakEntity<Buffer>,
pub position: Anchor,
- pub local_prompt: Result<String, String>,
- pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
+ pub prompt: Option<String>,
+}
+
+#[derive(Debug)]
+pub struct EditPredictionFinishedDebugEvent {
+ pub buffer: WeakEntity<Buffer>,
+ pub position: Anchor,
+ pub model_output: Option<String>,
}
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
struct ProjectState {
- events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
+ events: VecDeque<Arc<zeta_prompt::Event>>,
last_event: Option<LastEvent>,
recent_paths: VecDeque<ProjectPath>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
next_pending_prediction_id: usize,
pending_predictions: ArrayVec<PendingPrediction, 2>,
- context_updates_tx: smol::channel::Sender<()>,
- context_updates_rx: smol::channel::Receiver<()>,
+ debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
last_prediction_refresh: Option<(EntityId, Instant)>,
cancelled_predictions: HashSet<usize>,
context: Entity<RelatedExcerptStore>,
@@ -241,7 +252,7 @@ struct ProjectState {
}
impl ProjectState {
- pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
+ pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
self.events
.iter()
.cloned()
@@ -376,7 +387,7 @@ impl LastEvent {
&self,
license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
cx: &App,
- ) -> Option<Arc<predict_edits_v3::Event>> {
+ ) -> Option<Arc<zeta_prompt::Event>> {
let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
@@ -396,7 +407,7 @@ impl LastEvent {
if path == old_path && diff.is_empty() {
None
} else {
- Some(Arc::new(predict_edits_v3::Event::BufferChange {
+ Some(Arc::new(zeta_prompt::Event::BufferChange {
old_path,
path,
diff,
@@ -481,7 +492,6 @@ impl EditPredictionStore {
},
),
update_required: false,
- debug_tx: None,
#[cfg(feature = "eval-support")]
eval_cache: None,
edit_prediction_model: EditPredictionModel::Zeta2,
@@ -536,12 +546,6 @@ impl EditPredictionStore {
self.eval_cache = Some(cache);
}
- pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
- let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
- self.debug_tx = Some(debug_watch_tx);
- debug_watch_rx
- }
-
pub fn options(&self) -> &ZetaOptions {
&self.options
}
@@ -560,15 +564,35 @@ impl EditPredictionStore {
}
}
+ pub fn edit_history_for_project(
+ &self,
+ project: &Entity<Project>,
+ ) -> Vec<Arc<zeta_prompt::Event>> {
+ self.projects
+ .get(&project.entity_id())
+ .map(|project_state| project_state.events.iter().cloned().collect())
+ .unwrap_or_default()
+ }
+
pub fn context_for_project<'a>(
&'a self,
project: &Entity<Project>,
cx: &'a App,
- ) -> &'a [RelatedFile] {
+ ) -> Arc<[RelatedFile]> {
self.projects
.get(&project.entity_id())
.map(|project| project.context.read(cx).related_files())
- .unwrap_or(&[])
+ .unwrap_or_else(|| vec![].into())
+ }
+
+ pub fn context_for_project_with_buffers<'a>(
+ &'a self,
+ project: &Entity<Project>,
+ cx: &'a App,
+ ) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
+ self.projects
+ .get(&project.entity_id())
+ .map(|project| project.context.read(cx).related_files_with_buffers())
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
@@ -599,85 +623,21 @@ impl EditPredictionStore {
cx: &mut Context<Self>,
) -> &mut ProjectState {
let entity_id = project.entity_id();
- let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
self.projects
.entry(entity_id)
.or_insert_with(|| ProjectState {
context: {
let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
- cx.subscribe(
- &related_excerpt_store,
- move |this, _, event, _| match event {
- RelatedExcerptStoreEvent::StartedRefresh => {
- if let Some(debug_tx) = this.debug_tx.clone() {
- debug_tx
- .unbounded_send(DebugEvent::ContextRetrievalStarted(
- ContextRetrievalStartedDebugEvent {
- project_entity_id: entity_id,
- timestamp: Instant::now(),
- search_prompt: String::new(),
- },
- ))
- .ok();
- }
- }
- RelatedExcerptStoreEvent::FinishedRefresh {
- cache_hit_count,
- cache_miss_count,
- mean_definition_latency,
- max_definition_latency,
- } => {
- if let Some(debug_tx) = this.debug_tx.clone() {
- debug_tx
- .unbounded_send(DebugEvent::ContextRetrievalFinished(
- ContextRetrievalFinishedDebugEvent {
- project_entity_id: entity_id,
- timestamp: Instant::now(),
- metadata: vec![
- (
- "Cache Hits",
- format!(
- "{}/{}",
- cache_hit_count,
- cache_hit_count + cache_miss_count
- )
- .into(),
- ),
- (
- "Max LSP Time",
- format!(
- "{} ms",
- max_definition_latency.as_millis()
- )
- .into(),
- ),
- (
- "Mean LSP Time",
- format!(
- "{} ms",
- mean_definition_latency.as_millis()
- )
- .into(),
- ),
- ],
- },
- ))
- .ok();
- }
- if let Some(project_state) = this.projects.get(&entity_id) {
- project_state.context_updates_tx.send_blocking(()).ok();
- }
- }
- },
- )
+ cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
+ this.handle_excerpt_store_event(entity_id, event);
+ })
.detach();
related_excerpt_store
},
events: VecDeque::new(),
last_event: None,
recent_paths: VecDeque::new(),
- context_updates_rx,
- context_updates_tx,
+ debug_tx: None,
registered_buffers: HashMap::default(),
current_prediction: None,
cancelled_predictions: HashSet::default(),
@@ -689,12 +649,79 @@ impl EditPredictionStore {
})
}
- pub fn project_context_updates(
- &self,
+ pub fn remove_project(&mut self, project: &Entity<Project>) {
+ self.projects.remove(&project.entity_id());
+ }
+
+ fn handle_excerpt_store_event(
+ &mut self,
+ project_entity_id: EntityId,
+ event: &RelatedExcerptStoreEvent,
+ ) {
+ if let Some(project_state) = self.projects.get(&project_entity_id) {
+ if let Some(debug_tx) = project_state.debug_tx.clone() {
+ match event {
+ RelatedExcerptStoreEvent::StartedRefresh => {
+ debug_tx
+ .unbounded_send(DebugEvent::ContextRetrievalStarted(
+ ContextRetrievalStartedDebugEvent {
+ project_entity_id: project_entity_id,
+ timestamp: Instant::now(),
+ search_prompt: String::new(),
+ },
+ ))
+ .ok();
+ }
+ RelatedExcerptStoreEvent::FinishedRefresh {
+ cache_hit_count,
+ cache_miss_count,
+ mean_definition_latency,
+ max_definition_latency,
+ } => {
+ debug_tx
+ .unbounded_send(DebugEvent::ContextRetrievalFinished(
+ ContextRetrievalFinishedDebugEvent {
+ project_entity_id: project_entity_id,
+ timestamp: Instant::now(),
+ metadata: vec![
+ (
+ "Cache Hits",
+ format!(
+ "{}/{}",
+ cache_hit_count,
+ cache_hit_count + cache_miss_count
+ )
+ .into(),
+ ),
+ (
+ "Max LSP Time",
+ format!("{} ms", max_definition_latency.as_millis())
+ .into(),
+ ),
+ (
+ "Mean LSP Time",
+ format!("{} ms", mean_definition_latency.as_millis())
+ .into(),
+ ),
+ ],
+ },
+ ))
+ .ok();
+ }
+ }
+ }
+ }
+ }
+
+ pub fn debug_info(
+ &mut self,
project: &Entity<Project>,
- ) -> Option<smol::channel::Receiver<()>> {
- let project_state = self.projects.get(&project.entity_id())?;
- Some(project_state.context_updates_rx.clone())
+ cx: &mut Context<Self>,
+ ) -> mpsc::UnboundedReceiver<DebugEvent> {
+ let project_state = self.get_or_init_project(project, cx);
+ let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
+ project_state.debug_tx = Some(debug_watch_tx);
+ debug_watch_rx
}
fn handle_project_event(
@@ -1348,6 +1375,7 @@ impl EditPredictionStore {
let project_state = self.projects.get(&project.entity_id()).unwrap();
let events = project_state.events(cx);
let has_events = !events.is_empty();
+ let debug_tx = project_state.debug_tx.clone();
let snapshot = active_buffer.read(cx).snapshot();
let cursor_point = position.to_point(&snapshot);
@@ -1357,55 +1385,29 @@ impl EditPredictionStore {
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
let related_files = if self.use_context {
- self.context_for_project(&project, cx).to_vec()
+ self.context_for_project(&project, cx)
} else {
- Vec::new()
+ Vec::new().into()
+ };
+
+ let inputs = EditPredictionModelInput {
+ project: project.clone(),
+ buffer: active_buffer.clone(),
+ snapshot: snapshot.clone(),
+ position,
+ events,
+ related_files,
+ recent_paths: project_state.recent_paths.clone(),
+ trigger,
+ diagnostic_search_range: diagnostic_search_range.clone(),
+ debug_tx,
};
let task = match self.edit_prediction_model {
- EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
- self,
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- trigger,
- cx,
- ),
- EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
- self,
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- related_files,
- trigger,
- cx,
- ),
- EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- &project_state.recent_paths,
- related_files,
- diagnostic_search_range.clone(),
- cx,
- ),
- EditPredictionModel::Mercury => self.mercury.request_prediction(
- &project,
- &active_buffer,
- snapshot.clone(),
- position,
- events,
- &project_state.recent_paths,
- related_files,
- diagnostic_search_range.clone(),
- cx,
- ),
+ EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
+ EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
+ EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
+ EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
};
cx.spawn(async move |this, cx| {
@@ -1706,6 +1708,20 @@ impl EditPredictionStore {
}
}
+ #[cfg(feature = "eval-support")]
+ pub fn set_context_for_buffer(
+ &mut self,
+ project: &Entity<Project>,
+ related_files: Vec<RelatedFile>,
+ cx: &mut Context<Self>,
+ ) {
+ self.get_or_init_project(project, cx)
+ .context
+ .update(cx, |store, _| {
+ store.set_related_files(related_files);
+ });
+ }
+
fn is_file_open_source(
&self,
project: &Entity<Project>,
@@ -1729,14 +1745,14 @@ impl EditPredictionStore {
self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
}
- fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
+ fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
if !self.data_collection_choice.is_enabled() {
return false;
}
events.iter().all(|event| {
matches!(
event.as_ref(),
- Event::BufferChange {
+ zeta_prompt::Event::BufferChange {
in_open_source_repo: true,
..
}
@@ -1,5 +1,5 @@
use super::*;
-use crate::zeta1::MAX_EVENT_TOKENS;
+use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
use client::{UserStore, test::FakeServer};
use clock::{FakeSystemClock, ReplicaId};
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -7,7 +7,6 @@ use cloud_llm_client::{
EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
RejectEditPredictionsBody,
};
-use edit_prediction_context::Line;
use futures::{
AsyncReadExt, StreamExt,
channel::{mpsc, oneshot},
@@ -28,6 +27,7 @@ use settings::SettingsStore;
use std::{path::Path, sync::Arc, time::Duration};
use util::{path, rel_path::rel_path};
use uuid::Uuid;
+use zeta_prompt::ZetaPromptInput;
use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
@@ -65,18 +65,21 @@ async fn test_current_state(cx: &mut TestAppContext) {
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
});
- let (_request, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
- .send(model_response(indoc! {r"
- --- a/root/1.txt
- +++ b/root/1.txt
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "}))
+ .send(model_response(
+ request,
+ indoc! {r"
+ --- a/root/1.txt
+ +++ b/root/1.txt
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "},
+ ))
.unwrap();
cx.run_until_parked();
@@ -120,16 +123,20 @@ async fn test_current_state(cx: &mut TestAppContext) {
});
});
- let (_request, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
- .send(model_response(indoc! {r#"
- --- a/root/2.txt
- +++ b/root/2.txt
- Hola!
- -Como
- +Como estas?
- Adios
- "#}))
+ .send(model_response(
+ request,
+ indoc! {r#"
+ --- a/root/2.txt
+ +++ b/root/2.txt
+ @@ ... @@
+ Hola!
+ -Como
+ +Como estas?
+ Adios
+ "#},
+ ))
.unwrap();
cx.run_until_parked();
@@ -186,7 +193,7 @@ async fn test_simple_request(cx: &mut TestAppContext) {
ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
// TODO Put back when we have a structured request again
// assert_eq!(
@@ -202,15 +209,18 @@ async fn test_simple_request(cx: &mut TestAppContext) {
// );
respond_tx
- .send(model_response(indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "}))
+ .send(model_response(
+ request,
+ indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "},
+ ))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -276,15 +286,18 @@ async fn test_request_events(cx: &mut TestAppContext) {
);
respond_tx
- .send(model_response(indoc! {r#"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are you?
- Bye
- "#}))
+ .send(model_response(
+ request,
+ indoc! {r#"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are you?
+ Bye
+ "#},
+ ))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -324,18 +337,8 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- const NO_OP_DIFF: &str = indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How
- Bye
- "};
-
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let response = model_response(NO_OP_DIFF);
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let response = model_response(request, "");
let id = response.id.clone();
respond_tx.send(response).unwrap();
@@ -389,13 +392,13 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text("Hello!\nHow are you?\nBye", cx);
});
- let response = model_response(SIMPLE_DIFF);
+ let response = model_response(request, SIMPLE_DIFF);
let id = response.id.clone();
respond_tx.send(response).unwrap();
@@ -459,8 +462,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let first_response = model_response(SIMPLE_DIFF);
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let first_response = model_response(request, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_tx.send(first_response).unwrap();
@@ -482,8 +485,8 @@ async fn test_replace_current(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let second_response = model_response(SIMPLE_DIFF);
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let second_response = model_response(request, SIMPLE_DIFF);
let second_id = second_response.id.clone();
respond_tx.send(second_response).unwrap();
@@ -541,8 +544,8 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
- let first_response = model_response(SIMPLE_DIFF);
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
+ let first_response = model_response(request, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_tx.send(first_response).unwrap();
@@ -564,17 +567,20 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_tx) = requests.predict.next().await.unwrap();
+ let (request, respond_tx) = requests.predict.next().await.unwrap();
// worse than current prediction
- let second_response = model_response(indoc! { r"
- --- a/root/foo.md
- +++ b/root/foo.md
- @@ ... @@
- Hello!
- -How
- +How are
- Bye
- "});
+ let second_response = model_response(
+ request,
+ indoc! { r"
+ --- a/root/foo.md
+ +++ b/root/foo.md
+ @@ ... @@
+ Hello!
+ -How
+ +How are
+ Bye
+ "},
+ );
let second_id = second_response.id.clone();
respond_tx.send(second_response).unwrap();
@@ -633,19 +639,19 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_first) = requests.predict.next().await.unwrap();
+ let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_second) = requests.predict.next().await.unwrap();
+ let (request, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle
cx.run_until_parked();
// second responds first
- let second_response = model_response(SIMPLE_DIFF);
+ let second_response = model_response(request, SIMPLE_DIFF);
let second_id = second_response.id.clone();
respond_second.send(second_response).unwrap();
@@ -663,7 +669,7 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
);
});
- let first_response = model_response(SIMPLE_DIFF);
+ let first_response = model_response(request1, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_first.send(first_response).unwrap();
@@ -724,13 +730,13 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_first) = requests.predict.next().await.unwrap();
+ let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
- let (_, respond_second) = requests.predict.next().await.unwrap();
+ let (request2, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle, so requests are sent
cx.run_until_parked();
@@ -754,9 +760,9 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
// wait for throttle
cx.run_until_parked();
- let (_, respond_third) = requests.predict.next().await.unwrap();
+ let (request3, respond_third) = requests.predict.next().await.unwrap();
- let first_response = model_response(SIMPLE_DIFF);
+ let first_response = model_response(request1, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_first.send(first_response).unwrap();
@@ -774,7 +780,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
- let cancelled_response = model_response(SIMPLE_DIFF);
+ let cancelled_response = model_response(request2, SIMPLE_DIFF);
let cancelled_id = cancelled_response.id.clone();
respond_second.send(cancelled_response).unwrap();
@@ -792,7 +798,7 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
- let third_response = model_response(SIMPLE_DIFF);
+ let third_response = model_response(request3, SIMPLE_DIFF);
let third_response_id = third_response.id.clone();
respond_third.send(third_response).unwrap();
@@ -1036,7 +1042,24 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
// );
// }
-fn model_response(text: &str) -> open_ai::Response {
+// Generate a model response that would apply the given diff to the active file.
+fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
+ let prompt = match &request.messages[0] {
+ open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(content),
+ } => content,
+ _ => panic!("unexpected request {request:?}"),
+ };
+
+ let open = "<editable_region>\n";
+ let close = "</editable_region>";
+ let cursor = "<|user_cursor|>";
+
+ let start_ix = open.len() + prompt.find(open).unwrap();
+ let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
+ let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
+ let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
+
open_ai::Response {
id: Uuid::new_v4().to_string(),
object: "response".into(),
@@ -1045,7 +1068,7 @@ fn model_response(text: &str) -> open_ai::Response {
choices: vec![open_ai::Choice {
index: 0,
message: open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Plain(text.to_string())),
+ content: Some(open_ai::MessageContent::Plain(new_excerpt)),
tool_calls: vec![],
},
finish_reason: None,
@@ -1160,20 +1183,19 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
.read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
.await;
- let completion = EditPrediction {
+ let prediction = EditPrediction {
edits,
edit_preview,
buffer: buffer.clone(),
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
id: EditPredictionId("the-id".into()),
- inputs: EditPredictionInputs {
+ inputs: ZetaPromptInput {
events: Default::default(),
- included_files: Default::default(),
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- line: Line(0),
- column: 0,
- },
+ related_files: Default::default(),
cursor_path: Path::new("").into(),
+ cursor_excerpt: "".into(),
+ editable_range_in_excerpt: 0..0,
+ cursor_offset_in_excerpt: 0,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
@@ -1182,7 +1204,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
cx.update(|cx| {
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1192,7 +1214,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1202,7 +1224,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1212,7 +1234,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1222,7 +1244,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1232,7 +1254,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1242,7 +1264,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1252,7 +1274,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
assert_eq!(
from_completion_edits(
- &completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
+ &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1260,7 +1282,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
- assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
+ assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
})
}
@@ -1,20 +1,17 @@
use anyhow::{Context as _, Result};
-use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
-use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
- App, AppContext as _, Entity, Task,
+ App, AppContext as _, Task,
http_client::{self, AsyncBody, Method},
};
-use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
-use project::{Project, ProjectPath};
-use std::{
- collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
-};
+use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
+use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
+use zeta_prompt::ZetaPromptInput;
use crate::{
- EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
+ DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
+ EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
prediction::EditPredictionResult,
};
@@ -38,16 +35,17 @@ impl Mercury {
store_api_token_in_keychain(api_token, cx)
}
- pub fn request_prediction(
+ pub(crate) fn request_prediction(
&self,
- _project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
- snapshot: BufferSnapshot,
- position: language::Anchor,
- events: Vec<Arc<Event>>,
- _recent_paths: &VecDeque<ProjectPath>,
- related_files: Vec<RelatedFile>,
- _diagnostic_search_range: Range<Point>,
+ EditPredictionModelInput {
+ buffer,
+ snapshot,
+ position,
+ events,
+ related_files,
+ debug_tx,
+ ..
+ }: EditPredictionModelInput,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
@@ -62,6 +60,7 @@ impl Mercury {
let http_client = cx.http_client();
let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
+ let active_buffer = buffer.clone();
let result = cx.background_spawn(async move {
let (editable_range, context_range) =
@@ -72,39 +71,39 @@ impl Mercury {
MAX_REWRITE_TOKENS,
);
- let offset_range = editable_range.to_offset(&snapshot);
- let prompt = build_prompt(
- &events,
- &related_files,
- &snapshot,
- full_path.as_ref(),
- cursor_point,
- editable_range,
- context_range.clone(),
- );
-
- let inputs = EditPredictionInputs {
- events: events,
- included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
- path: full_path.clone(),
- max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
- excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
- start_line: cloud_llm_client::predict_edits_v3::Line(
- context_range.start.row,
- ),
- text: snapshot
- .text_for_range(context_range.clone())
- .collect::<String>()
- .into(),
- }],
- }],
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- column: cursor_point.column,
- line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
- },
+ let context_offset_range = context_range.to_offset(&snapshot);
+
+ let editable_offset_range = editable_range.to_offset(&snapshot);
+
+ let inputs = zeta_prompt::ZetaPromptInput {
+ events,
+ related_files,
+ cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
+ - context_range.start.to_offset(&snapshot),
cursor_path: full_path.clone(),
+ cursor_excerpt: snapshot
+ .text_for_range(context_range)
+ .collect::<String>()
+ .into(),
+ editable_range_in_excerpt: (editable_offset_range.start
+ - context_offset_range.start)
+ ..(editable_offset_range.end - context_offset_range.start),
};
+ let prompt = build_prompt(&inputs);
+
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionStarted(
+ EditPredictionStartedDebugEvent {
+ buffer: active_buffer.downgrade(),
+ prompt: Some(prompt.clone()),
+ position,
+ },
+ ))
+ .ok();
+ }
+
let request_body = open_ai::Request {
model: "mercury-coder".into(),
messages: vec![open_ai::RequestMessage::User {
@@ -160,6 +159,18 @@ impl Mercury {
let id = mem::take(&mut response.id);
let response_str = text_from_response(response).unwrap_or_default();
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionFinished(
+ EditPredictionFinishedDebugEvent {
+ buffer: active_buffer.downgrade(),
+ model_output: Some(response_str.clone()),
+ position,
+ },
+ ))
+ .ok();
+ }
+
let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
@@ -168,15 +179,16 @@ impl Mercury {
if response_str != NO_PREDICTION_OUTPUT {
let old_text = snapshot
- .text_for_range(offset_range.clone())
+ .text_for_range(editable_offset_range.clone())
.collect::<String>();
edits.extend(
language::text_diff(&old_text, &response_str)
.into_iter()
.map(|(range, text)| {
(
- snapshot.anchor_after(offset_range.start + range.start)
- ..snapshot.anchor_before(offset_range.start + range.end),
+ snapshot.anchor_after(editable_offset_range.start + range.start)
+ ..snapshot
+ .anchor_before(editable_offset_range.start + range.end),
text,
)
}),
@@ -186,8 +198,6 @@ impl Mercury {
anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
});
- let buffer = active_buffer.clone();
-
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) =
result.await.context("Mercury edit prediction failed")?;
@@ -208,15 +218,7 @@ impl Mercury {
}
}
-fn build_prompt(
- events: &[Arc<Event>],
- related_files: &[RelatedFile],
- cursor_buffer: &BufferSnapshot,
- cursor_buffer_path: &Path,
- cursor_point: Point,
- editable_range: Range<Point>,
- context_range: Range<Point>,
-) -> String {
+fn build_prompt(inputs: &ZetaPromptInput) -> String {
const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
@@ -237,14 +239,14 @@ fn build_prompt(
&mut prompt,
RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
|prompt| {
- for related_file in related_files {
+ for related_file in inputs.related_files.iter() {
for related_excerpt in &related_file.excerpts {
push_delimited(
prompt,
RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
|prompt| {
prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
- prompt.push_str(related_file.path.path.as_unix_str());
+ prompt.push_str(related_file.path.to_string_lossy().as_ref());
prompt.push('\n');
prompt.push_str(&related_excerpt.text.to_string());
},
@@ -259,21 +261,22 @@ fn build_prompt(
CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
|prompt| {
prompt.push_str(CURRENT_FILE_PATH_PREFIX);
- prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
+ prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
prompt.push('\n');
- let prefix_range = context_range.start..editable_range.start;
- let suffix_range = editable_range.end..context_range.end;
-
- prompt.extend(cursor_buffer.text_for_range(prefix_range));
+ prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]);
push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
- let range_before_cursor = editable_range.start..cursor_point;
- let range_after_cursor = cursor_point..editable_range.end;
- prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
+ prompt.push_str(
+ &inputs.cursor_excerpt
+ [inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt],
+ );
prompt.push_str(CURSOR_TAG);
- prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
+ prompt.push_str(
+ &inputs.cursor_excerpt
+ [inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end],
+ );
});
- prompt.extend(cursor_buffer.text_for_range(suffix_range));
+ prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]);
},
);
@@ -281,8 +284,8 @@ fn build_prompt(
&mut prompt,
EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
|prompt| {
- for event in events {
- writeln!(prompt, "{event}").unwrap();
+ for event in inputs.events.iter() {
+ zeta_prompt::write_event(prompt, &event);
}
},
);
@@ -1,6 +1,5 @@
use std::{
ops::Range,
- path::Path,
sync::Arc,
time::{Duration, Instant},
};
@@ -9,7 +8,7 @@ use cloud_llm_client::EditPredictionRejectReason;
use edit_prediction_types::interpolate_edits;
use gpui::{AsyncApp, Entity, SharedString};
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
-use serde::Serialize;
+use zeta_prompt::ZetaPromptInput;
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct EditPredictionId(pub SharedString);
@@ -40,7 +39,7 @@ impl EditPredictionResult {
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
buffer_snapshotted_at: Instant,
response_received_at: Instant,
- inputs: EditPredictionInputs,
+ inputs: ZetaPromptInput,
cx: &mut AsyncApp,
) -> Self {
if edits.is_empty() {
@@ -94,15 +93,7 @@ pub struct EditPrediction {
pub buffer: Entity<Buffer>,
pub buffer_snapshotted_at: Instant,
pub response_received_at: Instant,
- pub inputs: EditPredictionInputs,
-}
-
-#[derive(Debug, Clone, Serialize)]
-pub struct EditPredictionInputs {
- pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
- pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
- pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
- pub cursor_path: Arc<Path>,
+ pub inputs: zeta_prompt::ZetaPromptInput,
}
impl EditPrediction {
@@ -133,9 +124,12 @@ impl std::fmt::Debug for EditPrediction {
#[cfg(test)]
mod tests {
+ use std::path::Path;
+
use super::*;
use gpui::{App, Entity, TestAppContext, prelude::*};
use language::{Buffer, ToOffset as _};
+ use zeta_prompt::ZetaPromptInput;
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
@@ -154,14 +148,13 @@ mod tests {
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
buffer: buffer.clone(),
edit_preview,
- inputs: EditPredictionInputs {
+ inputs: ZetaPromptInput {
events: vec![],
- included_files: vec![],
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- line: cloud_llm_client::predict_edits_v3::Line(0),
- column: 0,
- },
+ related_files: vec![].into(),
cursor_path: Path::new("path.txt").into(),
+ cursor_offset_in_excerpt: 0,
+ cursor_excerpt: "".into(),
+ editable_range_in_excerpt: 0..0,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
@@ -1,26 +1,21 @@
use anyhow::{Context as _, Result};
-use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
-use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
- App, AppContext as _, Entity, Task,
+ App, AppContext as _, Task,
http_client::{self, AsyncBody, Method},
};
-use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
+use language::{Point, ToOffset as _};
use lsp::DiagnosticSeverity;
-use project::{Project, ProjectPath};
use serde::{Deserialize, Serialize};
use std::{
- collections::VecDeque,
fmt::{self, Write as _},
- ops::Range,
path::Path,
sync::Arc,
time::Instant,
};
-use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
+use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
@@ -44,40 +39,34 @@ impl SweepAi {
pub fn request_prediction_with_sweep(
&self,
- project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
- snapshot: BufferSnapshot,
- position: language::Anchor,
- events: Vec<Arc<Event>>,
- recent_paths: &VecDeque<ProjectPath>,
- related_files: Vec<RelatedFile>,
- diagnostic_search_range: Range<Point>,
+ inputs: EditPredictionModelInput,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let debug_info = self.debug_info.clone();
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
return Task::ready(Ok(None));
};
- let full_path: Arc<Path> = snapshot
+ let full_path: Arc<Path> = inputs
+ .snapshot
.file()
.map(|file| file.full_path(cx))
.unwrap_or_else(|| "untitled".into())
.into();
- let project_file = project::File::from_dyn(snapshot.file());
+ let project_file = project::File::from_dyn(inputs.snapshot.file());
let repo_name = project_file
.map(|file| file.worktree.read(cx).root_name_str())
.unwrap_or("untitled")
.into();
- let offset = position.to_offset(&snapshot);
+ let offset = inputs.position.to_offset(&inputs.snapshot);
- let recent_buffers = recent_paths.iter().cloned();
+ let recent_buffers = inputs.recent_paths.iter().cloned();
let http_client = cx.http_client();
let recent_buffer_snapshots = recent_buffers
.filter_map(|project_path| {
- let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
- if active_buffer == &buffer {
+ let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
+ if inputs.buffer == buffer {
None
} else {
Some(buffer.read(cx).snapshot())
@@ -86,14 +75,13 @@ impl SweepAi {
.take(3)
.collect::<Vec<_>>();
- let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
let result = cx.background_spawn(async move {
- let text = snapshot.text();
+ let text = inputs.snapshot.text();
let mut recent_changes = String::new();
- for event in &events {
+ for event in &inputs.events {
write_event(event.as_ref(), &mut recent_changes).unwrap();
}
@@ -122,20 +110,23 @@ impl SweepAi {
})
.collect::<Vec<_>>();
- let retrieval_chunks = related_files
+ let retrieval_chunks = inputs
+ .related_files
.iter()
.flat_map(|related_file| {
related_file.excerpts.iter().map(|excerpt| FileChunk {
- file_path: related_file.path.path.as_unix_str().to_string(),
- start_line: excerpt.point_range.start.row as usize,
- end_line: excerpt.point_range.end.row as usize,
+ file_path: related_file.path.to_string_lossy().to_string(),
+ start_line: excerpt.row_range.start as usize,
+ end_line: excerpt.row_range.end as usize,
content: excerpt.text.to_string(),
timestamp: None,
})
})
.collect();
- let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
+ let diagnostic_entries = inputs
+ .snapshot
+ .diagnostics_in_range(inputs.diagnostic_search_range, false);
let mut diagnostic_content = String::new();
let mut diagnostic_count = 0;
@@ -195,21 +186,14 @@ impl SweepAi {
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
- let inputs = EditPredictionInputs {
- events,
- included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
- path: full_path.clone(),
- max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
- excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
- start_line: cloud_llm_client::predict_edits_v3::Line(0),
- text: request_body.file_contents.into(),
- }],
- }],
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- column: cursor_point.column,
- line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
- },
+ let ep_inputs = zeta_prompt::ZetaPromptInput {
+ events: inputs.events,
+ related_files: inputs.related_files.clone(),
cursor_path: full_path.clone(),
+ cursor_excerpt: request_body.file_contents.into(),
+ // we actually don't know
+ editable_range_in_excerpt: 0..inputs.snapshot.len(),
+ cursor_offset_in_excerpt: request_body.cursor_position,
};
let request = http_client::Request::builder()
@@ -237,15 +221,20 @@ impl SweepAi {
let response: AutocompleteResponse = serde_json::from_slice(&body)?;
- let old_text = snapshot
+ let old_text = inputs
+ .snapshot
.text_for_range(response.start_index..response.end_index)
.collect::<String>();
let edits = language::text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
- snapshot.anchor_after(response.start_index + range.start)
- ..snapshot.anchor_before(response.start_index + range.end),
+ inputs
+ .snapshot
+ .anchor_after(response.start_index + range.start)
+ ..inputs
+ .snapshot
+ .anchor_before(response.start_index + range.end),
text,
)
})
@@ -254,13 +243,13 @@ impl SweepAi {
anyhow::Ok((
response.autocomplete_id,
edits,
- snapshot,
+ inputs.snapshot,
response_received_at,
- inputs,
+ ep_inputs,
))
});
- let buffer = active_buffer.clone();
+ let buffer = inputs.buffer.clone();
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
@@ -403,12 +392,9 @@ struct AdditionalCompletion {
pub finish_reason: Option<String>,
}
-fn write_event(
- event: &cloud_llm_client::predict_edits_v3::Event,
- f: &mut impl fmt::Write,
-) -> fmt::Result {
+fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
match event {
- cloud_llm_client::predict_edits_v3::Event::BufferChange {
+ zeta_prompt::Event::BufferChange {
old_path,
path,
diff,
@@ -14,68 +14,18 @@ use anyhow::anyhow;
use collections::HashMap;
use gpui::AsyncApp;
use gpui::Entity;
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot};
+use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
use project::Project;
-pub async fn parse_diff<'a>(
- diff_str: &'a str,
- get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
-) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
- let mut diff = DiffParser::new(diff_str);
- let mut edited_buffer = None;
- let mut edits = Vec::new();
-
- while let Some(event) = diff.next()? {
- match event {
- DiffEvent::Hunk {
- path: file_path,
- hunk,
- } => {
- let (buffer, ranges) = match edited_buffer {
- None => {
- edited_buffer = get_buffer(&Path::new(file_path.as_ref()));
- edited_buffer
- .as_ref()
- .context("Model tried to edit a file that wasn't included")?
- }
- Some(ref current) => current,
- };
-
- edits.extend(
- resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)
- .with_context(|| format!("Diff:\n{diff_str}"))?,
- );
- }
- DiffEvent::FileEnd { renamed_to } => {
- let (buffer, _) = edited_buffer
- .take()
- .context("Got a FileEnd event before an Hunk event")?;
-
- if renamed_to.is_some() {
- anyhow::bail!("edit predictions cannot rename files");
- }
-
- if diff.next()?.is_some() {
- anyhow::bail!("Edited more than one file");
- }
-
- return Ok((buffer, edits));
- }
- }
- }
-
- Err(anyhow::anyhow!("No EOF"))
-}
-
-#[derive(Debug)]
-pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap<Cow<'a, str>, Entity<Buffer>>);
+#[derive(Clone, Debug)]
+pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
#[must_use]
-pub async fn apply_diff<'a>(
- diff_str: &'a str,
+pub async fn apply_diff(
+ diff_str: &str,
project: &Entity<Project>,
cx: &mut AsyncApp,
-) -> Result<OpenedBuffers<'a>> {
+) -> Result<OpenedBuffers> {
let mut included_files = HashMap::default();
for line in diff_str.lines() {
@@ -94,7 +44,7 @@ pub async fn apply_diff<'a>(
})??
.await?;
- included_files.insert(path, buffer);
+ included_files.insert(path.to_string(), buffer);
}
}
@@ -113,7 +63,7 @@ pub async fn apply_diff<'a>(
let (buffer, ranges) = match current_file {
None => {
let buffer = included_files
- .get_mut(&file_path)
+ .get_mut(file_path.as_ref())
.expect("Opened all files in diff");
current_file = Some((buffer, ranges.as_slice()));
@@ -167,6 +117,29 @@ pub async fn apply_diff<'a>(
Ok(OpenedBuffers(included_files))
}
+pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
+ let mut diff = DiffParser::new(diff_str);
+
+ let mut text = text.to_string();
+
+ while let Some(event) = diff.next()? {
+ match event {
+ DiffEvent::Hunk { hunk, .. } => {
+ let hunk_offset = text
+ .find(&hunk.context)
+ .ok_or_else(|| anyhow!("couldn't result hunk {:?}", hunk.context))?;
+ for edit in hunk.edits.iter().rev() {
+ let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
+ text.replace_range(range, &edit.text);
+ }
+ }
+ DiffEvent::FileEnd { .. } => {}
+ }
+ }
+
+ Ok(text)
+}
+
struct PatchFile<'a> {
old_path: Cow<'a, str>,
new_path: Cow<'a, str>,
@@ -492,7 +465,6 @@ mod tests {
use super::*;
use gpui::TestAppContext;
use indoc::indoc;
- use language::Point;
use pretty_assertions::assert_eq;
use project::{FakeFs, Project};
use serde_json::json;
@@ -817,137 +789,6 @@ mod tests {
});
}
- #[gpui::test]
- async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
- let fs = init_test(cx);
-
- let buffer_1_text = indoc! {r#"
- one
- two
- three
- four
- five
- one
- two
- three
- four
- five
- "# };
-
- fs.insert_tree(
- path!("/root"),
- json!({
- "file1": buffer_1_text,
- }),
- )
- .await;
-
- let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/root/file1"), cx)
- })
- .await
- .unwrap();
- let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
- let diff = indoc! {r#"
- --- a/root/file1
- +++ b/root/file1
- one
- two
- -three
- +3
- four
- five
- "#};
-
- let final_text = indoc! {r#"
- one
- two
- three
- four
- five
- one
- two
- 3
- four
- five
- "#};
-
- apply_diff(diff, &project, &mut cx.to_async())
- .await
- .expect_err("Non-unique edits should fail");
-
- let ranges = [buffer_snapshot.anchor_before(Point::new(1, 0))
- ..buffer_snapshot.anchor_after(buffer_snapshot.max_point())];
-
- let (edited_snapshot, edits) = parse_diff(diff, |_path| Some((&buffer_snapshot, &ranges)))
- .await
- .unwrap();
-
- assert_eq!(edited_snapshot.remote_id(), buffer_snapshot.remote_id());
- buffer.update(cx, |buffer, cx| {
- buffer.edit(edits, None, cx);
- assert_eq!(buffer.text(), final_text);
- });
- }
-
- #[gpui::test]
- async fn test_parse_diff_with_edits_within_line(cx: &mut TestAppContext) {
- let fs = init_test(cx);
-
- let buffer_1_text = indoc! {r#"
- one two three four
- five six seven eight
- nine ten eleven twelve
- "# };
-
- fs.insert_tree(
- path!("/root"),
- json!({
- "file1": buffer_1_text,
- }),
- )
- .await;
-
- let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/root/file1"), cx)
- })
- .await
- .unwrap();
- let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
- let diff = indoc! {r#"
- --- a/root/file1
- +++ b/root/file1
- one two three four
- -five six seven eight
- +five SIX seven eight!
- nine ten eleven twelve
- "#};
-
- let (buffer, edits) = parse_diff(diff, |_path| {
- Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_]))
- })
- .await
- .unwrap();
-
- let edits = edits
- .into_iter()
- .map(|(range, text)| (range.to_point(&buffer), text))
- .collect::<Vec<_>>();
- assert_eq!(
- edits,
- &[
- (Point::new(1, 5)..Point::new(1, 8), "SIX".into()),
- (Point::new(1, 20)..Point::new(1, 20), "!".into())
- ]
- );
- }
-
#[gpui::test]
async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
let fs = init_test(cx);
@@ -1,637 +0,0 @@
-use anyhow::{Context as _, Result};
-use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
-use std::{cmp, ops::Range, path::Path, sync::Arc};
-
-const EDITS_TAG_NAME: &'static str = "edits";
-const OLD_TEXT_TAG_NAME: &'static str = "old_text";
-const NEW_TEXT_TAG_NAME: &'static str = "new_text";
-const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME];
-
-pub async fn parse_xml_edits<'a>(
- input: &'a str,
- get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
-) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
- parse_xml_edits_inner(input, get_buffer)
- .await
- .with_context(|| format!("Failed to parse XML edits:\n{input}"))
-}
-
-async fn parse_xml_edits_inner<'a>(
- input: &'a str,
- get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
-) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
- let xml_edits = extract_xml_replacements(input)?;
-
- let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref())
- .with_context(|| format!("no buffer for file {}", xml_edits.file_path))?;
-
- let mut all_edits = vec![];
- for (old_text, new_text) in xml_edits.replacements {
- let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?;
- let matched_old_text = buffer
- .text_for_range(match_range.clone())
- .collect::<String>();
- let edits_within_hunk = language::text_diff(&matched_old_text, new_text);
- all_edits.extend(
- edits_within_hunk
- .into_iter()
- .map(move |(inner_range, inner_text)| {
- (
- buffer.anchor_after(match_range.start + inner_range.start)
- ..buffer.anchor_before(match_range.start + inner_range.end),
- inner_text,
- )
- }),
- );
- }
-
- Ok((buffer, all_edits))
-}
-
-fn fuzzy_match_in_ranges(
- old_text: &str,
- buffer: &BufferSnapshot,
- context_ranges: &[Range<Anchor>],
-) -> Result<Range<usize>> {
- let mut state = FuzzyMatcher::new(buffer, old_text);
- let mut best_match = None;
- let mut tie_match_range = None;
-
- for range in context_ranges {
- let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
- match (best_match_cost, state.match_range(range.to_offset(buffer))) {
- (Some(lowest_cost), Some((new_cost, new_range))) => {
- if new_cost == lowest_cost {
- tie_match_range = Some(new_range);
- } else if new_cost < lowest_cost {
- tie_match_range.take();
- best_match = Some((new_cost, new_range));
- }
- }
- (None, Some(new_match)) => {
- best_match = Some(new_match);
- }
- (None, None) | (Some(_), None) => {}
- };
- }
-
- if let Some((_, best_match_range)) = best_match {
- if let Some(tie_match_range) = tie_match_range {
- anyhow::bail!(
- "Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
- best_match_range.clone(),
- buffer.text_for_range(best_match_range).collect::<String>(),
- tie_match_range.clone(),
- buffer.text_for_range(tie_match_range).collect::<String>()
- );
- }
- return Ok(best_match_range);
- }
-
- anyhow::bail!(
- "Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
- old_text,
- context_ranges
- .iter()
- .map(|range| buffer.text_for_range(range.clone()).collect::<String>())
- .collect::<Vec<String>>()
- .join("```\n```")
- );
-}
-
-#[derive(Debug)]
-struct XmlEdits<'a> {
- file_path: &'a str,
- /// Vec of (old_text, new_text) pairs
- replacements: Vec<(&'a str, &'a str)>,
-}
-
-fn extract_xml_replacements(input: &str) -> Result<XmlEdits<'_>> {
- let mut cursor = 0;
-
- let (edits_body_start, edits_attrs) =
- find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?;
-
- let file_path = edits_attrs
- .trim_start()
- .strip_prefix("path")
- .context("no path attribute on edits tag")?
- .trim_end()
- .strip_prefix('=')
- .context("no value for path attribute")?
- .trim()
- .trim_start_matches('"')
- .trim_end_matches('"');
-
- cursor = edits_body_start;
- let mut edits_list = Vec::new();
-
- while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? {
- let old_body_end = find_tag_close(input, &mut cursor)?;
- let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]);
-
- let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)?
- .context("no new_text tag following old_text")?;
- let new_body_end = find_tag_close(input, &mut cursor)?;
- let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]);
-
- edits_list.push((old_text, new_text));
- }
-
- Ok(XmlEdits {
- file_path,
- replacements: edits_list,
- })
-}
-
-/// Trims a single leading and trailing newline
-fn trim_surrounding_newlines(input: &str) -> &str {
- let start = input.strip_prefix('\n').unwrap_or(input);
- let end = start.strip_suffix('\n').unwrap_or(start);
- end
-}
-
-fn find_tag_open<'a>(
- input: &'a str,
- cursor: &mut usize,
- expected_tag: &str,
-) -> Result<Option<(usize, &'a str)>> {
- let mut search_pos = *cursor;
-
- while search_pos < input.len() {
- let Some(tag_start) = input[search_pos..].find("<") else {
- break;
- };
- let tag_start = search_pos + tag_start;
- if !input[tag_start + 1..].starts_with(expected_tag) {
- search_pos = search_pos + tag_start + 1;
- continue;
- };
-
- let after_tag_name = tag_start + expected_tag.len() + 1;
- let close_bracket = input[after_tag_name..]
- .find('>')
- .with_context(|| format!("missing > after <{}", expected_tag))?;
- let attrs_end = after_tag_name + close_bracket;
- let body_start = attrs_end + 1;
-
- let attributes = input[after_tag_name..attrs_end].trim();
- *cursor = body_start;
-
- return Ok(Some((body_start, attributes)));
- }
-
- Ok(None)
-}
-
-fn find_tag_close(input: &str, cursor: &mut usize) -> Result<usize> {
- let mut depth = 1;
- let mut search_pos = *cursor;
-
- while search_pos < input.len() && depth > 0 {
- let Some(bracket_offset) = input[search_pos..].find('<') else {
- break;
- };
- let bracket_pos = search_pos + bracket_offset;
-
- if input[bracket_pos..].starts_with("</")
- && let Some(close_end) = input[bracket_pos + 2..].find('>')
- {
- let close_start = bracket_pos + 2;
- let tag_name = input[close_start..close_start + close_end].trim();
-
- if XML_TAGS.contains(&tag_name) {
- depth -= 1;
- if depth == 0 {
- *cursor = close_start + close_end + 1;
- return Ok(bracket_pos);
- }
- }
- search_pos = close_start + close_end + 1;
- continue;
- } else if let Some(close_bracket_offset) = input[bracket_pos..].find('>') {
- let close_bracket_pos = bracket_pos + close_bracket_offset;
- let tag_name = &input[bracket_pos + 1..close_bracket_pos].trim();
- if XML_TAGS.contains(&tag_name) {
- depth += 1;
- }
- }
-
- search_pos = bracket_pos + 1;
- }
-
- anyhow::bail!("no closing tag found")
-}
-
-const REPLACEMENT_COST: u32 = 1;
-const INSERTION_COST: u32 = 3;
-const DELETION_COST: u32 = 10;
-
-/// A fuzzy matcher that can process text chunks incrementally
-/// and return the best match found so far at each step.
-struct FuzzyMatcher<'a> {
- snapshot: &'a BufferSnapshot,
- query_lines: Vec<&'a str>,
- matrix: SearchMatrix,
-}
-
-impl<'a> FuzzyMatcher<'a> {
- fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self {
- let query_lines = old_text.lines().collect();
- Self {
- snapshot,
- query_lines,
- matrix: SearchMatrix::new(0),
- }
- }
-
- fn match_range(&mut self, range: Range<usize>) -> Option<(u32, Range<usize>)> {
- let point_range = range.to_point(&self.snapshot);
- let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize;
-
- self.matrix
- .reset(self.query_lines.len() + 1, buffer_line_count + 1);
- let query_line_count = self.query_lines.len();
-
- for row in 0..query_line_count {
- let query_line = self.query_lines[row].trim();
- let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
-
- self.matrix.set(
- row + 1,
- 0,
- SearchState::new(leading_deletion_cost, SearchDirection::Up),
- );
-
- let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines();
-
- let mut col = 0;
- while let Some(buffer_line) = buffer_lines.next() {
- let buffer_line = buffer_line.trim();
- let up = SearchState::new(
- self.matrix
- .get(row, col + 1)
- .cost
- .saturating_add(DELETION_COST),
- SearchDirection::Up,
- );
- let left = SearchState::new(
- self.matrix
- .get(row + 1, col)
- .cost
- .saturating_add(INSERTION_COST),
- SearchDirection::Left,
- );
- let diagonal = SearchState::new(
- if query_line == buffer_line {
- self.matrix.get(row, col).cost
- } else if fuzzy_eq(query_line, buffer_line) {
- self.matrix.get(row, col).cost + REPLACEMENT_COST
- } else {
- self.matrix
- .get(row, col)
- .cost
- .saturating_add(DELETION_COST + INSERTION_COST)
- },
- SearchDirection::Diagonal,
- );
- self.matrix
- .set(row + 1, col + 1, up.min(left).min(diagonal));
- col += 1;
- }
- }
-
- // Find all matches with the best cost
- let mut best_cost = u32::MAX;
- let mut matches_with_best_cost = Vec::new();
-
- for col in 1..=buffer_line_count {
- let cost = self.matrix.get(query_line_count, col).cost;
- if cost < best_cost {
- best_cost = cost;
- matches_with_best_cost.clear();
- matches_with_best_cost.push(col as u32);
- } else if cost == best_cost {
- matches_with_best_cost.push(col as u32);
- }
- }
-
- // Find ranges for the matches
- for &match_end_col in &matches_with_best_cost {
- let mut matched_lines = 0;
- let mut query_row = query_line_count;
- let mut match_start_col = match_end_col;
- while query_row > 0 && match_start_col > 0 {
- let current = self.matrix.get(query_row, match_start_col as usize);
- match current.direction {
- SearchDirection::Diagonal => {
- query_row -= 1;
- match_start_col -= 1;
- matched_lines += 1;
- }
- SearchDirection::Up => {
- query_row -= 1;
- }
- SearchDirection::Left => {
- match_start_col -= 1;
- }
- }
- }
-
- let buffer_row_start = match_start_col + point_range.start.row;
- let buffer_row_end = match_end_col + point_range.start.row;
-
- let matched_buffer_row_count = buffer_row_end - buffer_row_start;
- let matched_ratio = matched_lines as f32
- / (matched_buffer_row_count as f32).max(query_line_count as f32);
- if matched_ratio >= 0.8 {
- let buffer_start_ix = self
- .snapshot
- .point_to_offset(Point::new(buffer_row_start, 0));
- let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
- buffer_row_end - 1,
- self.snapshot.line_len(buffer_row_end - 1),
- ));
- return Some((best_cost, buffer_start_ix..buffer_end_ix));
- }
- }
-
- None
- }
-}
-
-fn fuzzy_eq(left: &str, right: &str) -> bool {
- const THRESHOLD: f64 = 0.8;
-
- let min_levenshtein = left.len().abs_diff(right.len());
- let min_normalized_levenshtein =
- 1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
- if min_normalized_levenshtein < THRESHOLD {
- return false;
- }
-
- strsim::normalized_levenshtein(left, right) >= THRESHOLD
-}
-
-#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
-enum SearchDirection {
- Up,
- Left,
- Diagonal,
-}
-
-#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
-struct SearchState {
- cost: u32,
- direction: SearchDirection,
-}
-
-impl SearchState {
- fn new(cost: u32, direction: SearchDirection) -> Self {
- Self { cost, direction }
- }
-}
-
-struct SearchMatrix {
- cols: usize,
- rows: usize,
- data: Vec<SearchState>,
-}
-
-impl SearchMatrix {
- fn new(cols: usize) -> Self {
- SearchMatrix {
- cols,
- rows: 0,
- data: Vec::new(),
- }
- }
-
- fn reset(&mut self, rows: usize, cols: usize) {
- self.rows = rows;
- self.cols = cols;
- self.data
- .fill(SearchState::new(0, SearchDirection::Diagonal));
- self.data.resize(
- self.rows * self.cols,
- SearchState::new(0, SearchDirection::Diagonal),
- );
- }
-
- fn get(&self, row: usize, col: usize) -> SearchState {
- debug_assert!(row < self.rows);
- debug_assert!(col < self.cols);
- self.data[row * self.cols + col]
- }
-
- fn set(&mut self, row: usize, col: usize, state: SearchState) {
- debug_assert!(row < self.rows && col < self.cols);
- self.data[row * self.cols + col] = state;
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use gpui::TestAppContext;
- use indoc::indoc;
- use language::Point;
- use project::{FakeFs, Project};
- use serde_json::json;
- use settings::SettingsStore;
- use util::path;
-
- #[test]
- fn test_extract_xml_edits() {
- let input = indoc! {r#"
- <edits path="test.rs">
- <old_text>
- old content
- </old_text>
- <new_text>
- new content
- </new_text>
- </edits>
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "test.rs");
- assert_eq!(result.replacements.len(), 1);
- assert_eq!(result.replacements[0].0, "old content");
- assert_eq!(result.replacements[0].1, "new content");
- }
-
- #[test]
- fn test_extract_xml_edits_with_wrong_closing_tags() {
- let input = indoc! {r#"
- <edits path="test.rs">
- <old_text>
- old content
- </new_text>
- <new_text>
- new content
- </old_text>
- </ edits >
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "test.rs");
- assert_eq!(result.replacements.len(), 1);
- assert_eq!(result.replacements[0].0, "old content");
- assert_eq!(result.replacements[0].1, "new content");
- }
-
- #[test]
- fn test_extract_xml_edits_with_xml_like_content() {
- let input = indoc! {r#"
- <edits path="component.tsx">
- <old_text>
- <foo><bar></bar></foo>
- </old_text>
- <new_text>
- <foo><bar><baz></baz></bar></foo>
- </new_text>
- </edits>
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "component.tsx");
- assert_eq!(result.replacements.len(), 1);
- assert_eq!(result.replacements[0].0, "<foo><bar></bar></foo>");
- assert_eq!(
- result.replacements[0].1,
- "<foo><bar><baz></baz></bar></foo>"
- );
- }
-
- #[test]
- fn test_extract_xml_edits_with_conflicting_content() {
- let input = indoc! {r#"
- <edits path="component.tsx">
- <old_text>
- <new_text></new_text>
- </old_text>
- <new_text>
- <old_text></old_text>
- </new_text>
- </edits>
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "component.tsx");
- assert_eq!(result.replacements.len(), 1);
- assert_eq!(result.replacements[0].0, "<new_text></new_text>");
- assert_eq!(result.replacements[0].1, "<old_text></old_text>");
- }
-
- #[test]
- fn test_extract_xml_edits_multiple_pairs() {
- let input = indoc! {r#"
- Some reasoning before edits. Lots of thinking going on here
-
- <edits path="test.rs">
- <old_text>
- first old
- </old_text>
- <new_text>
- first new
- </new_text>
- <old_text>
- second old
- </edits>
- <new_text>
- second new
- </old_text>
- </edits>
- "#};
-
- let result = extract_xml_replacements(input).unwrap();
- assert_eq!(result.file_path, "test.rs");
- assert_eq!(result.replacements.len(), 2);
- assert_eq!(result.replacements[0].0, "first old");
- assert_eq!(result.replacements[0].1, "first new");
- assert_eq!(result.replacements[1].0, "second old");
- assert_eq!(result.replacements[1].1, "second new");
- }
-
- #[test]
- fn test_extract_xml_edits_unexpected_eof() {
- let input = indoc! {r#"
- <edits path="test.rs">
- <old_text>
- first old
- </
- "#};
-
- extract_xml_replacements(input).expect_err("Unexpected end of file");
- }
-
- #[gpui::test]
- async fn test_parse_xml_edits(cx: &mut TestAppContext) {
- let fs = init_test(cx);
-
- let buffer_1_text = indoc! {r#"
- one two three four
- five six seven eight
- nine ten eleven twelve
- thirteen fourteen fifteen
- sixteen seventeen eighteen
- "#};
-
- fs.insert_tree(
- path!("/root"),
- json!({
- "file1": buffer_1_text,
- }),
- )
- .await;
-
- let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
- let buffer = project
- .update(cx, |project, cx| {
- project.open_local_buffer(path!("/root/file1"), cx)
- })
- .await
- .unwrap();
- let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
-
- let edits = indoc! {r#"
- <edits path="root/file1">
- <old_text>
- nine ten eleven twelve
- </old_text>
- <new_text>
- nine TEN eleven twelve!
- </new_text>
- </edits>
- "#};
-
- let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
- let (buffer, edits) = parse_xml_edits(edits, |_path| {
- Some((&buffer_snapshot, included_ranges.as_slice()))
- })
- .await
- .unwrap();
-
- let edits = edits
- .into_iter()
- .map(|(range, text)| (range.to_point(&buffer), text))
- .collect::<Vec<_>>();
- assert_eq!(
- edits,
- &[
- (Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
- (Point::new(2, 22)..Point::new(2, 22), "!".into())
- ]
- );
- }
-
- fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- });
-
- FakeFs::new(cx.background_executor.clone())
- }
-}
@@ -1,22 +1,23 @@
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
use crate::{
- EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
+ DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
+ EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
- prediction::{EditPredictionInputs, EditPredictionResult},
+ prediction::EditPredictionResult,
};
use anyhow::{Context as _, Result};
use cloud_llm_client::{
PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
- predict_edits_v3::Event,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
use language::{
- Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
+ Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
};
use project::{Project, ProjectPath};
use release_channel::AppVersion;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
+use zeta_prompt::{Event, ZetaPromptInput};
const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
@@ -29,24 +30,27 @@ pub(crate) const MAX_EVENT_TOKENS: usize = 500;
pub(crate) fn request_prediction_with_zeta1(
store: &mut EditPredictionStore,
- project: &Entity<Project>,
- buffer: &Entity<Buffer>,
- snapshot: BufferSnapshot,
- position: language::Anchor,
- events: Vec<Arc<Event>>,
- trigger: PredictEditsRequestTrigger,
+ EditPredictionModelInput {
+ project,
+ buffer,
+ snapshot,
+ position,
+ events,
+ trigger,
+ debug_tx,
+ ..
+ }: EditPredictionModelInput,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
- let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
- let can_collect_file = store.can_collect_file(project, file, cx);
+ let can_collect_file = store.can_collect_file(&project, file, cx);
let git_info = if can_collect_file {
- git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
+ git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
} else {
None
};
@@ -120,33 +124,33 @@ pub(crate) fn request_prediction_with_zeta1(
)
.await;
- let inputs = EditPredictionInputs {
+ let context_start_offset = context_range.start.to_offset(&snapshot);
+ let editable_offset_range = editable_range.to_offset(&snapshot);
+
+ let inputs = ZetaPromptInput {
events: included_events.into(),
- included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
- path: full_path.clone(),
- max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
- excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
- start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
- text: snapshot
- .text_for_range(context_range)
- .collect::<String>()
- .into(),
- }],
- }],
- cursor_point: cloud_llm_client::predict_edits_v3::Point {
- column: cursor_point.column,
- line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
- },
+ related_files: vec![].into(),
cursor_path: full_path,
+ cursor_excerpt: snapshot
+ .text_for_range(context_range)
+ .collect::<String>()
+ .into(),
+ editable_range_in_excerpt: (editable_range.start - context_start_offset)
+ ..(editable_offset_range.end - context_start_offset),
+ cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
};
- // let response = perform_predict_edits(PerformPredictEditsParams {
- // client,
- // llm_token,
- // app_version,
- // body,
- // })
- // .await;
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionStarted(
+ EditPredictionStartedDebugEvent {
+ buffer: buffer.downgrade(),
+ prompt: Some(serde_json::to_string(&inputs).unwrap()),
+ position,
+ },
+ ))
+ .ok();
+ }
let (response, usage) = match response {
Ok(response) => response,
@@ -189,6 +193,18 @@ pub(crate) fn request_prediction_with_zeta1(
.ok();
}
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionFinished(
+ EditPredictionFinishedDebugEvent {
+ buffer: buffer.downgrade(),
+ model_output: Some(response.output_excerpt.clone()),
+ position,
+ },
+ ))
+ .ok();
+ }
+
let edit_prediction = process_completion_response(
response,
buffer,
@@ -226,7 +242,7 @@ fn process_completion_response(
buffer: Entity<Buffer>,
snapshot: &BufferSnapshot,
editable_range: Range<usize>,
- inputs: EditPredictionInputs,
+ inputs: ZetaPromptInput,
buffer_snapshotted_at: Instant,
received_response_at: Instant,
cx: &AsyncApp,
@@ -3,46 +3,39 @@ use crate::EvalCacheEntryKind;
use crate::open_ai_response::text_from_response;
use crate::prediction::EditPredictionResult;
use crate::{
- DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
- EditPredictionRequestedDebugEvent, EditPredictionStore,
+ DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, EditPredictionId,
+ EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
};
-use anyhow::{Result, anyhow, bail};
-use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
-use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
-use cloud_zeta2_prompt::CURSOR_MARKER;
-use edit_prediction_context::{EditPredictionExcerpt, Line};
-use edit_prediction_context::{RelatedExcerpt, RelatedFile};
-use futures::channel::oneshot;
-use gpui::{Entity, Task, prelude::*};
-use language::{Anchor, BufferSnapshot};
-use language::{Buffer, Point, ToOffset as _, ToPoint};
-use project::{Project, ProjectItem as _};
+use anyhow::{Result, anyhow};
+use cloud_llm_client::EditPredictionRejectReason;
+use gpui::{Task, prelude::*};
+use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
use release_channel::AppVersion;
-use std::{
- env,
- path::Path,
- sync::Arc,
- time::{Duration, Instant},
-};
+use std::{path::Path, sync::Arc, time::Instant};
+use zeta_prompt::CURSOR_MARKER;
+use zeta_prompt::format_zeta_prompt;
+
+const MAX_CONTEXT_TOKENS: usize = 150;
+const MAX_REWRITE_TOKENS: usize = 350;
pub fn request_prediction_with_zeta2(
store: &mut EditPredictionStore,
- project: &Entity<Project>,
- active_buffer: &Entity<Buffer>,
- active_snapshot: BufferSnapshot,
- position: Anchor,
- events: Vec<Arc<Event>>,
- mut included_files: Vec<RelatedFile>,
- trigger: PredictEditsRequestTrigger,
+ EditPredictionModelInput {
+ buffer,
+ snapshot,
+ position,
+ related_files,
+ events,
+ debug_tx,
+ ..
+ }: EditPredictionModelInput,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
- let options = store.options.clone();
let buffer_snapshotted_at = Instant::now();
- let Some((excerpt_path, active_project_path)) = active_snapshot
+ let Some(excerpt_path) = snapshot
.file()
.map(|file| -> Arc<Path> { file.full_path(cx).into() })
- .zip(active_buffer.read(cx).project_path(cx))
else {
return Task::ready(Err(anyhow!("No file path for excerpt")));
};
@@ -50,148 +43,35 @@ pub fn request_prediction_with_zeta2(
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
- let debug_tx = store.debug_tx.clone();
-
- let file = active_buffer.read(cx).file();
-
- let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
-
- // TODO data collection
- let can_collect_data = file
- .as_ref()
- .map_or(false, |file| store.can_collect_file(project, file, cx));
#[cfg(feature = "eval-support")]
let eval_cache = store.eval_cache.clone();
let request_task = cx.background_spawn({
- let active_buffer = active_buffer.clone();
async move {
- let cursor_offset = position.to_offset(&active_snapshot);
- let cursor_point = cursor_offset.to_point(&active_snapshot);
-
- let before_retrieval = Instant::now();
-
- let excerpt_options = options.context;
-
- let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
- cursor_point,
- &active_snapshot,
- &excerpt_options,
- ) else {
- return Ok((None, None));
- };
-
- let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
- ..active_snapshot.anchor_before(excerpt.range.end);
- let related_excerpt = RelatedExcerpt {
- anchor_range: excerpt_anchor_range.clone(),
- point_range: Point::new(excerpt.line_range.start.0, 0)
- ..Point::new(excerpt.line_range.end.0, 0),
- text: active_snapshot.as_rope().slice(excerpt.range),
- };
-
- if let Some(buffer_ix) = included_files
- .iter()
- .position(|file| file.buffer.entity_id() == active_buffer.entity_id())
- {
- let file = &mut included_files[buffer_ix];
- file.excerpts.push(related_excerpt);
- file.merge_excerpts();
- let last_ix = included_files.len() - 1;
- included_files.swap(buffer_ix, last_ix);
- } else {
- let active_file = RelatedFile {
- path: active_project_path,
- buffer: active_buffer.downgrade(),
- excerpts: vec![related_excerpt],
- max_row: active_snapshot.max_point().row,
- };
- included_files.push(active_file);
- }
-
- let included_files = included_files
- .iter()
- .map(|related_file| predict_edits_v3::RelatedFile {
- path: Arc::from(related_file.path.path.as_std_path()),
- max_row: Line(related_file.max_row),
- excerpts: related_file
- .excerpts
- .iter()
- .map(|excerpt| predict_edits_v3::Excerpt {
- start_line: Line(excerpt.point_range.start.row),
- text: excerpt.text.to_string().into(),
- })
- .collect(),
- })
- .collect::<Vec<_>>();
-
- let cloud_request = predict_edits_v3::PredictEditsRequest {
- excerpt_path,
- excerpt: String::new(),
- excerpt_line_range: Line(0)..Line(0),
- excerpt_range: 0..0,
- cursor_point: predict_edits_v3::Point {
- line: predict_edits_v3::Line(cursor_point.row),
- column: cursor_point.column,
- },
- related_files: included_files,
+ let cursor_offset = position.to_offset(&snapshot);
+ let (editable_offset_range, prompt_input) = zeta2_prompt_input(
+ &snapshot,
+ related_files,
events,
- can_collect_data,
- debug_info: debug_tx.is_some(),
- prompt_max_bytes: Some(options.max_prompt_bytes),
- prompt_format: options.prompt_format,
- excerpt_parent: None,
- git_info: None,
- trigger,
- };
-
- let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
-
- let inputs = EditPredictionInputs {
- included_files: cloud_request.related_files,
- events: cloud_request.events,
- cursor_point: cloud_request.cursor_point,
- cursor_path: cloud_request.excerpt_path,
- };
-
- let retrieval_time = Instant::now() - before_retrieval;
+ excerpt_path,
+ cursor_offset,
+ );
- let debug_response_tx = if let Some(debug_tx) = &debug_tx {
- let (response_tx, response_rx) = oneshot::channel();
+ let prompt = format_zeta_prompt(&prompt_input);
+ if let Some(debug_tx) = &debug_tx {
debug_tx
- .unbounded_send(DebugEvent::EditPredictionRequested(
- EditPredictionRequestedDebugEvent {
- inputs: inputs.clone(),
- retrieval_time,
- buffer: active_buffer.downgrade(),
- local_prompt: match prompt_result.as_ref() {
- Ok(prompt) => Ok(prompt.clone()),
- Err(err) => Err(err.to_string()),
- },
+ .unbounded_send(DebugEvent::EditPredictionStarted(
+ EditPredictionStartedDebugEvent {
+ buffer: buffer.downgrade(),
+ prompt: Some(prompt.clone()),
position,
- response_rx,
},
))
.ok();
- Some(response_tx)
- } else {
- None
- };
-
- if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
- if let Some(debug_response_tx) = debug_response_tx {
- debug_response_tx
- .send((Err("Request skipped".to_string()), Duration::ZERO))
- .ok();
- }
- anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
}
- let prompt = prompt_result?;
- let generation_params =
- cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
let request = open_ai::Request {
model: EDIT_PREDICTIONS_MODEL_ID.clone(),
messages: vec![open_ai::RequestMessage::User {
@@ -199,8 +79,8 @@ pub fn request_prediction_with_zeta2(
}],
stream: false,
max_completion_tokens: None,
- stop: generation_params.stop.unwrap_or_default(),
- temperature: generation_params.temperature.or(Some(0.7)),
+ stop: Default::default(),
+ temperature: Default::default(),
tool_choice: None,
parallel_tool_calls: None,
tools: vec![],
@@ -210,7 +90,6 @@ pub fn request_prediction_with_zeta2(
log::trace!("Sending edit prediction request");
- let before_request = Instant::now();
let response = EditPredictionStore::send_raw_llm_request(
request,
client,
@@ -223,68 +102,53 @@ pub fn request_prediction_with_zeta2(
)
.await;
let received_response_at = Instant::now();
- let request_time = received_response_at - before_request;
log::trace!("Got edit prediction response");
- if let Some(debug_response_tx) = debug_response_tx {
- debug_response_tx
- .send((
- response
- .as_ref()
- .map_err(|err| err.to_string())
- .map(|response| response.0.clone()),
- request_time,
- ))
- .ok();
- }
-
let (res, usage) = response?;
let request_id = EditPredictionId(res.id.clone().into());
let Some(mut output_text) = text_from_response(res) else {
return Ok((Some((request_id, None)), usage));
};
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(DebugEvent::EditPredictionFinished(
+ EditPredictionFinishedDebugEvent {
+ buffer: buffer.downgrade(),
+ position,
+ model_output: Some(output_text.clone()),
+ },
+ ))
+ .ok();
+ }
+
if output_text.contains(CURSOR_MARKER) {
log::trace!("Stripping out {CURSOR_MARKER} from response");
output_text = output_text.replace(CURSOR_MARKER, "");
}
- let get_buffer_from_context = |path: &Path| {
- if Some(path) == active_file_full_path.as_deref() {
- Some((
- &active_snapshot,
- std::slice::from_ref(&excerpt_anchor_range),
- ))
- } else {
- None
- }
- };
-
- let (_, edits) = match options.prompt_format {
- PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
- if output_text.contains("--- a/\n+++ b/\nNo edits") {
- let edits = vec![];
- (&active_snapshot, edits)
- } else {
- crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
- }
- }
- PromptFormat::OldTextNewText => {
- crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
- }
- _ => {
- bail!("unsupported prompt format {}", options.prompt_format)
- }
- };
+ let old_text = snapshot
+ .text_for_range(editable_offset_range.clone())
+ .collect::<String>();
+ let edits: Vec<_> = language::text_diff(&old_text, &output_text)
+ .into_iter()
+ .map(|(range, text)| {
+ (
+ snapshot.anchor_after(editable_offset_range.start + range.start)
+ ..snapshot.anchor_before(editable_offset_range.start + range.end),
+ text,
+ )
+ })
+ .collect();
anyhow::Ok((
Some((
request_id,
Some((
- inputs,
- active_buffer,
- active_snapshot.clone(),
+ prompt_input,
+ buffer,
+ snapshot.clone(),
edits,
received_response_at,
)),
@@ -325,3 +189,40 @@ pub fn request_prediction_with_zeta2(
))
})
}
+
+pub fn zeta2_prompt_input(
+ snapshot: &language::BufferSnapshot,
+ related_files: Arc<[zeta_prompt::RelatedFile]>,
+ events: Vec<Arc<zeta_prompt::Event>>,
+ excerpt_path: Arc<Path>,
+ cursor_offset: usize,
+) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
+ let cursor_point = cursor_offset.to_point(snapshot);
+
+ let (editable_range, context_range) =
+ crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
+ cursor_point,
+ snapshot,
+ MAX_CONTEXT_TOKENS,
+ MAX_REWRITE_TOKENS,
+ );
+
+ let context_start_offset = context_range.start.to_offset(snapshot);
+ let editable_offset_range = editable_range.to_offset(snapshot);
+ let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
+ let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
+ ..(editable_offset_range.end - context_start_offset);
+
+ let prompt_input = zeta_prompt::ZetaPromptInput {
+ cursor_path: excerpt_path,
+ cursor_excerpt: snapshot
+ .text_for_range(context_range)
+ .collect::<String>()
+ .into(),
+ editable_range_in_excerpt,
+ cursor_offset_in_excerpt,
+ events,
+ related_files,
+ };
+ (editable_offset_range, prompt_input)
+}
@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
workspace = true
[[bin]]
-name = "ep_cli"
+name = "ep"
path = "src/main.rs"
[dependencies]
@@ -20,10 +20,9 @@ chrono.workspace = true
clap.workspace = true
client.workspace = true
cloud_llm_client.workspace= true
-cloud_zeta2_prompt.workspace = true
collections.workspace = true
debug_adapter_extension.workspace = true
-edit_prediction_context.workspace = true
+dirs.workspace = true
extension.workspace = true
fs.workspace = true
futures.workspace = true
@@ -51,12 +50,21 @@ smol.workspace = true
sqlez.workspace = true
sqlez_macros.workspace = true
terminal_view.workspace = true
-toml.workspace = true
util.workspace = true
watch.workspace = true
edit_prediction = { workspace = true, features = ["eval-support"] }
+wasmtime.workspace = true
+zeta_prompt.workspace = true
zlog.workspace = true
+# Wasmtime is included as a dependency in order to enable the same
+# features that are enabled in Zed.
+#
+# If we don't enable these features we get crashes when creating
+# a Tree-sitter WasmStore.
+[package.metadata.cargo-machete]
+ignored = ["wasmtime"]
+
[dev-dependencies]
indoc.workspace = true
gpui = { workspace = true, features = ["test-support"] }
@@ -5,11 +5,13 @@ use anthropic::{
use anyhow::Result;
use http_client::HttpClient;
use indoc::indoc;
+use reqwest_client::ReqwestClient;
use sqlez::bindable::Bind;
use sqlez::bindable::StaticColumnCount;
use sqlez_macros::sql;
use std::hash::Hash;
use std::hash::Hasher;
+use std::path::Path;
use std::sync::Arc;
pub struct PlainLlmClient {
@@ -18,7 +20,8 @@ pub struct PlainLlmClient {
}
impl PlainLlmClient {
- fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
+ fn new() -> Result<Self> {
+ let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
Ok(Self {
@@ -29,12 +32,12 @@ impl PlainLlmClient {
async fn generate(
&self,
- model: String,
+ model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<AnthropicResponse> {
let request = AnthropicRequest {
- model,
+ model: model.to_string(),
max_tokens,
messages,
tools: Vec::new(),
@@ -105,11 +108,12 @@ struct SerializableMessage {
}
impl BatchingLlmClient {
- fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
+ fn new(cache_path: &Path) -> Result<Self> {
+ let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
- let connection = sqlez::connection::Connection::open_file(&cache_path);
+ let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
let mut statement = sqlez::statement::Statement::prepare(
&connection,
indoc! {"
@@ -182,16 +186,16 @@ impl BatchingLlmClient {
async fn generate(
&self,
- model: String,
+ model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<Option<AnthropicResponse>> {
- let response = self.lookup(&model, max_tokens, &messages)?;
+ let response = self.lookup(model, max_tokens, &messages)?;
if let Some(response) = response {
return Ok(Some(response));
}
- self.mark_for_batch(&model, max_tokens, &messages)?;
+ self.mark_for_batch(model, max_tokens, &messages)?;
Ok(None)
}
@@ -258,7 +262,7 @@ impl BatchingLlmClient {
}
}
}
- log::info!("Uploaded {} successful requests", success_count);
+ log::info!("Downloaded {} successful requests", success_count);
}
}
@@ -363,23 +367,20 @@ fn message_content_to_string(content: &[RequestContent]) -> String {
.join("\n")
}
-pub enum LlmClient {
+pub enum AnthropicClient {
// No batching
Plain(PlainLlmClient),
Batch(BatchingLlmClient),
Dummy,
}
-impl LlmClient {
- pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
- Ok(Self::Plain(PlainLlmClient::new(http_client)?))
+impl AnthropicClient {
+ pub fn plain() -> Result<Self> {
+ Ok(Self::Plain(PlainLlmClient::new()?))
}
- pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
- Ok(Self::Batch(BatchingLlmClient::new(
- cache_path,
- http_client,
- )?))
+ pub fn batch(cache_path: &Path) -> Result<Self> {
+ Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
}
#[allow(dead_code)]
@@ -389,29 +390,29 @@ impl LlmClient {
pub async fn generate(
&self,
- model: String,
+ model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<Option<AnthropicResponse>> {
match self {
- LlmClient::Plain(plain_llm_client) => plain_llm_client
+ AnthropicClient::Plain(plain_llm_client) => plain_llm_client
.generate(model, max_tokens, messages)
.await
.map(Some),
- LlmClient::Batch(batching_llm_client) => {
+ AnthropicClient::Batch(batching_llm_client) => {
batching_llm_client
.generate(model, max_tokens, messages)
.await
}
- LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+ AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
}
}
pub async fn sync_batches(&self) -> Result<()> {
match self {
- LlmClient::Plain(_) => Ok(()),
- LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
- LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+ AnthropicClient::Plain(_) => Ok(()),
+ AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
+ AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
}
}
}
@@ -1,641 +0,0 @@
-use crate::metrics::{self, Scores};
-use std::{
- collections::HashMap,
- io::{IsTerminal, Write},
- sync::Arc,
-};
-
-use anyhow::Result;
-use edit_prediction::{EditPredictionStore, udiff::DiffLine};
-use gpui::{AsyncApp, Entity};
-use project::Project;
-use util::ResultExt as _;
-
-use crate::{
- EvaluateArguments, PredictionOptions,
- example::{Example, NamedExample},
- headless::ZetaCliAppState,
- paths::print_run_data_dir,
- predict::{PredictionDetails, perform_predict, setup_store},
-};
-
-#[derive(Debug)]
-pub(crate) struct ExecutionData {
- execution_id: String,
- diff: String,
- reasoning: String,
-}
-
-pub async fn run_evaluate(
- args: EvaluateArguments,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
-) {
- if args.example_paths.is_empty() {
- eprintln!("No examples provided");
- return;
- }
-
- let all_tasks = args.example_paths.into_iter().map(|path| {
- let options = args.options.clone();
- let app_state = app_state.clone();
- let example = NamedExample::load(&path).expect("Failed to load example");
-
- cx.spawn(async move |cx| {
- let project = example.setup_project(&app_state, cx).await.unwrap();
-
- let providers = (0..args.repetitions)
- .map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap())
- .collect::<Vec<_>>();
-
- let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
-
- let tasks = providers
- .into_iter()
- .enumerate()
- .map(move |(repetition_ix, store)| {
- let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
- let example = example.clone();
- let project = project.clone();
- let options = options.clone();
-
- cx.spawn(async move |cx| {
- let name = example.name.clone();
- run_evaluate_one(
- example,
- repetition_ix,
- project,
- store,
- options,
- !args.skip_prediction,
- cx,
- )
- .await
- .map_err(|err| (err, name, repetition_ix))
- })
- });
- futures::future::join_all(tasks).await
- })
- });
- let all_results = futures::future::join_all(all_tasks).await;
-
- write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
- if let Some(mut output_file) =
- std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
- {
- write_aggregated_scores(&mut output_file, &all_results).log_err();
- };
-
- if args.repetitions > 1 {
- if let Err(e) = write_bucketed_analysis(&all_results) {
- eprintln!("Failed to write bucketed analysis: {:?}", e);
- }
- }
-
- print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
-}
-
-fn write_aggregated_scores(
- w: &mut impl std::io::Write,
- all_results: &Vec<
- Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
- >,
-) -> Result<()> {
- let mut successful = Vec::new();
- let mut failed_count = 0;
-
- for result in all_results.iter().flatten() {
- match result {
- Ok((eval_result, _execution_data)) => successful.push(eval_result),
- Err((err, name, repetition_ix)) => {
- if failed_count == 0 {
- writeln!(w, "## Errors\n")?;
- }
-
- failed_count += 1;
- writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
- }
- }
- }
-
- if successful.len() > 1 {
- let edit_scores = successful
- .iter()
- .filter_map(|r| r.edit_scores.clone())
- .collect::<Vec<_>>();
- let has_edit_predictions = edit_scores.len() > 0;
- let aggregated_result = EvaluationResult {
- context_scores: Scores::aggregate(successful.iter().map(|r| &r.context_scores)),
- edit_scores: has_edit_predictions.then(|| EditScores::aggregate(&edit_scores)),
- prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
- generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
- / successful.len(),
- };
-
- writeln!(w, "\n{}", "-".repeat(80))?;
- writeln!(w, "\n## TOTAL SCORES")?;
- writeln!(w, "{:#}", aggregated_result)?;
- }
-
- if successful.len() + failed_count > 1 {
- writeln!(
- w,
- "\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
- successful.len(),
- successful.len() + failed_count,
- (successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
- )?;
- }
-
- Ok(())
-}
-
-pub async fn run_evaluate_one(
- example: NamedExample,
- repetition_ix: Option<u16>,
- project: Entity<Project>,
- store: Entity<EditPredictionStore>,
- prediction_options: PredictionOptions,
- predict: bool,
- cx: &mut AsyncApp,
-) -> Result<(EvaluationResult, ExecutionData)> {
- let predict_result = perform_predict(
- example.clone(),
- project,
- store,
- repetition_ix,
- prediction_options,
- cx,
- )
- .await?;
-
- let evaluation_result = evaluate(&example.example, &predict_result, predict);
-
- if repetition_ix.is_none() {
- write_eval_result(
- &example,
- &predict_result,
- &evaluation_result,
- &mut std::io::stdout(),
- std::io::stdout().is_terminal(),
- predict,
- )?;
- }
-
- if let Some(mut results_file) =
- std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
- {
- write_eval_result(
- &example,
- &predict_result,
- &evaluation_result,
- &mut results_file,
- false,
- predict,
- )
- .log_err();
- }
-
- let execution_data = ExecutionData {
- execution_id: if let Some(rep_ix) = repetition_ix {
- format!("{:03}", rep_ix)
- } else {
- example.name.clone()
- },
- diff: predict_result.diff.clone(),
- reasoning: std::fs::read_to_string(
- predict_result
- .run_example_dir
- .join("prediction_response.md"),
- )
- .unwrap_or_default(),
- };
-
- anyhow::Ok((evaluation_result, execution_data))
-}
-
-fn write_eval_result(
- example: &NamedExample,
- predictions: &PredictionDetails,
- evaluation_result: &EvaluationResult,
- out: &mut impl Write,
- use_color: bool,
- predict: bool,
-) -> Result<()> {
- if predict {
- writeln!(
- out,
- "## Expected edit prediction:\n\n```diff\n{}\n```\n",
- compare_diffs(
- &example.example.expected_patch,
- &predictions.diff,
- use_color
- )
- )?;
- writeln!(
- out,
- "## Actual edit prediction:\n\n```diff\n{}\n```\n",
- compare_diffs(
- &predictions.diff,
- &example.example.expected_patch,
- use_color
- )
- )?;
- }
-
- writeln!(out, "{:#}", evaluation_result)?;
-
- anyhow::Ok(())
-}
-
-#[derive(Debug, Default, Clone)]
-pub struct EditScores {
- pub line_match: Scores,
- pub chr_f: f64,
-}
-
-impl EditScores {
- pub fn aggregate(scores: &[EditScores]) -> EditScores {
- let line_match = Scores::aggregate(scores.iter().map(|s| &s.line_match));
- let chr_f = scores.iter().map(|s| s.chr_f).sum::<f64>() / scores.len() as f64;
-
- EditScores { line_match, chr_f }
- }
-}
-
-#[derive(Debug, Default)]
-pub struct EvaluationResult {
- pub edit_scores: Option<EditScores>,
- pub context_scores: Scores,
- pub prompt_len: usize,
- pub generated_len: usize,
-}
-
-impl std::fmt::Display for EvaluationResult {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- if f.alternate() {
- self.fmt_table(f)
- } else {
- self.fmt_markdown(f)
- }
- }
-}
-
-impl EvaluationResult {
- fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(
- f,
- r#"
-### Context Scores
-{}
-"#,
- self.context_scores.to_markdown(),
- )?;
- if let Some(scores) = &self.edit_scores {
- write!(
- f,
- r#"
- ### Edit Prediction Scores
- {}"#,
- scores.line_match.to_markdown()
- )?;
- }
- Ok(())
- }
-
- fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- writeln!(f, "#### Prompt Statistics")?;
- writeln!(f, "─────────────────────────")?;
- writeln!(f, "Prompt_len Generated_len")?;
- writeln!(f, "─────────────────────────")?;
- writeln!(f, "{:<11} {:<14}", self.prompt_len, self.generated_len,)?;
- writeln!(f)?;
- writeln!(f)?;
- writeln!(f, "#### Performance Scores")?;
- writeln!(
- f,
- "──────────────────────────────────────────────────────────────────"
- )?;
- writeln!(
- f,
- " TP FP FN Precision Recall F1"
- )?;
- writeln!(
- f,
- "──────────────────────────────────────────────────────────────────"
- )?;
- writeln!(
- f,
- "Context Retrieval {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
- self.context_scores.true_positives,
- self.context_scores.false_positives,
- self.context_scores.false_negatives,
- self.context_scores.precision() * 100.0,
- self.context_scores.recall() * 100.0,
- self.context_scores.f1_score() * 100.0
- )?;
- if let Some(edit_scores) = &self.edit_scores {
- let line_match = &edit_scores.line_match;
- writeln!(f, "Edit Prediction")?;
- writeln!(
- f,
- " ├─ exact lines {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
- line_match.true_positives,
- line_match.false_positives,
- line_match.false_negatives,
- line_match.precision() * 100.0,
- line_match.recall() * 100.0,
- line_match.f1_score() * 100.0
- )?;
- writeln!(
- f,
- " └─ diff chrF {:<6} {:<6} {:<6} {:>8} {:>8} {:>6.2}",
- "-", "-", "-", "-", "-", edit_scores.chr_f
- )?;
- }
- Ok(())
- }
-}
-
-fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
- let mut eval_result = EvaluationResult {
- prompt_len: preds.prompt_len,
- generated_len: preds.generated_len,
- ..Default::default()
- };
-
- if predict {
- // todo: alternatives for patches
- let expected_patch = example
- .expected_patch
- .lines()
- .map(DiffLine::parse)
- .collect::<Vec<_>>();
- let actual_patch = preds.diff.lines().map(DiffLine::parse).collect::<Vec<_>>();
-
- let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
- let chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch);
-
- eval_result.edit_scores = Some(EditScores { line_match, chr_f });
- }
-
- eval_result
-}
-
-/// Return annotated `patch_a` so that:
-/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
-/// Additions and deletions that are present in `patch_b` will be highlighted in green.
-pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
- let green = if use_color { "\x1b[32m✓ " } else { "" };
- let red = if use_color { "\x1b[31m✗ " } else { "" };
- let neutral = if use_color { " " } else { "" };
- let reset = if use_color { "\x1b[0m" } else { "" };
- let lines_a = patch_a.lines().map(DiffLine::parse);
- let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
-
- let annotated = lines_a
- .map(|line| match line {
- DiffLine::Addition(_) | DiffLine::Deletion(_) => {
- if lines_b.contains(&line) {
- format!("{green}{line}{reset}")
- } else {
- format!("{red}{line}{reset}")
- }
- }
- _ => format!("{neutral}{line}{reset}"),
- })
- .collect::<Vec<String>>();
-
- annotated.join("\n")
-}
-
-fn write_bucketed_analysis(
- all_results: &Vec<
- Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
- >,
-) -> Result<()> {
- #[derive(Debug)]
- struct EditBucket {
- diff: String,
- is_correct: bool,
- execution_indices: Vec<String>,
- reasoning_samples: Vec<String>,
- }
-
- let mut total_executions = 0;
- let mut empty_predictions = Vec::new();
- let mut errors = Vec::new();
-
- let mut buckets: HashMap<String, EditBucket> = HashMap::new();
-
- for result in all_results.iter().flatten() {
- total_executions += 1;
-
- let (evaluation_result, execution_data) = match result {
- Ok((eval_result, execution_data)) => {
- if execution_data.diff.is_empty() {
- empty_predictions.push(execution_data);
- continue;
- }
- (eval_result, execution_data)
- }
- Err(err) => {
- errors.push(err);
- continue;
- }
- };
-
- buckets
- .entry(execution_data.diff.clone())
- .and_modify(|bucket| {
- bucket
- .execution_indices
- .push(execution_data.execution_id.clone());
- bucket
- .reasoning_samples
- .push(execution_data.reasoning.clone());
- })
- .or_insert_with(|| EditBucket {
- diff: execution_data.diff.clone(),
- is_correct: {
- evaluation_result
- .edit_scores
- .as_ref()
- .map_or(false, |edit_scores| {
- edit_scores.line_match.false_positives == 0
- && edit_scores.line_match.false_negatives == 0
- && edit_scores.line_match.true_positives > 0
- })
- },
- execution_indices: vec![execution_data.execution_id.clone()],
- reasoning_samples: vec![execution_data.reasoning.clone()],
- });
- }
-
- let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
- sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
- (true, false) => std::cmp::Ordering::Less,
- (false, true) => std::cmp::Ordering::Greater,
- _ => b.execution_indices.len().cmp(&a.execution_indices.len()),
- });
-
- let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
- let mut output = std::fs::File::create(&output_path)?;
-
- writeln!(output, "# Bucketed Edit Analysis\n")?;
-
- writeln!(output, "## Summary\n")?;
- writeln!(output, "- **Total executions**: {}", total_executions)?;
-
- let correct_count: usize = sorted_buckets
- .iter()
- .filter(|b| b.is_correct)
- .map(|b| b.execution_indices.len())
- .sum();
-
- let incorrect_count: usize = sorted_buckets
- .iter()
- .filter(|b| !b.is_correct)
- .map(|b| b.execution_indices.len())
- .sum();
-
- writeln!(
- output,
- "- **Correct predictions**: {} ({:.1}%)",
- correct_count,
- (correct_count as f64 / total_executions as f64) * 100.0
- )?;
-
- writeln!(
- output,
- "- **Incorrect predictions**: {} ({:.1}%)",
- incorrect_count,
- (incorrect_count as f64 / total_executions as f64) * 100.0
- )?;
-
- writeln!(
- output,
- "- **No Predictions**: {} ({:.1}%)",
- empty_predictions.len(),
- (empty_predictions.len() as f64 / total_executions as f64) * 100.0
- )?;
-
- let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
- writeln!(
- output,
- "- **Unique incorrect edit patterns**: {}\n",
- unique_incorrect
- )?;
-
- writeln!(output, "---\n")?;
-
- for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
- if idx == 0 {
- writeln!(
- output,
- "## Correct Predictions ({} occurrences)\n",
- bucket.execution_indices.len()
- )?;
- }
-
- writeln!(output, "**Predicted Edit:**\n")?;
- writeln!(output, "```diff")?;
- writeln!(output, "{}", bucket.diff)?;
- writeln!(output, "```\n")?;
-
- writeln!(
- output,
- "**Executions:** {}\n",
- bucket.execution_indices.join(", ")
- )?;
- writeln!(output, "---\n")?;
- }
-
- for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
- writeln!(
- output,
- "## Incorrect Prediction #{} ({} occurrences)\n",
- idx + 1,
- bucket.execution_indices.len()
- )?;
-
- writeln!(output, "**Predicted Edit:**\n")?;
- writeln!(output, "```diff")?;
- writeln!(output, "{}", bucket.diff)?;
- writeln!(output, "```\n")?;
-
- writeln!(
- output,
- "**Executions:** {}\n",
- bucket.execution_indices.join(", ")
- )?;
-
- for (exec_id, reasoning) in bucket
- .execution_indices
- .iter()
- .zip(bucket.reasoning_samples.iter())
- {
- writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
- }
-
- writeln!(output, "\n---\n")?;
- }
-
- if !empty_predictions.is_empty() {
- writeln!(
- output,
- "## No Predictions ({} occurrences)\n",
- empty_predictions.len()
- )?;
-
- for execution_data in &empty_predictions {
- writeln!(
- output,
- "{}",
- fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
- )?;
- }
- writeln!(output, "\n---\n")?;
- }
-
- if !errors.is_empty() {
- writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
-
- for (err, name, repetition_ix) in &errors {
- writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
- }
- writeln!(output, "\n---\n")?;
- }
-
- fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
- let exec_content = format!(
- "\n### Execution {} `{}/{}/prediction_response.md`{}",
- exec_id,
- crate::paths::RUN_DIR.display(),
- exec_id,
- indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
- );
- indent_text(&exec_content, 2)
- }
-
- fn indent_text(text: &str, spaces: usize) -> String {
- let indent = " ".repeat(spaces);
- text.lines()
- .collect::<Vec<_>>()
- .join(&format!("\n{}", indent))
- }
-
- Ok(())
-}
-
-fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
- let err = format!("{err:?}")
- .replace("<edits", "```xml\n<edits")
- .replace("</edits>", "</edits>\n```");
- format!(
- "### ERROR {name}{}\n\n{err}\n",
- repetition_ix
- .map(|ix| format!(" [RUN {ix:03}]"))
- .unwrap_or_default()
- )
-}
@@ -1,59 +1,103 @@
+use crate::{
+ PredictionProvider, PromptFormat,
+ metrics::ClassificationMetrics,
+ paths::{REPOS_DIR, WORKTREES_DIR},
+};
+use anyhow::{Context as _, Result};
+use edit_prediction::udiff::OpenedBuffers;
+use gpui::Entity;
+use http_client::Url;
+use language::{Anchor, Buffer};
+use project::Project;
+use serde::{Deserialize, Serialize};
+use std::sync::Arc;
use std::{
borrow::Cow,
- cell::RefCell,
- fmt::{self, Display},
- fs,
- hash::Hash,
- hash::Hasher,
- io::Write,
+ io::{Read, Write},
mem,
path::{Path, PathBuf},
- sync::{Arc, OnceLock},
};
+use zeta_prompt::RelatedFile;
-use crate::headless::ZetaCliAppState;
-use anyhow::{Context as _, Result, anyhow};
-use clap::ValueEnum;
-use cloud_zeta2_prompt::CURSOR_MARKER;
-use collections::HashMap;
-use edit_prediction::udiff::OpenedBuffers;
-use futures::{
- AsyncWriteExt as _,
- lock::{Mutex, OwnedMutexGuard},
-};
-use futures::{FutureExt as _, future::Shared};
-use gpui::{AsyncApp, Entity, Task, http_client::Url};
-use language::{Anchor, Buffer};
-use project::{Project, ProjectPath};
-use pulldown_cmark::CowStr;
-use serde::{Deserialize, Serialize};
-use util::{paths::PathStyle, rel_path::RelPath};
-
-use crate::paths::{REPOS_DIR, WORKTREES_DIR};
-
-const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
-const EDIT_HISTORY_HEADING: &str = "Edit History";
-const CURSOR_POSITION_HEADING: &str = "Cursor Position";
-const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
-const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
-const REPOSITORY_URL_FIELD: &str = "repository_url";
-const REVISION_FIELD: &str = "revision";
-
-#[derive(Debug, Clone)]
-pub struct NamedExample {
- pub name: String,
- pub example: Example,
-}
-
-#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
+#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Example {
+ #[serde(default)]
+ pub name: String,
pub repository_url: String,
pub revision: String,
pub uncommitted_diff: String,
- pub cursor_path: PathBuf,
+ pub cursor_path: Arc<Path>,
pub cursor_position: String,
pub edit_history: String,
pub expected_patch: String,
+
+ /// The full content of the file where an edit is being predicted, and the
+ /// actual cursor offset.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub buffer: Option<ExampleBuffer>,
+
+ /// The context retrieved for the prediction. This requires the worktree to
+ /// be loaded and the language server to be started.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub context: Option<ExampleContext>,
+
+ /// The input and expected output from the edit prediction model.
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub prompt: Option<ExamplePrompt>,
+
+ /// The actual predictions from the model.
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub predictions: Vec<ExamplePrediction>,
+
+ /// The scores, for how well the actual predictions match the expected
+ /// predictions.
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
+ pub score: Vec<ExampleScore>,
+
+ /// The application state used to process this example.
+ #[serde(skip)]
+ pub state: Option<ExampleState>,
+}
+
+#[derive(Clone, Debug)]
+pub struct ExampleState {
+ pub project: Entity<Project>,
+ pub buffer: Entity<Buffer>,
+ pub cursor_position: Anchor,
+ pub _open_buffers: OpenedBuffers,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExampleContext {
+ pub files: Arc<[RelatedFile]>,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExampleBuffer {
+ pub content: String,
+ pub cursor_row: u32,
+ pub cursor_column: u32,
+ pub cursor_offset: usize,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExamplePrompt {
+ pub input: String,
+ pub expected_output: String,
+ pub format: PromptFormat,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExamplePrediction {
+ pub actual_patch: String,
+ pub actual_output: String,
+ pub provider: PredictionProvider,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ExampleScore {
+ pub delta_chr_f: f32,
+ pub line_match: ClassificationMetrics,
}
impl Example {
@@ -90,485 +134,244 @@ impl Example {
}
}
- pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
- let (repo_owner, repo_name) = self.repo_name()?;
-
- let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
- let repo_lock = lock_repo(&repo_dir).await;
+ pub fn worktree_path(&self) -> PathBuf {
+ WORKTREES_DIR
+ .join(&self.name)
+ .join(self.repo_name().unwrap().1.as_ref())
+ }
- if !repo_dir.is_dir() {
- fs::create_dir_all(&repo_dir)?;
- run_git(&repo_dir, &["init"]).await?;
- run_git(
- &repo_dir,
- &["remote", "add", "origin", &self.repository_url],
- )
- .await?;
- }
+ pub fn repo_path(&self) -> PathBuf {
+ let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
+ REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
+ }
+}
- // Resolve the example to a revision, fetching it if needed.
- let revision = run_git(
- &repo_dir,
- &["rev-parse", &format!("{}^{{commit}}", self.revision)],
- )
- .await;
- let revision = if let Ok(revision) = revision {
- revision
+pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
+ let mut examples = Vec::new();
+
+ let stdin_path: PathBuf = PathBuf::from("-");
+
+ let inputs = if inputs.is_empty() {
+ &[stdin_path]
+ } else {
+ inputs
+ };
+
+ for path in inputs {
+ let is_stdin = path.as_path() == Path::new("-");
+ let content = if is_stdin {
+ let mut buffer = String::new();
+ std::io::stdin()
+ .read_to_string(&mut buffer)
+ .expect("Failed to read from stdin");
+ buffer
} else {
- if run_git(
- &repo_dir,
- &["fetch", "--depth", "1", "origin", &self.revision],
- )
- .await
- .is_err()
- {
- run_git(&repo_dir, &["fetch", "origin"]).await?;
- }
- let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
- if revision != self.revision {
- run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
- }
- revision
+ std::fs::read_to_string(path)
+ .unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
};
-
- // Create the worktree for this example if needed.
- let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
- if worktree_path.is_dir() {
- run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
- run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
- run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
+ let filename = path.file_stem().unwrap().to_string_lossy().to_string();
+ let ext = if !is_stdin {
+ path.extension()
+ .map(|ext| ext.to_string_lossy().to_string())
+ .unwrap_or_else(|| panic!("{} should have an extension", path.display()))
} else {
- let worktree_path_string = worktree_path.to_string_lossy();
- run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
- run_git(
- &repo_dir,
- &["worktree", "add", "-f", &worktree_path_string, &file_name],
- )
- .await?;
- }
- drop(repo_lock);
-
- // Apply the uncommitted diff for this example.
- if !self.uncommitted_diff.is_empty() {
- let mut apply_process = smol::process::Command::new("git")
- .current_dir(&worktree_path)
- .args(&["apply", "-"])
- .stdin(std::process::Stdio::piped())
- .spawn()?;
-
- let mut stdin = apply_process.stdin.take().unwrap();
- stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
- stdin.close().await?;
- drop(stdin);
-
- let apply_result = apply_process.output().await?;
- if !apply_result.status.success() {
- anyhow::bail!(
- "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
- apply_result.status,
- String::from_utf8_lossy(&apply_result.stderr),
- String::from_utf8_lossy(&apply_result.stdout),
- );
+ "jsonl".to_string()
+ };
+
+ match ext.as_ref() {
+ "json" => {
+ let mut example =
+ serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
+ panic!("Failed to parse example file: {}\n{error}", path.display())
+ });
+ if example.name.is_empty() {
+ example.name = filename;
+ }
+ examples.push(example);
+ }
+ "jsonl" => examples.extend(
+ content
+ .lines()
+ .enumerate()
+ .map(|(line_ix, line)| {
+ let mut example =
+ serde_json::from_str::<Example>(line).unwrap_or_else(|_| {
+ panic!(
+ "Failed to parse example on {}:{}",
+ path.display(),
+ line_ix + 1
+ )
+ });
+ if example.name.is_empty() {
+ example.name = format!("{filename}-{line_ix}")
+ }
+ example
+ })
+ .collect::<Vec<Example>>(),
+ ),
+ "md" => {
+ examples.push(parse_markdown_example(filename, &content).unwrap());
+ }
+ ext => {
+ panic!("{} has invalid example extension `{ext}`", path.display())
}
}
-
- Ok(worktree_path)
- }
-
- pub fn unique_name(&self) -> String {
- let mut hasher = std::hash::DefaultHasher::new();
- self.hash(&mut hasher);
- let disambiguator = hasher.finish();
- let hash = format!("{:04x}", disambiguator);
- format!("{}_{}", &self.revision[..8], &hash[..4])
}
+ examples
}
-pub type ActualExcerpt = Excerpt;
-
-#[derive(Clone, Debug, Serialize, Deserialize)]
-pub struct Excerpt {
- pub path: PathBuf,
- pub text: String,
-}
-
-#[derive(ValueEnum, Debug, Clone)]
-pub enum ExampleFormat {
- Json,
- Toml,
- Md,
+pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
+ let mut content = String::new();
+ for example in examples {
+ let line = serde_json::to_string(example).unwrap();
+ content.push_str(&line);
+ content.push('\n');
+ }
+ if let Some(output_path) = output_path {
+ std::fs::write(output_path, content).expect("Failed to write examples");
+ } else {
+ std::io::stdout().write_all(&content.as_bytes()).unwrap();
+ }
}
-impl NamedExample {
- pub fn load(path: impl AsRef<Path>) -> Result<Self> {
- let path = path.as_ref();
- let content = std::fs::read_to_string(path)?;
- let ext = path.extension();
-
- match ext.and_then(|s| s.to_str()) {
- Some("json") => Ok(Self {
- name: path.file_stem().unwrap_or_default().display().to_string(),
- example: serde_json::from_str(&content)?,
- }),
- Some("toml") => Ok(Self {
- name: path.file_stem().unwrap_or_default().display().to_string(),
- example: toml::from_str(&content)?,
- }),
- Some("md") => Self::parse_md(&content),
- Some(_) => {
- anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
- }
- None => {
- anyhow::bail!(
- "Failed to determine example type since the file does not have an extension."
- );
- }
- }
+fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
+ use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
+
+ const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
+ const EDIT_HISTORY_HEADING: &str = "Edit History";
+ const CURSOR_POSITION_HEADING: &str = "Cursor Position";
+ const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
+ const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
+ const REPOSITORY_URL_FIELD: &str = "repository_url";
+ const REVISION_FIELD: &str = "revision";
+
+ let parser = Parser::new(input);
+
+ let mut example = Example {
+ name: id,
+ repository_url: String::new(),
+ revision: String::new(),
+ uncommitted_diff: String::new(),
+ cursor_path: PathBuf::new().into(),
+ cursor_position: String::new(),
+ edit_history: String::new(),
+ expected_patch: String::new(),
+ buffer: None,
+ context: None,
+ prompt: None,
+ predictions: Vec::new(),
+ score: Vec::new(),
+ state: None,
+ };
+
+ let mut name = String::new();
+ let mut text = String::new();
+ let mut block_info: CowStr = "".into();
+
+ #[derive(PartialEq)]
+ enum Section {
+ UncommittedDiff,
+ EditHistory,
+ CursorPosition,
+ ExpectedExcerpts,
+ ExpectedPatch,
+ Other,
}
- pub fn parse_md(input: &str) -> Result<Self> {
- use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
-
- let parser = Parser::new(input);
-
- let mut named = NamedExample {
- name: String::new(),
- example: Example {
- repository_url: String::new(),
- revision: String::new(),
- uncommitted_diff: String::new(),
- cursor_path: PathBuf::new(),
- cursor_position: String::new(),
- edit_history: String::new(),
- expected_patch: String::new(),
- },
- };
+ let mut current_section = Section::Other;
- let mut text = String::new();
- let mut block_info: CowStr = "".into();
-
- #[derive(PartialEq)]
- enum Section {
- UncommittedDiff,
- EditHistory,
- CursorPosition,
- ExpectedExcerpts,
- ExpectedPatch,
- Other,
- }
+ for event in parser {
+ match event {
+ Event::Text(line) => {
+ text.push_str(&line);
- let mut current_section = Section::Other;
-
- for event in parser {
- match event {
- Event::Text(line) => {
- text.push_str(&line);
-
- if !named.name.is_empty()
- && current_section == Section::Other
- // in h1 section
- && let Some((field, value)) = line.split_once('=')
- {
- match field.trim() {
- REPOSITORY_URL_FIELD => {
- named.example.repository_url = value.trim().to_string();
- }
- REVISION_FIELD => {
- named.example.revision = value.trim().to_string();
- }
- _ => {}
- }
- }
- }
- Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
- if !named.name.is_empty() {
- anyhow::bail!(
- "Found multiple H1 headings. There should only be one with the name of the example."
- );
- }
- named.name = mem::take(&mut text);
- }
- Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
- let title = mem::take(&mut text);
- current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
- Section::UncommittedDiff
- } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
- Section::EditHistory
- } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
- Section::CursorPosition
- } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
- Section::ExpectedPatch
- } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
- Section::ExpectedExcerpts
- } else {
- Section::Other
- };
- }
- Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
- mem::take(&mut text);
- }
- Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
- mem::take(&mut text);
- }
- Event::End(TagEnd::Heading(level)) => {
- anyhow::bail!("Unexpected heading level: {level}");
- }
- Event::Start(Tag::CodeBlock(kind)) => {
- match kind {
- CodeBlockKind::Fenced(info) => {
- block_info = info;
- }
- CodeBlockKind::Indented => {
- anyhow::bail!("Unexpected indented codeblock");
- }
- };
- }
- Event::Start(_) => {
- text.clear();
- block_info = "".into();
- }
- Event::End(TagEnd::CodeBlock) => {
- let block_info = block_info.trim();
- match current_section {
- Section::UncommittedDiff => {
- named.example.uncommitted_diff = mem::take(&mut text);
- }
- Section::EditHistory => {
- named.example.edit_history.push_str(&mem::take(&mut text));
- }
- Section::CursorPosition => {
- named.example.cursor_path = block_info.into();
- named.example.cursor_position = mem::take(&mut text);
- }
- Section::ExpectedExcerpts => {
- mem::take(&mut text);
+ if let Some((field, value)) = line.split_once('=') {
+ match field.trim() {
+ REPOSITORY_URL_FIELD => {
+ example.repository_url = value.trim().to_string();
}
- Section::ExpectedPatch => {
- named.example.expected_patch = mem::take(&mut text);
+ REVISION_FIELD => {
+ example.revision = value.trim().to_string();
}
- Section::Other => {}
+ _ => {}
}
}
- _ => {}
}
- }
-
- if named.example.cursor_path.as_path() == Path::new("")
- || named.example.cursor_position.is_empty()
- {
- anyhow::bail!("Missing cursor position codeblock");
- }
-
- Ok(named)
- }
-
- pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
- match format {
- ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
- ExampleFormat::Toml => {
- Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
+ Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
+ if !name.is_empty() {
+ anyhow::bail!(
+ "Found multiple H1 headings. There should only be one with the name of the example."
+ );
+ }
+ name = mem::take(&mut text);
}
- ExampleFormat::Md => Ok(write!(out, "{}", self)?),
- }
- }
-
- pub async fn setup_project(
- &self,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
- ) -> Result<Entity<Project>> {
- let worktree_path = self.setup_worktree().await?;
-
- static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
-
- AUTHENTICATED
- .get_or_init(|| {
- let client = app_state.client.clone();
- cx.spawn(async move |cx| {
- client
- .sign_in_with_optional_connect(true, cx)
- .await
- .unwrap();
- })
- .shared()
- })
- .clone()
- .await;
-
- let project = cx.update(|cx| {
- Project::local(
- app_state.client.clone(),
- app_state.node_runtime.clone(),
- app_state.user_store.clone(),
- app_state.languages.clone(),
- app_state.fs.clone(),
- None,
- cx,
- )
- })?;
-
- let worktree = project
- .update(cx, |project, cx| {
- project.create_worktree(&worktree_path, true, cx)
- })?
- .await?;
- worktree
- .read_with(cx, |worktree, _cx| {
- worktree.as_local().unwrap().scan_complete()
- })?
- .await;
-
- anyhow::Ok(project)
- }
-
- pub async fn setup_worktree(&self) -> Result<PathBuf> {
- self.example.setup_worktree(self.file_name()).await
- }
-
- pub fn file_name(&self) -> String {
- self.name
- .chars()
- .map(|c| {
- if c.is_whitespace() {
- '-'
+ Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
+ let title = mem::take(&mut text);
+ current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
+ Section::UncommittedDiff
+ } else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
+ Section::EditHistory
+ } else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
+ Section::CursorPosition
+ } else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
+ Section::ExpectedPatch
+ } else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
+ Section::ExpectedExcerpts
} else {
- c.to_ascii_lowercase()
+ Section::Other
+ };
+ }
+ Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
+ mem::take(&mut text);
+ }
+ Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
+ mem::take(&mut text);
+ }
+ Event::End(TagEnd::Heading(level)) => {
+ anyhow::bail!("Unexpected heading level: {level}");
+ }
+ Event::Start(Tag::CodeBlock(kind)) => {
+ match kind {
+ CodeBlockKind::Fenced(info) => {
+ block_info = info;
+ }
+ CodeBlockKind::Indented => {
+ anyhow::bail!("Unexpected indented codeblock");
+ }
+ };
+ }
+ Event::Start(_) => {
+ text.clear();
+ block_info = "".into();
+ }
+ Event::End(TagEnd::CodeBlock) => {
+ let block_info = block_info.trim();
+ match current_section {
+ Section::UncommittedDiff => {
+ example.uncommitted_diff = mem::take(&mut text);
+ }
+ Section::EditHistory => {
+ example.edit_history.push_str(&mem::take(&mut text));
+ }
+ Section::CursorPosition => {
+ example.cursor_path = Path::new(block_info).into();
+ example.cursor_position = mem::take(&mut text);
+ }
+ Section::ExpectedExcerpts => {
+ mem::take(&mut text);
+ }
+ Section::ExpectedPatch => {
+ example.expected_patch = mem::take(&mut text);
+ }
+ Section::Other => {}
}
- })
- .collect()
- }
-
- pub async fn cursor_position(
- &self,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
- ) -> Result<(Entity<Buffer>, Anchor)> {
- let worktree = project.read_with(cx, |project, cx| {
- project.visible_worktrees(cx).next().unwrap()
- })?;
- let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
- let cursor_buffer = project
- .update(cx, |project, cx| {
- project.open_buffer(
- ProjectPath {
- worktree_id: worktree.read(cx).id(),
- path: cursor_path,
- },
- cx,
- )
- })?
- .await?;
- let cursor_offset_within_excerpt = self
- .example
- .cursor_position
- .find(CURSOR_MARKER)
- .ok_or_else(|| anyhow!("missing cursor marker"))?;
- let mut cursor_excerpt = self.example.cursor_position.clone();
- cursor_excerpt.replace_range(
- cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
- "",
- );
- let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
- let text = buffer.text();
-
- let mut matches = text.match_indices(&cursor_excerpt);
- let Some((excerpt_offset, _)) = matches.next() else {
- anyhow::bail!(
- "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
- );
- };
- assert!(matches.next().is_none());
-
- Ok(excerpt_offset)
- })??;
-
- let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
- let cursor_anchor =
- cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
- Ok((cursor_buffer, cursor_anchor))
- }
-
- #[must_use]
- pub async fn apply_edit_history(
- &self,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
- ) -> Result<OpenedBuffers<'_>> {
- edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
- }
-}
-
-async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
- let output = smol::process::Command::new("git")
- .current_dir(repo_path)
- .args(args)
- .output()
- .await?;
-
- anyhow::ensure!(
- output.status.success(),
- "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
- args.join(" "),
- repo_path.display(),
- output.status,
- String::from_utf8_lossy(&output.stderr),
- String::from_utf8_lossy(&output.stdout),
- );
- Ok(String::from_utf8(output.stdout)?.trim().to_string())
-}
-
-impl Display for NamedExample {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "# {}\n\n", self.name)?;
- write!(
- f,
- "{REPOSITORY_URL_FIELD} = {}\n",
- self.example.repository_url
- )?;
- write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
-
- write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
- write!(f, "`````diff\n")?;
- write!(f, "{}", self.example.uncommitted_diff)?;
- write!(f, "`````\n")?;
-
- if !self.example.edit_history.is_empty() {
- write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
- }
-
- write!(
- f,
- "## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
- self.example.cursor_path.display(),
- self.example.cursor_position
- )?;
- write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
-
- if !self.example.expected_patch.is_empty() {
- write!(
- f,
- "\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
- self.example.expected_patch
- )?;
+ }
+ _ => {}
}
-
- Ok(())
}
-}
-
-thread_local! {
- static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
-}
+ if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
+ anyhow::bail!("Missing cursor position codeblock");
+ }
-#[must_use]
-pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
- REPO_LOCKS
- .with(|cell| {
- cell.borrow_mut()
- .entry(path.as_ref().to_path_buf())
- .or_default()
- .clone()
- })
- .lock_owned()
- .await
+ Ok(example)
}
@@ -0,0 +1,280 @@
+use crate::{
+ PromptFormat,
+ example::{Example, ExamplePrompt},
+ headless::EpAppState,
+ retrieve_context::run_context_retrieval,
+};
+use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
+use gpui::AsyncApp;
+use std::sync::Arc;
+use zeta_prompt::format_zeta_prompt;
+
+pub async fn run_format_prompt(
+ example: &mut Example,
+ prompt_format: PromptFormat,
+ app_state: Arc<EpAppState>,
+ mut cx: AsyncApp,
+) {
+ run_context_retrieval(example, app_state, cx.clone()).await;
+
+ let prompt = match prompt_format {
+ PromptFormat::Teacher => TeacherPrompt::format(example),
+ PromptFormat::Zeta2 => {
+ let ep_store = cx
+ .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+ .unwrap();
+
+ let state = example.state.as_ref().unwrap();
+ let snapshot = state
+ .buffer
+ .read_with(&cx, |buffer, _| buffer.snapshot())
+ .unwrap();
+ let project = state.project.clone();
+ let (_, input) = ep_store
+ .update(&mut cx, |ep_store, _cx| {
+ zeta2_prompt_input(
+ &snapshot,
+ example.context.as_ref().unwrap().files.clone(),
+ ep_store.edit_history_for_project(&project),
+ example.cursor_path.clone(),
+ example.buffer.as_ref().unwrap().cursor_offset,
+ )
+ })
+ .unwrap();
+ format_zeta_prompt(&input)
+ }
+ };
+
+ example.prompt = Some(ExamplePrompt {
+ input: prompt,
+ expected_output: example.expected_patch.clone(), // TODO
+ format: prompt_format,
+ });
+}
+
+pub trait PromptFormatter {
+ fn format(example: &Example) -> String;
+}
+
+pub trait PromptParser {
+ /// Return unified diff patch of prediction given raw LLM response
+ fn parse(example: &Example, response: &str) -> String;
+}
+
+pub struct TeacherPrompt;
+
+impl PromptFormatter for TeacherPrompt {
+ fn format(example: &Example) -> String {
+ let edit_history = Self::format_edit_history(&example.edit_history);
+ let context = Self::format_context(example);
+ let editable_region = Self::format_editable_region(example);
+
+ let prompt = Self::PROMPT
+ .replace("{{context}}", &context)
+ .replace("{{edit_history}}", &edit_history)
+ .replace("{{editable_region}}", &editable_region);
+
+ prompt
+ }
+}
+
+impl TeacherPrompt {
+ const PROMPT: &str = include_str!("teacher.prompt.md");
+ pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
+ pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
+
+ /// Truncate edit history to this number of last lines
+ const MAX_HISTORY_LINES: usize = 128;
+
+ fn format_edit_history(edit_history: &str) -> String {
+ // Strip comments ("garbage lines") from edit history
+ let lines = edit_history
+ .lines()
+ .filter(|&s| Self::is_udiff_content_line(s))
+ .collect::<Vec<_>>();
+
+ let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
+ &lines[lines.len() - Self::MAX_HISTORY_LINES..]
+ } else {
+ &lines
+ };
+
+ if history_lines.is_empty() {
+ return "(No edit history)".to_string();
+ }
+
+ history_lines.join("\n")
+ }
+
+ fn format_context(example: &Example) -> String {
+ if example.context.is_none() {
+ panic!("Missing context retriever step");
+ }
+
+ let mut prompt = String::new();
+ zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
+
+ prompt
+ }
+
+ fn format_editable_region(example: &Example) -> String {
+ let mut result = String::new();
+
+ let path_str = example.cursor_path.to_string_lossy();
+ result.push_str(&format!("`````path=\"{path_str}\"\n"));
+ result.push_str(Self::EDITABLE_REGION_START);
+
+ // TODO: control number of lines around cursor
+ result.push_str(&example.cursor_position);
+ if !example.cursor_position.ends_with('\n') {
+ result.push('\n');
+ }
+
+ result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
+ result.push_str("`````");
+
+ result
+ }
+
+ fn extract_editable_region(text: &str) -> String {
+ let start = text
+ .find(Self::EDITABLE_REGION_START)
+ .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
+ let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
+
+ let region = &text[start..end];
+
+ region.replace("<|user_cursor|>", "")
+ }
+
+ fn is_udiff_content_line(s: &str) -> bool {
+ s.starts_with("-")
+ || s.starts_with("+")
+ || s.starts_with(" ")
+ || s.starts_with("---")
+ || s.starts_with("+++")
+ || s.starts_with("@@")
+ }
+}
+
+impl PromptParser for TeacherPrompt {
+ fn parse(example: &Example, response: &str) -> String {
+ // Ideally, we should always be able to find cursor position in the retrieved context.
+ // In reality, sometimes we don't find it for these reasons:
+ // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
+ // (can be fixed by getting cursor coordinates at the load_example stage)
+ // 2. Context retriever just didn't include cursor line.
+ //
+ // In that case, fallback to using `cursor_position` as excerpt.
+ let cursor_file = &example
+ .buffer
+ .as_ref()
+ .expect("`buffer` should be filled in in the context collection step")
+ .content;
+
+ // Extract updated (new) editable region from the model response
+ let new_editable_region = extract_last_codeblock(response);
+
+ // Reconstruct old editable region we sent to the model
+ let old_editable_region = Self::format_editable_region(example);
+ let old_editable_region = Self::extract_editable_region(&old_editable_region);
+ if !cursor_file.contains(&old_editable_region) {
+ panic!("Something's wrong: editable_region is not found in the cursor file")
+ }
+
+ // Apply editable region to a larger context and compute diff.
+ // This is needed to get a better context lines around the editable region
+ let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
+ let diff = language::unified_diff(&cursor_file, &edited_file);
+
+ let diff = indoc::formatdoc! {"
+ --- a/{path}
+ +++ b/{path}
+ {diff}
+ ",
+ path = example.cursor_path.to_string_lossy(),
+ diff = diff,
+ };
+
+ diff
+ }
+}
+
+fn extract_last_codeblock(text: &str) -> String {
+ let mut last_block = None;
+ let mut search_start = 0;
+
+ while let Some(start) = text[search_start..].find("```") {
+ let start = start + search_start;
+ let bytes = text.as_bytes();
+ let mut backtick_end = start;
+
+ while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
+ backtick_end += 1;
+ }
+
+ let backtick_count = backtick_end - start;
+ let closing_backticks = "`".repeat(backtick_count);
+
+ while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
+ backtick_end += 1;
+ }
+
+ if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
+ let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
+ last_block = Some(code_block.to_string());
+ search_start = backtick_end + end_pos + backtick_count;
+ } else {
+ break;
+ }
+ }
+
+ last_block.unwrap_or_else(|| text.to_string())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_extract_last_code_block() {
+ let text = indoc::indoc! {"
+ Some thinking
+
+ ```
+ first block
+ ```
+
+ `````path='something' lines=1:2
+ last block
+ `````
+ "};
+ let last_block = extract_last_codeblock(text);
+ assert_eq!(last_block, "last block");
+ }
+
+ #[test]
+ fn test_extract_editable_region() {
+ let text = indoc::indoc! {"
+ some lines
+ are
+ here
+ <|editable_region_start|>
+ one
+ two three
+
+ <|editable_region_end|>
+ more
+ lines here
+ "};
+ let parsed = TeacherPrompt::extract_editable_region(text);
+ assert_eq!(
+ parsed,
+ indoc::indoc! {"
+ one
+ two three
+
+ "}
+ );
+ }
+}
@@ -16,7 +16,7 @@ use std::sync::Arc;
use util::ResultExt as _;
/// Headless subset of `workspace::AppState`.
-pub struct ZetaCliAppState {
+pub struct EpAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
@@ -25,7 +25,7 @@ pub struct ZetaCliAppState {
}
// TODO: dedupe with crates/eval/src/eval.rs
-pub fn init(cx: &mut App) -> ZetaCliAppState {
+pub fn init(cx: &mut App) -> EpAppState {
let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned()));
let app_version = AppVersion::load(
@@ -112,7 +112,7 @@ pub fn init(cx: &mut App) -> ZetaCliAppState {
prompt_store::init(cx);
terminal_view::init(cx);
- ZetaCliAppState {
+ EpAppState {
languages,
client,
user_store,
@@ -0,0 +1,320 @@
+use crate::{
+ example::{Example, ExampleBuffer, ExampleState},
+ headless::EpAppState,
+};
+use anyhow::{Result, anyhow};
+use collections::HashMap;
+use edit_prediction::EditPredictionStore;
+use edit_prediction::udiff::OpenedBuffers;
+use futures::{
+ AsyncWriteExt as _,
+ lock::{Mutex, OwnedMutexGuard},
+};
+use gpui::{AsyncApp, Entity};
+use language::{Anchor, Buffer, ToOffset, ToPoint};
+use project::buffer_store::BufferStoreEvent;
+use project::{Project, ProjectPath};
+use std::{
+ cell::RefCell,
+ fs,
+ path::{Path, PathBuf},
+ sync::Arc,
+};
+use util::{paths::PathStyle, rel_path::RelPath};
+use zeta_prompt::CURSOR_MARKER;
+
+pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
+ if example.state.is_some() {
+ return;
+ }
+
+ let project = setup_project(example, &app_state, &mut cx).await;
+ let buffer_store = project
+ .read_with(&cx, |project, _| project.buffer_store().clone())
+ .unwrap();
+
+ let ep_store = cx
+ .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+ .unwrap();
+
+ cx.subscribe(&buffer_store, {
+ let project = project.clone();
+ move |_, event, cx| match event {
+ BufferStoreEvent::BufferAdded(buffer) => {
+ ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
+ }
+ _ => {}
+ }
+ })
+ .unwrap()
+ .detach();
+
+ let _open_buffers = apply_edit_history(example, &project, &mut cx)
+ .await
+ .unwrap();
+ let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
+ example.buffer = buffer
+ .read_with(&cx, |buffer, _cx| {
+ let cursor_point = cursor_position.to_point(&buffer);
+ Some(ExampleBuffer {
+ content: buffer.text(),
+ cursor_row: cursor_point.row,
+ cursor_column: cursor_point.column,
+ cursor_offset: cursor_position.to_offset(&buffer),
+ })
+ })
+ .unwrap();
+ example.state = Some(ExampleState {
+ buffer,
+ project,
+ cursor_position,
+ _open_buffers,
+ });
+}
+
+async fn cursor_position(
+ example: &Example,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+) -> (Entity<Buffer>, Anchor) {
+ let worktree = project
+ .read_with(cx, |project, cx| {
+ project.visible_worktrees(cx).next().unwrap()
+ })
+ .unwrap();
+
+ let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
+ .unwrap()
+ .into_arc();
+ let cursor_buffer = project
+ .update(cx, |project, cx| {
+ project.open_buffer(
+ ProjectPath {
+ worktree_id: worktree.read(cx).id(),
+ path: cursor_path,
+ },
+ cx,
+ )
+ })
+ .unwrap()
+ .await
+ .unwrap();
+ let cursor_offset_within_excerpt = example
+ .cursor_position
+ .find(CURSOR_MARKER)
+ .ok_or_else(|| anyhow!("missing cursor marker"))
+ .unwrap();
+ let mut cursor_excerpt = example.cursor_position.clone();
+ cursor_excerpt.replace_range(
+ cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
+ "",
+ );
+ let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
+ let text = buffer.text();
+
+ let mut matches = text.match_indices(&cursor_excerpt);
+ let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
+ panic!(
+ "\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
+ );
+ });
+ assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
+ excerpt_offset
+ }).unwrap();
+
+ let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
+ let cursor_anchor = cursor_buffer
+ .read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
+ .unwrap();
+
+ (cursor_buffer, cursor_anchor)
+}
+
+async fn setup_project(
+ example: &mut Example,
+ app_state: &Arc<EpAppState>,
+ cx: &mut AsyncApp,
+) -> Entity<Project> {
+ setup_worktree(example).await;
+
+ let project = cx
+ .update(|cx| {
+ Project::local(
+ app_state.client.clone(),
+ app_state.node_runtime.clone(),
+ app_state.user_store.clone(),
+ app_state.languages.clone(),
+ app_state.fs.clone(),
+ None,
+ cx,
+ )
+ })
+ .unwrap();
+
+ let worktree = project
+ .update(cx, |project, cx| {
+ project.create_worktree(&example.worktree_path(), true, cx)
+ })
+ .unwrap()
+ .await
+ .unwrap();
+ worktree
+ .read_with(cx, |worktree, _cx| {
+ worktree.as_local().unwrap().scan_complete()
+ })
+ .unwrap()
+ .await;
+ project
+}
+
+pub async fn setup_worktree(example: &Example) {
+ let repo_dir = example.repo_path();
+ let repo_lock = lock_repo(&repo_dir).await;
+
+ if !repo_dir.is_dir() {
+ fs::create_dir_all(&repo_dir).unwrap();
+ run_git(&repo_dir, &["init"]).await.unwrap();
+ run_git(
+ &repo_dir,
+ &["remote", "add", "origin", &example.repository_url],
+ )
+ .await
+ .unwrap();
+ }
+
+ // Resolve the example to a revision, fetching it if needed.
+ let revision = run_git(
+ &repo_dir,
+ &["rev-parse", &format!("{}^{{commit}}", example.revision)],
+ )
+ .await;
+ let revision = if let Ok(revision) = revision {
+ revision
+ } else {
+ if run_git(
+ &repo_dir,
+ &["fetch", "--depth", "1", "origin", &example.revision],
+ )
+ .await
+ .is_err()
+ {
+ run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
+ }
+ let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
+ .await
+ .unwrap();
+ if revision != example.revision {
+ run_git(&repo_dir, &["tag", &example.revision, &revision])
+ .await
+ .unwrap();
+ }
+ revision
+ };
+
+ // Create the worktree for this example if needed.
+ let worktree_path = example.worktree_path();
+ if worktree_path.is_dir() {
+ run_git(&worktree_path, &["clean", "--force", "-d"])
+ .await
+ .unwrap();
+ run_git(&worktree_path, &["reset", "--hard", "HEAD"])
+ .await
+ .unwrap();
+ run_git(&worktree_path, &["checkout", revision.as_str()])
+ .await
+ .unwrap();
+ } else {
+ let worktree_path_string = worktree_path.to_string_lossy();
+ run_git(
+ &repo_dir,
+ &["branch", "-f", &example.name, revision.as_str()],
+ )
+ .await
+ .unwrap();
+ run_git(
+ &repo_dir,
+ &[
+ "worktree",
+ "add",
+ "-f",
+ &worktree_path_string,
+ &example.name,
+ ],
+ )
+ .await
+ .unwrap();
+ }
+ drop(repo_lock);
+
+ // Apply the uncommitted diff for this example.
+ if !example.uncommitted_diff.is_empty() {
+ let mut apply_process = smol::process::Command::new("git")
+ .current_dir(&worktree_path)
+ .args(&["apply", "-"])
+ .stdin(std::process::Stdio::piped())
+ .spawn()
+ .unwrap();
+
+ let mut stdin = apply_process.stdin.take().unwrap();
+ stdin
+ .write_all(example.uncommitted_diff.as_bytes())
+ .await
+ .unwrap();
+ stdin.close().await.unwrap();
+ drop(stdin);
+
+ let apply_result = apply_process.output().await.unwrap();
+ if !apply_result.status.success() {
+ panic!(
+ "Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
+ apply_result.status,
+ String::from_utf8_lossy(&apply_result.stderr),
+ String::from_utf8_lossy(&apply_result.stdout),
+ );
+ }
+ }
+}
+
+async fn apply_edit_history(
+ example: &Example,
+ project: &Entity<Project>,
+ cx: &mut AsyncApp,
+) -> Result<OpenedBuffers> {
+ edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
+}
+
+thread_local! {
+ static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
+}
+
+#[must_use]
+pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
+ REPO_LOCKS
+ .with(|cell| {
+ cell.borrow_mut()
+ .entry(path.as_ref().to_path_buf())
+ .or_default()
+ .clone()
+ })
+ .lock_owned()
+ .await
+}
+
+async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
+ let output = smol::process::Command::new("git")
+ .current_dir(repo_path)
+ .args(args)
+ .output()
+ .await?;
+
+ anyhow::ensure!(
+ output.status.success(),
+ "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
+ args.join(" "),
+ repo_path.display(),
+ output.status,
+ String::from_utf8_lossy(&output.stderr),
+ String::from_utf8_lossy(&output.stdout),
+ );
+ Ok(String::from_utf8(output.stdout)?.trim().to_string())
+}
@@ -1,522 +1,196 @@
-mod evaluate;
+mod anthropic_client;
mod example;
+mod format_prompt;
mod headless;
+mod load_project;
mod metrics;
mod paths;
mod predict;
-mod source_location;
-mod training;
-mod util;
+mod retrieve_context;
+mod score;
-use crate::{
- evaluate::run_evaluate,
- example::{ExampleFormat, NamedExample},
- headless::ZetaCliAppState,
- predict::run_predict,
- source_location::SourceLocation,
- training::{context::ContextType, distill::run_distill},
- util::{open_buffer, open_buffer_with_language_server},
-};
-use ::util::{ResultExt, paths::PathStyle};
-use anyhow::{Result, anyhow};
-use clap::{Args, Parser, Subcommand, ValueEnum};
-use cloud_llm_client::predict_edits_v3;
-use edit_prediction::udiff::DiffLine;
-use edit_prediction_context::EditPredictionExcerptOptions;
-use gpui::{Application, AsyncApp, Entity, prelude::*};
-use language::{Bias, Buffer, BufferSnapshot, Point};
-use metrics::delta_chr_f;
-use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
+use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
+use edit_prediction::EditPredictionStore;
+use gpui::Application;
use reqwest_client::ReqwestClient;
-use std::io::{self};
-use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
+use serde::{Deserialize, Serialize};
+use std::{path::PathBuf, sync::Arc};
+
+use crate::example::{read_examples, write_examples};
+use crate::format_prompt::run_format_prompt;
+use crate::load_project::run_load_project;
+use crate::predict::run_prediction;
+use crate::retrieve_context::run_context_retrieval;
+use crate::score::run_scoring;
#[derive(Parser, Debug)]
-#[command(name = "zeta")]
-struct ZetaCliArgs {
+#[command(name = "ep")]
+struct EpArgs {
#[arg(long, default_value_t = false)]
printenv: bool,
+ #[clap(long, default_value_t = 10)]
+ max_parallelism: usize,
#[command(subcommand)]
command: Option<Command>,
+ #[clap(global = true)]
+ inputs: Vec<PathBuf>,
+ #[arg(long, short, global = true)]
+ output: Option<PathBuf>,
+ #[arg(long, short, global = true)]
+ in_place: bool,
}
#[derive(Subcommand, Debug)]
enum Command {
- Context(ContextArgs),
- Predict(PredictArguments),
- Eval(EvaluateArguments),
- Distill(DistillArguments),
- ConvertExample {
- path: PathBuf,
- #[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
- output_format: ExampleFormat,
- },
- Score {
- golden_patch: PathBuf,
- actual_patch: PathBuf,
- },
+ /// Parse markdown examples and output a combined .jsonl file
+ ParseExample,
+ /// Create git worktrees for each example and load file contents
+ LoadBuffer,
+ /// Retrieve context for input examples.
+ Context,
+ /// Generate a prompt string for a specific model
+ FormatPrompt(FormatPromptArgs),
+ /// Runs edit prediction
+ Predict(PredictArgs),
+ /// Computes a score based on actual and expected patches
+ Score(PredictArgs),
+ /// Print aggregated scores
+ Eval(PredictArgs),
+ /// Remove git repositories and worktrees
Clean,
}
#[derive(Debug, Args)]
-struct ContextArgs {
- #[arg(long)]
- provider: ContextProvider,
- #[arg(long)]
- worktree: PathBuf,
- #[arg(long)]
- cursor: SourceLocation,
- #[arg(long)]
- use_language_server: bool,
- #[arg(long)]
- edit_history: Option<FileOrStdin>,
- #[clap(flatten)]
- zeta2_args: Zeta2Args,
-}
-
-#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
-enum ContextProvider {
- Zeta1,
- #[default]
- Zeta2,
-}
-
-#[derive(Clone, Debug, Args)]
-struct Zeta2Args {
- #[arg(long, default_value_t = 8192)]
- max_prompt_bytes: usize,
- #[arg(long, default_value_t = 2048)]
- max_excerpt_bytes: usize,
- #[arg(long, default_value_t = 1024)]
- min_excerpt_bytes: usize,
- #[arg(long, default_value_t = 0.66)]
- target_before_cursor_over_total_bytes: f32,
- #[arg(long, default_value_t = 1024)]
- max_diagnostic_bytes: usize,
- #[arg(long, value_enum, default_value_t = PromptFormat::default())]
+struct FormatPromptArgs {
+ #[clap(long)]
prompt_format: PromptFormat,
- #[arg(long, value_enum, default_value_t = Default::default())]
- output_format: OutputFormat,
- #[arg(long, default_value_t = 42)]
- file_indexing_parallelism: usize,
- #[arg(long, default_value_t = false)]
- disable_imports_gathering: bool,
- #[arg(long, default_value_t = u8::MAX)]
- max_retrieved_definitions: u8,
}
-#[derive(Debug, Args)]
-pub struct PredictArguments {
- #[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
- format: PredictionsOutputFormat,
- example_path: PathBuf,
- #[clap(flatten)]
- options: PredictionOptions,
+#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
+enum PromptFormat {
+ Teacher,
+ Zeta2,
}
#[derive(Debug, Args)]
-pub struct DistillArguments {
- split_commit_dataset: PathBuf,
- #[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
- context_type: ContextType,
- #[clap(long)]
- batch: Option<String>,
-}
-
-#[derive(Clone, Debug, Args)]
-pub struct PredictionOptions {
- #[clap(flatten)]
- zeta2: Zeta2Args,
+struct PredictArgs {
#[clap(long)]
provider: PredictionProvider,
- #[clap(long, value_enum, default_value_t = CacheMode::default())]
- cache: CacheMode,
-}
-
-#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
-pub enum CacheMode {
- /// Use cached LLM requests and responses, except when multiple repetitions are requested
- #[default]
- Auto,
- /// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
- #[value(alias = "request")]
- Requests,
- /// Ignore existing cache entries for both LLM and search.
- Skip,
- /// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
- /// Useful for reproducing results and fixing bugs outside of search queries
- Force,
-}
-
-impl CacheMode {
- fn use_cached_llm_responses(&self) -> bool {
- self.assert_not_auto();
- matches!(self, CacheMode::Requests | CacheMode::Force)
- }
-
- fn use_cached_search_results(&self) -> bool {
- self.assert_not_auto();
- matches!(self, CacheMode::Force)
- }
-
- fn assert_not_auto(&self) {
- assert_ne!(
- *self,
- CacheMode::Auto,
- "Cache mode should not be auto at this point!"
- );
- }
-}
-
-#[derive(clap::ValueEnum, Debug, Clone)]
-pub enum PredictionsOutputFormat {
- Json,
- Md,
- Diff,
+ #[clap(long, default_value_t = 1)]
+ repetitions: usize,
}
-#[derive(Debug, Args)]
-pub struct EvaluateArguments {
- example_paths: Vec<PathBuf>,
- #[clap(flatten)]
- options: PredictionOptions,
- #[clap(short, long, default_value_t = 1, alias = "repeat")]
- repetitions: u16,
- #[arg(long)]
- skip_prediction: bool,
-}
-
-#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
+#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
enum PredictionProvider {
+ Sweep,
+ Mercury,
Zeta1,
- #[default]
Zeta2,
- Sweep,
-}
-
-fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
- edit_prediction::ZetaOptions {
- context: EditPredictionExcerptOptions {
- max_bytes: args.max_excerpt_bytes,
- min_bytes: args.min_excerpt_bytes,
- target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
- },
- max_prompt_bytes: args.max_prompt_bytes,
- prompt_format: args.prompt_format.into(),
- }
-}
-
-#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
-enum PromptFormat {
- OnlySnippets,
- #[default]
- OldTextNewText,
- Minimal,
- MinimalQwen,
- SeedCoder1120,
+ Teacher,
}
-impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
- fn into(self) -> predict_edits_v3::PromptFormat {
- match self {
- Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
- Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
- Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
- Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
- Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
+impl EpArgs {
+ fn output_path(&self) -> Option<PathBuf> {
+ if self.in_place {
+ if self.inputs.len() == 1 {
+ self.inputs.first().cloned()
+ } else {
+ panic!("--in-place requires exactly one input file")
+ }
+ } else {
+ self.output.clone()
}
}
}
-#[derive(clap::ValueEnum, Default, Debug, Clone)]
-enum OutputFormat {
- #[default]
- Prompt,
- Request,
- Full,
-}
-
-#[derive(Debug, Clone)]
-enum FileOrStdin {
- File(PathBuf),
- Stdin,
-}
+fn main() {
+ zlog::init();
+ zlog::init_output_stderr();
+ let args = EpArgs::parse();
-impl FileOrStdin {
- async fn read_to_string(&self) -> Result<String, std::io::Error> {
- match self {
- FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
- FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
- }
+ if args.printenv {
+ ::util::shell_env::print_env();
+ return;
}
-}
-
-impl FromStr for FileOrStdin {
- type Err = <PathBuf as FromStr>::Err;
- fn from_str(s: &str) -> Result<Self, Self::Err> {
- match s {
- "-" => Ok(Self::Stdin),
- _ => Ok(Self::File(PathBuf::from_str(s)?)),
+ let output = args.output_path();
+ let command = match args.command {
+ Some(cmd) => cmd,
+ None => {
+ EpArgs::command().print_help().unwrap();
+ return;
}
- }
-}
-
-struct LoadedContext {
- full_path_str: String,
- snapshot: BufferSnapshot,
- clipped_cursor: Point,
- worktree: Entity<Worktree>,
- project: Entity<Project>,
- buffer: Entity<Buffer>,
- lsp_open_handle: Option<OpenLspBufferHandle>,
-}
-
-async fn load_context(
- args: &ContextArgs,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
-) -> Result<LoadedContext> {
- let ContextArgs {
- worktree: worktree_path,
- cursor,
- use_language_server,
- ..
- } = args;
-
- let worktree_path = worktree_path.canonicalize()?;
-
- let project = cx.update(|cx| {
- Project::local(
- app_state.client.clone(),
- app_state.node_runtime.clone(),
- app_state.user_store.clone(),
- app_state.languages.clone(),
- app_state.fs.clone(),
- None,
- cx,
- )
- })?;
-
- let worktree = project
- .update(cx, |project, cx| {
- project.create_worktree(&worktree_path, true, cx)
- })?
- .await?;
-
- let mut ready_languages = HashSet::default();
- let (lsp_open_handle, buffer) = if *use_language_server {
- let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
- project.clone(),
- worktree.clone(),
- cursor.path.clone(),
- &mut ready_languages,
- cx,
- )
- .await?;
- (Some(lsp_open_handle), buffer)
- } else {
- let buffer =
- open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
- (None, buffer)
};
- let full_path_str = worktree
- .read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
- .display(PathStyle::local())
- .to_string();
-
- let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
- let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
- if clipped_cursor != cursor.point {
- let max_row = snapshot.max_point().row;
- if cursor.point.row < max_row {
- return Err(anyhow!(
- "Cursor position {:?} is out of bounds (line length is {})",
- cursor.point,
- snapshot.line_len(cursor.point.row)
- ));
- } else {
- return Err(anyhow!(
- "Cursor position {:?} is out of bounds (max row is {})",
- cursor.point,
- max_row
- ));
+ match &command {
+ Command::Clean => {
+ std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
+ return;
}
+ _ => {}
}
- Ok(LoadedContext {
- full_path_str,
- snapshot,
- clipped_cursor,
- worktree,
- project,
- buffer,
- lsp_open_handle,
- })
-}
-
-async fn zeta2_context(
- args: ContextArgs,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
-) -> Result<String> {
- let LoadedContext {
- worktree,
- project,
- buffer,
- clipped_cursor,
- lsp_open_handle: _handle,
- ..
- } = load_context(&args, app_state, cx).await?;
-
- // wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
- // the whole worktree.
- worktree
- .read_with(cx, |worktree, _cx| {
- worktree.as_local().unwrap().scan_complete()
- })?
- .await;
- let output = cx
- .update(|cx| {
- let store = cx.new(|cx| {
- edit_prediction::EditPredictionStore::new(
- app_state.client.clone(),
- app_state.user_store.clone(),
- cx,
- )
- });
- store.update(cx, |store, cx| {
- store.set_options(zeta2_args_to_options(&args.zeta2_args));
- store.register_buffer(&buffer, &project, cx);
- });
- cx.spawn(async move |cx| {
- let updates_rx = store.update(cx, |store, cx| {
- let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
- store.set_use_context(true);
- store.refresh_context(&project, &buffer, cursor, cx);
- store.project_context_updates(&project).unwrap()
- })?;
-
- updates_rx.recv().await.ok();
-
- let context = store.update(cx, |store, cx| {
- store.context_for_project(&project, cx).to_vec()
- })?;
-
- anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
- })
- })?
- .await?;
-
- Ok(output)
-}
-
-async fn zeta1_context(
- args: ContextArgs,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
-) -> Result<edit_prediction::zeta1::GatherContextOutput> {
- let LoadedContext {
- full_path_str,
- snapshot,
- clipped_cursor,
- ..
- } = load_context(&args, app_state, cx).await?;
-
- let events = match args.edit_history {
- Some(events) => events.read_to_string().await?,
- None => String::new(),
- };
-
- let prompt_for_events = move || (events, 0);
- cx.update(|cx| {
- edit_prediction::zeta1::gather_context(
- full_path_str,
- &snapshot,
- clipped_cursor,
- prompt_for_events,
- cloud_llm_client::PredictEditsRequestTrigger::Cli,
- cx,
- )
- })?
- .await
-}
-
-fn main() {
- zlog::init();
- zlog::init_output_stderr();
- let args = ZetaCliArgs::parse();
+ let mut examples = read_examples(&args.inputs);
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client);
app.run(move |cx| {
let app_state = Arc::new(headless::init(cx));
+ EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
+
cx.spawn(async move |cx| {
- match args.command {
- None => {
- if args.printenv {
- ::util::shell_env::print_env();
- } else {
- panic!("Expected a command");
- }
- }
- Some(Command::Context(context_args)) => {
- let result = match context_args.provider {
- ContextProvider::Zeta1 => {
- let context =
- zeta1_context(context_args, &app_state, cx).await.unwrap();
- serde_json::to_string_pretty(&context.body).unwrap()
- }
- ContextProvider::Zeta2 => {
- zeta2_context(context_args, &app_state, cx).await.unwrap()
+ match &command {
+ Command::Predict(args) => predict::sync_batches(&args.provider).await,
+ _ => (),
+ };
+
+ for data in examples.chunks_mut(args.max_parallelism) {
+ let mut futures = Vec::new();
+ for example in data.iter_mut() {
+ let cx = cx.clone();
+ let app_state = app_state.clone();
+ futures.push(async {
+ match &command {
+ Command::ParseExample => {}
+ Command::LoadBuffer => {
+ run_load_project(example, app_state.clone(), cx).await;
+ }
+ Command::Context => {
+ run_context_retrieval(example, app_state, cx).await;
+ }
+ Command::FormatPrompt(args) => {
+ run_format_prompt(example, args.prompt_format, app_state, cx).await;
+ }
+ Command::Predict(args) => {
+ run_prediction(
+ example,
+ Some(args.provider),
+ args.repetitions,
+ app_state.clone(),
+ cx,
+ )
+ .await;
+ }
+ Command::Score(args) | Command::Eval(args) => {
+ run_scoring(example, &args, app_state, cx).await;
+ }
+ Command::Clean => {
+ unreachable!()
+ }
}
- };
- println!("{}", result);
- }
- Some(Command::Predict(arguments)) => {
- run_predict(arguments, &app_state, cx).await;
- }
- Some(Command::Eval(arguments)) => {
- run_evaluate(arguments, &app_state, cx).await;
+ });
}
- Some(Command::Distill(arguments)) => {
- let _guard = cx
- .update(|cx| gpui_tokio::Tokio::handle(cx))
- .unwrap()
- .enter();
- run_distill(arguments).await.log_err();
- }
- Some(Command::ConvertExample {
- path,
- output_format,
- }) => {
- let example = NamedExample::load(path).unwrap();
- example.write(output_format, io::stdout()).unwrap();
- }
- Some(Command::Score {
- golden_patch,
- actual_patch,
- }) => {
- let golden_content = std::fs::read_to_string(golden_patch).unwrap();
- let actual_content = std::fs::read_to_string(actual_patch).unwrap();
-
- let golden_diff: Vec<DiffLine> = golden_content
- .lines()
- .map(|line| DiffLine::parse(line))
- .collect();
+ futures::future::join_all(futures).await;
+ }
- let actual_diff: Vec<DiffLine> = actual_content
- .lines()
- .map(|line| DiffLine::parse(line))
- .collect();
+ if args.output.is_some() || !matches!(command, Command::Eval(_)) {
+ write_examples(&examples, output.as_ref());
+ }
- let score = delta_chr_f(&golden_diff, &actual_diff);
- println!("{:.2}", score);
- }
- Some(Command::Clean) => {
- std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
- }
+ match &command {
+ Command::Predict(args) => predict::sync_batches(&args.provider).await,
+ Command::Eval(_) => score::print_report(&examples),
+ _ => (),
};
let _ = cx.update(|cx| cx.quit());
@@ -1,30 +1,34 @@
use collections::{HashMap, HashSet};
use edit_prediction::udiff::DiffLine;
+use serde::{Deserialize, Serialize};
type Counts = HashMap<String, usize>;
type CountsDelta = HashMap<String, isize>;
-#[derive(Default, Debug, Clone)]
-pub struct Scores {
+#[derive(Default, Debug, Clone, Serialize, Deserialize)]
+pub struct ClassificationMetrics {
pub true_positives: usize,
pub false_positives: usize,
pub false_negatives: usize,
}
-impl Scores {
- pub fn from_sets(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
+impl ClassificationMetrics {
+ pub fn from_sets(
+ expected: &HashSet<String>,
+ actual: &HashSet<String>,
+ ) -> ClassificationMetrics {
let true_positives = expected.intersection(actual).count();
let false_positives = actual.difference(expected).count();
let false_negatives = expected.difference(actual).count();
- Scores {
+ ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
}
}
- pub fn from_counts(expected: &Counts, actual: &Counts) -> Scores {
+ pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -45,32 +49,16 @@ impl Scores {
}
}
- Scores {
+ ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
}
}
- pub fn to_markdown(&self) -> String {
- format!(
- "
-Precision : {:.4}
-Recall : {:.4}
-F1 Score : {:.4}
-True Positives : {}
-False Positives : {}
-False Negatives : {}",
- self.precision(),
- self.recall(),
- self.f1_score(),
- self.true_positives,
- self.false_positives,
- self.false_negatives
- )
- }
-
- pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
+ pub fn aggregate<'a>(
+ scores: impl Iterator<Item = &'a ClassificationMetrics>,
+ ) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -81,7 +69,7 @@ False Negatives : {}",
false_negatives += score.false_negatives;
}
- Scores {
+ ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
@@ -115,7 +103,10 @@ False Negatives : {}",
}
}
-pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine]) -> Scores {
+pub fn line_match_score(
+ expected_patch: &[DiffLine],
+ actual_patch: &[DiffLine],
+) -> ClassificationMetrics {
let expected_change_lines = expected_patch
.iter()
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
@@ -128,7 +119,7 @@ pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine])
.map(|line| line.to_string())
.collect();
- Scores::from_sets(&expected_change_lines, &actual_change_lines)
+ ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
}
enum ChrfWhitespace {
@@ -204,7 +195,7 @@ pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
let expected_counts = ngram_delta_to_counts(&expected_delta);
let actual_counts = ngram_delta_to_counts(&actual_delta);
- let score = Scores::from_counts(&expected_counts, &actual_counts);
+ let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
total_precision += score.precision();
total_recall += score.recall();
}
@@ -1,57 +1,25 @@
-use std::{env, path::PathBuf, sync::LazyLock};
+use std::{
+ path::{Path, PathBuf},
+ sync::LazyLock,
+};
-pub static TARGET_ZETA_DIR: LazyLock<PathBuf> =
- LazyLock::new(|| env::current_dir().unwrap().join("target/zeta"));
-pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("cache"));
-pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("repos"));
-pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("worktrees"));
+pub static DATA_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
+ let dir = dirs::home_dir().unwrap().join(".zed_ep");
+ ensure_dir(&dir)
+});
+pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("cache")));
+pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("repos")));
+pub static WORKTREES_DIR: LazyLock<PathBuf> =
+ LazyLock::new(|| ensure_dir(&DATA_DIR.join("worktrees")));
pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
- TARGET_ZETA_DIR
+ DATA_DIR
.join("runs")
.join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
});
-pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> =
- LazyLock::new(|| TARGET_ZETA_DIR.join("latest"));
-
-pub fn print_run_data_dir(deep: bool, use_color: bool) {
- println!("\n## Run Data\n");
- let mut files = Vec::new();
-
- let current_dir = std::env::current_dir().unwrap();
- for file in std::fs::read_dir(&*RUN_DIR).unwrap() {
- let file = file.unwrap();
- if file.file_type().unwrap().is_dir() && deep {
- for file in std::fs::read_dir(file.path()).unwrap() {
- let path = file.unwrap().path();
- let path = path.strip_prefix(¤t_dir).unwrap_or(&path);
- files.push(format!(
- "- {}/{}{}{}",
- path.parent().unwrap().display(),
- if use_color { "\x1b[34m" } else { "" },
- path.file_name().unwrap().display(),
- if use_color { "\x1b[0m" } else { "" },
- ));
- }
- } else {
- let path = file.path();
- let path = path.strip_prefix(¤t_dir).unwrap_or(&path);
- files.push(format!(
- "- {}/{}{}{}",
- path.parent().unwrap().display(),
- if use_color { "\x1b[34m" } else { "" },
- path.file_name().unwrap().display(),
- if use_color { "\x1b[0m" } else { "" }
- ));
- }
- }
- files.sort();
-
- for file in files {
- println!("{}", file);
- }
+pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
+pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
- println!(
- "\n💡 Tip of the day: {} always points to the latest run\n",
- LATEST_EXAMPLE_RUN_DIR.display()
- );
+fn ensure_dir(path: &Path) -> PathBuf {
+ std::fs::create_dir_all(path).expect("Failed to create directory");
+ path.to_path_buf()
}
@@ -1,374 +1,271 @@
-use crate::example::{ActualExcerpt, NamedExample};
-use crate::headless::ZetaCliAppState;
-use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
use crate::{
- CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
+ PredictionProvider, PromptFormat,
+ anthropic_client::AnthropicClient,
+ example::{Example, ExamplePrediction},
+ format_prompt::{PromptParser, TeacherPrompt, run_format_prompt},
+ headless::EpAppState,
+ load_project::run_load_project,
+ paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
+ retrieve_context::run_context_retrieval,
+};
+use edit_prediction::{DebugEvent, EditPredictionStore};
+use futures::{FutureExt as _, StreamExt as _, future::Shared};
+use gpui::{AppContext as _, AsyncApp, Task};
+use std::{
+ fs,
+ sync::{
+ Arc, Mutex, OnceLock,
+ atomic::{AtomicUsize, Ordering::SeqCst},
+ },
};
-use ::serde::Serialize;
-use anyhow::{Context, Result, anyhow};
-use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
-use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
-use futures::StreamExt as _;
-use gpui::{AppContext, AsyncApp, Entity};
-use project::Project;
-use project::buffer_store::BufferStoreEvent;
-use serde::Deserialize;
-use std::fs;
-use std::io::{IsTerminal, Write};
-use std::path::PathBuf;
-use std::sync::Arc;
-use std::sync::Mutex;
-use std::time::{Duration, Instant};
-pub async fn run_predict(
- args: PredictArguments,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
+pub async fn run_prediction(
+ example: &mut Example,
+ provider: Option<PredictionProvider>,
+ repetition_count: usize,
+ app_state: Arc<EpAppState>,
+ mut cx: AsyncApp,
) {
- let example = NamedExample::load(args.example_path).unwrap();
- let project = example.setup_project(app_state, cx).await.unwrap();
- let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
- let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
- let result = perform_predict(example, project, store, None, args.options, cx)
- .await
- .unwrap();
- result.write(args.format, std::io::stdout()).unwrap();
-
- print_run_data_dir(true, std::io::stdout().is_terminal());
-}
-
-pub fn setup_store(
- provider: PredictionProvider,
- project: &Entity<Project>,
- app_state: &Arc<ZetaCliAppState>,
- cx: &mut AsyncApp,
-) -> Result<Entity<EditPredictionStore>> {
- let store = cx.new(|cx| {
- edit_prediction::EditPredictionStore::new(
- app_state.client.clone(),
- app_state.user_store.clone(),
- cx,
- )
- })?;
+ if !example.predictions.is_empty() {
+ return;
+ }
- store.update(cx, |store, _cx| {
- let model = match provider {
- PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
- PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
- PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
- };
- store.set_edit_prediction_model(model);
- })?;
+ run_load_project(example, app_state.clone(), cx.clone()).await;
+ run_context_retrieval(example, app_state.clone(), cx.clone()).await;
- let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
+ let provider = provider.unwrap();
- cx.subscribe(&buffer_store, {
- let project = project.clone();
- let store = store.clone();
- move |_, event, cx| match event {
- BufferStoreEvent::BufferAdded(buffer) => {
- store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
- }
- _ => {}
+ if matches!(provider, PredictionProvider::Teacher) {
+ if example.prompt.is_none() {
+ run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
}
- })?
- .detach();
- anyhow::Ok(store)
-}
-
-pub async fn perform_predict(
- example: NamedExample,
- project: Entity<Project>,
- store: Entity<EditPredictionStore>,
- repetition_ix: Option<u16>,
- options: PredictionOptions,
- cx: &mut AsyncApp,
-) -> Result<PredictionDetails> {
- let mut cache_mode = options.cache;
- if repetition_ix.is_some() {
- if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
- panic!("Repetitions are not supported in Auto cache mode");
- } else {
- cache_mode = CacheMode::Skip;
- }
- } else if cache_mode == CacheMode::Auto {
- cache_mode = CacheMode::Requests;
+ let batched = true;
+ return predict_anthropic(example, repetition_count, batched).await;
}
- let mut example_run_dir = RUN_DIR.join(&example.file_name());
- if let Some(repetition_ix) = repetition_ix {
- example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
- }
- fs::create_dir_all(&example_run_dir)?;
- if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
- fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
+ if matches!(
+ provider,
+ PredictionProvider::Zeta1 | PredictionProvider::Zeta2
+ ) {
+ static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
+ AUTHENTICATED
+ .get_or_init(|| {
+ let client = app_state.client.clone();
+ cx.spawn(async move |cx| {
+ client
+ .sign_in_with_optional_connect(true, cx)
+ .await
+ .unwrap();
+ })
+ .shared()
+ })
+ .clone()
+ .await;
}
- #[cfg(unix)]
- std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
- .context("creating latest link")?;
-
- #[cfg(windows)]
- std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
- .context("creating latest link")?;
-
- store.update(cx, |store, _cx| {
- store.with_eval_cache(Arc::new(RunCache {
- example_run_dir: example_run_dir.clone(),
- cache_mode,
- }));
- })?;
-
- let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
-
- let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
-
- let prompt_format = options.zeta2.prompt_format;
-
- store.update(cx, |store, _cx| {
- let mut options = store.options().clone();
- options.prompt_format = prompt_format.into();
- store.set_options(options);
- })?;
+ let ep_store = cx
+ .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+ .unwrap();
- let mut debug_task = gpui::Task::ready(Ok(()));
+ ep_store
+ .update(&mut cx, |store, _cx| {
+ let model = match provider {
+ PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
+ PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
+ PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
+ PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
+ PredictionProvider::Teacher => unreachable!(),
+ };
+ store.set_edit_prediction_model(model);
+ })
+ .unwrap();
+ let state = example.state.as_ref().unwrap();
+ let run_dir = RUN_DIR.join(&example.name);
- if options.provider == crate::PredictionProvider::Zeta2 {
- let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
+ let updated_example = Arc::new(Mutex::new(example.clone()));
+ let current_run_ix = Arc::new(AtomicUsize::new(0));
- debug_task = cx.background_spawn({
- let result = result.clone();
- async move {
- let mut start_time = None;
- let mut retrieval_finished_at = None;
- while let Some(event) = debug_rx.next().await {
- match event {
- edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
- start_time = Some(info.timestamp);
- fs::write(
- example_run_dir.join("search_prompt.md"),
- &info.search_prompt,
- )?;
+ let mut debug_rx = ep_store
+ .update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
+ .unwrap();
+ let debug_task = cx.background_spawn({
+ let updated_example = updated_example.clone();
+ let current_run_ix = current_run_ix.clone();
+ let run_dir = run_dir.clone();
+ async move {
+ while let Some(event) = debug_rx.next().await {
+ let run_ix = current_run_ix.load(SeqCst);
+ let mut updated_example = updated_example.lock().unwrap();
+
+ let run_dir = if repetition_count > 1 {
+ run_dir.join(format!("{:03}", run_ix))
+ } else {
+ run_dir.clone()
+ };
+
+ match event {
+ DebugEvent::EditPredictionStarted(request) => {
+ assert_eq!(updated_example.predictions.len(), run_ix + 1);
+
+ if let Some(prompt) = request.prompt {
+ fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
}
- edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
- retrieval_finished_at = Some(info.timestamp);
- for (key, value) in &info.metadata {
- if *key == "search_queries" {
- fs::write(
- example_run_dir.join("search_queries.json"),
- value.as_bytes(),
- )?;
- }
- }
+ }
+ DebugEvent::EditPredictionFinished(request) => {
+ assert_eq!(updated_example.predictions.len(), run_ix + 1);
+
+ if let Some(output) = request.model_output {
+ fs::write(run_dir.join("prediction_response.md"), &output)?;
+ updated_example
+ .predictions
+ .last_mut()
+ .unwrap()
+ .actual_output = output;
}
- edit_prediction::DebugEvent::EditPredictionRequested(request) => {
- let prediction_started_at = Instant::now();
- start_time.get_or_insert(prediction_started_at);
- let prompt = request.local_prompt.unwrap_or_default();
- fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
-
- {
- let mut result = result.lock().unwrap();
- result.prompt_len = prompt.chars().count();
-
- for included_file in request.inputs.included_files {
- let insertions =
- vec![(request.inputs.cursor_point, CURSOR_MARKER)];
- result.excerpts.extend(included_file.excerpts.iter().map(
- |excerpt| ActualExcerpt {
- path: included_file.path.components().skip(1).collect(),
- text: String::from(excerpt.text.as_ref()),
- },
- ));
- write_codeblock(
- &included_file.path,
- included_file.excerpts.iter(),
- if included_file.path == request.inputs.cursor_path {
- &insertions
- } else {
- &[]
- },
- included_file.max_row,
- false,
- &mut result.excerpts_text,
- );
- }
- }
-
- let response =
- request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
- let response =
- edit_prediction::open_ai_response::text_from_response(response)
- .unwrap_or_default();
- let prediction_finished_at = Instant::now();
- fs::write(example_run_dir.join("prediction_response.md"), &response)?;
-
- let mut result = result.lock().unwrap();
- result.generated_len = response.chars().count();
- result.retrieval_time =
- retrieval_finished_at.unwrap() - start_time.unwrap();
- result.prediction_time = prediction_finished_at - prediction_started_at;
- result.total_time = prediction_finished_at - start_time.unwrap();
-
+ if run_ix >= repetition_count {
break;
}
}
+ _ => {}
}
- anyhow::Ok(())
}
- });
-
- store.update(cx, |store, cx| {
- store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
- })?;
- }
-
- let prediction = store
- .update(cx, |store, cx| {
- store.request_prediction(
- &project,
- &cursor_buffer,
- cursor_anchor,
- cloud_llm_client::PredictEditsRequestTrigger::Cli,
- cx,
- )
- })?
- .await?;
-
- debug_task.await?;
-
- let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
-
- result.diff = prediction
- .and_then(|prediction| {
- let prediction = prediction.prediction.ok()?;
- prediction.edit_preview.as_unified_diff(&prediction.edits)
- })
- .unwrap_or_default();
-
- anyhow::Ok(result)
-}
-
-struct RunCache {
- cache_mode: CacheMode,
- example_run_dir: PathBuf,
-}
+ anyhow::Ok(())
+ }
+ });
-impl RunCache {
- fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
- CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
- }
+ for ix in 0..repetition_count {
+ current_run_ix.store(ix, SeqCst);
+ let run_dir = if repetition_count > 1 {
+ run_dir.join(format!("{:03}", ix))
+ } else {
+ run_dir.clone()
+ };
- fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
- CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
+ fs::create_dir_all(&run_dir).unwrap();
+ if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
+ fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
+ }
+ #[cfg(unix)]
+ std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
+ #[cfg(windows)]
+ std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
+
+ updated_example
+ .lock()
+ .unwrap()
+ .predictions
+ .push(ExamplePrediction {
+ actual_patch: String::new(),
+ actual_output: String::new(),
+ provider,
+ });
+
+ let prediction = ep_store
+ .update(&mut cx, |store, cx| {
+ store.request_prediction(
+ &state.project,
+ &state.buffer,
+ state.cursor_position,
+ cloud_llm_client::PredictEditsRequestTrigger::Cli,
+ cx,
+ )
+ })
+ .unwrap()
+ .await
+ .unwrap();
+
+ updated_example
+ .lock()
+ .unwrap()
+ .predictions
+ .last_mut()
+ .unwrap()
+ .actual_patch = prediction
+ .and_then(|prediction| {
+ let prediction = prediction.prediction.ok()?;
+ prediction.edit_preview.as_unified_diff(&prediction.edits)
+ })
+ .unwrap_or_default();
}
- fn link_to_run(&self, key: &EvalCacheKey) {
- let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
- fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
+ ep_store
+ .update(&mut cx, |store, _| {
+ store.remove_project(&state.project);
+ })
+ .unwrap();
+ debug_task.await.unwrap();
- let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
- fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
- }
+ *example = Arc::into_inner(updated_example)
+ .unwrap()
+ .into_inner()
+ .unwrap();
}
-impl EvalCache for RunCache {
- fn read(&self, key: EvalCacheKey) -> Option<String> {
- let path = RunCache::output_cache_path(&key);
-
- if path.exists() {
- let use_cache = match key.0 {
- EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
- EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
- self.cache_mode.use_cached_llm_responses()
- }
- };
- if use_cache {
- log::info!("Using cache entry: {}", path.display());
- self.link_to_run(&key);
- Some(fs::read_to_string(path).unwrap())
- } else {
- log::trace!("Skipping cached entry: {}", path.display());
- None
- }
- } else if matches!(self.cache_mode, CacheMode::Force) {
- panic!(
- "No cached entry found for {:?}. Run without `--cache force` at least once.",
- key.0
- );
- } else {
- None
- }
- }
-
- fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
- fs::create_dir_all(&*CACHE_DIR).unwrap();
+async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
+ let llm_model_name = "claude-sonnet-4-5";
+ let max_tokens = 16384;
+ let llm_client = if batched {
+ AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
+ } else {
+ AnthropicClient::plain()
+ };
+ let llm_client = llm_client.expect("Failed to create LLM client");
+
+ let prompt = example
+ .prompt
+ .as_ref()
+ .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
+
+ let messages = vec![anthropic::Message {
+ role: anthropic::Role::User,
+ content: vec![anthropic::RequestContent::Text {
+ text: prompt.input.clone(),
+ cache_control: None,
+ }],
+ }];
+
+ let Some(response) = llm_client
+ .generate(llm_model_name, max_tokens, messages)
+ .await
+ .unwrap()
+ else {
+ // Request stashed for batched processing
+ return;
+ };
+
+ let actual_output = response
+ .content
+ .into_iter()
+ .filter_map(|content| match content {
+ anthropic::ResponseContent::Text { text } => Some(text),
+ _ => None,
+ })
+ .collect::<Vec<String>>()
+ .join("\n");
- let input_path = RunCache::input_cache_path(&key);
- fs::write(&input_path, input).unwrap();
+ let actual_patch = TeacherPrompt::parse(example, &actual_output);
- let output_path = RunCache::output_cache_path(&key);
- log::trace!("Writing cache entry: {}", output_path.display());
- fs::write(&output_path, output).unwrap();
+ let prediction = ExamplePrediction {
+ actual_patch,
+ actual_output,
+ provider: PredictionProvider::Teacher,
+ };
- self.link_to_run(&key);
- }
+ example.predictions.push(prediction);
}
-#[derive(Clone, Debug, Serialize, Deserialize)]
-pub struct PredictionDetails {
- pub diff: String,
- pub excerpts: Vec<ActualExcerpt>,
- pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
- pub retrieval_time: Duration,
- pub prediction_time: Duration,
- pub total_time: Duration,
- pub run_example_dir: PathBuf,
- pub prompt_len: usize,
- pub generated_len: usize,
-}
-
-impl PredictionDetails {
- pub fn new(run_example_dir: PathBuf) -> Self {
- Self {
- diff: Default::default(),
- excerpts: Default::default(),
- excerpts_text: Default::default(),
- retrieval_time: Default::default(),
- prediction_time: Default::default(),
- total_time: Default::default(),
- run_example_dir,
- prompt_len: 0,
- generated_len: 0,
+pub async fn sync_batches(provider: &PredictionProvider) {
+ match provider {
+ PredictionProvider::Teacher => {
+ let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
+ let llm_client =
+ AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
+ llm_client
+ .sync_batches()
+ .await
+ .expect("Failed to sync batches");
}
- }
-
- pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
- let formatted = match format {
- PredictionsOutputFormat::Md => self.to_markdown(),
- PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
- PredictionsOutputFormat::Diff => self.diff.clone(),
- };
-
- Ok(out.write_all(formatted.as_bytes())?)
- }
-
- pub fn to_markdown(&self) -> String {
- format!(
- "## Excerpts\n\n\
- {}\n\n\
- ## Prediction\n\n\
- {}\n\n\
- ## Time\n\n\
- Retrieval: {}ms\n\
- Prediction: {}ms\n\n\
- Total: {}ms\n",
- self.excerpts_text,
- self.diff,
- self.retrieval_time.as_millis(),
- self.prediction_time.as_millis(),
- self.total_time.as_millis(),
- )
+ _ => (),
}
}
@@ -1,106 +1,136 @@
-use anyhow::{Result, anyhow};
-use futures::channel::mpsc;
-use futures::{FutureExt as _, StreamExt as _};
+use crate::{
+ example::{Example, ExampleContext},
+ headless::EpAppState,
+ load_project::run_load_project,
+};
+use anyhow::Result;
+use collections::HashSet;
+use edit_prediction::{DebugEvent, EditPredictionStore};
+use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
use gpui::{AsyncApp, Entity, Task};
-use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus};
-use project::lsp_store::OpenLspBufferHandle;
-use project::{Project, ProjectPath, Worktree};
-use std::collections::HashSet;
-use std::sync::Arc;
-use std::time::Duration;
-use util::rel_path::RelPath;
-
-pub fn open_buffer(
- project: Entity<Project>,
- worktree: Entity<Worktree>,
- path: Arc<RelPath>,
- cx: &AsyncApp,
-) -> Task<Result<Entity<Buffer>>> {
- cx.spawn(async move |cx| {
- let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
- worktree_id: worktree.id(),
- path,
- })?;
-
- let buffer = project
- .update(cx, |project, cx| project.open_buffer(project_path, cx))?
- .await?;
-
- let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
- while *parse_status.borrow() != ParseStatus::Idle {
- parse_status.changed().await?;
+use language::{Buffer, LanguageNotFound};
+use project::Project;
+use std::{sync::Arc, time::Duration};
+
+pub async fn run_context_retrieval(
+ example: &mut Example,
+ app_state: Arc<EpAppState>,
+ mut cx: AsyncApp,
+) {
+ if example.context.is_some() {
+ return;
+ }
+
+ run_load_project(example, app_state.clone(), cx.clone()).await;
+
+ let state = example.state.as_ref().unwrap();
+ let project = state.project.clone();
+
+ let _lsp_handle = project
+ .update(&mut cx, |project, cx| {
+ project.register_buffer_with_language_servers(&state.buffer, cx)
+ })
+ .unwrap();
+
+ wait_for_language_server_to_start(example, &project, &state.buffer, &mut cx).await;
+
+ let ep_store = cx
+ .update(|cx| EditPredictionStore::try_global(cx).unwrap())
+ .unwrap();
+
+ let mut events = ep_store
+ .update(&mut cx, |store, cx| {
+ store.register_buffer(&state.buffer, &project, cx);
+ store.set_use_context(true);
+ store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
+ store.debug_info(&project, cx)
+ })
+ .unwrap();
+
+ while let Some(event) = events.next().await {
+ match event {
+ DebugEvent::ContextRetrievalFinished(_) => {
+ break;
+ }
+ _ => {}
}
+ }
- Ok(buffer)
- })
+ let context_files = ep_store
+ .update(&mut cx, |store, cx| store.context_for_project(&project, cx))
+ .unwrap();
+
+ example.context = Some(ExampleContext {
+ files: context_files,
+ });
}
-pub async fn open_buffer_with_language_server(
- project: Entity<Project>,
- worktree: Entity<Worktree>,
- path: Arc<RelPath>,
- ready_languages: &mut HashSet<LanguageId>,
+async fn wait_for_language_server_to_start(
+ example: &Example,
+ project: &Entity<Project>,
+ buffer: &Entity<Buffer>,
cx: &mut AsyncApp,
-) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity<Buffer>)> {
- let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
-
- let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
- (
- project.register_buffer_with_language_servers(&buffer, cx),
- project.path_style(cx),
- )
- })?;
-
- let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
+) {
+ let language_registry = project
+ .read_with(cx, |project, _| project.languages().clone())
+ .unwrap();
let result = language_registry
- .load_language_for_file_path(path.as_std_path())
+ .load_language_for_file_path(&example.cursor_path)
.await;
if let Err(error) = result
&& !error.is::<LanguageNotFound>()
{
- anyhow::bail!(error);
+ panic!("Failed to load language for file path: {}", error);
}
- let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
- buffer.language().map(|language| language.id())
- })?
+ let Some(language_id) = buffer
+ .read_with(cx, |buffer, _cx| {
+ buffer.language().map(|language| language.id())
+ })
+ .unwrap()
else {
- return Err(anyhow!("No language for {}", path.display(path_style)));
+ panic!("No language for {:?}", example.cursor_path);
};
- let log_prefix = format!("{} | ", path.display(path_style));
+ let mut ready_languages = HashSet::default();
+ let log_prefix = format!("{} | ", example.name);
if !ready_languages.contains(&language_id) {
- wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
+ wait_for_lang_server(&project, &buffer, log_prefix, cx)
+ .await
+ .unwrap();
ready_languages.insert(language_id);
}
- let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
+ let lsp_store = project
+ .read_with(cx, |project, _cx| project.lsp_store())
+ .unwrap();
// hacky wait for buffer to be registered with the language server
for _ in 0..100 {
- let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
- buffer.update(cx, |buffer, cx| {
- lsp_store
- .language_servers_for_local_buffer(&buffer, cx)
- .next()
- .map(|(_, language_server)| language_server.server_id())
+ if lsp_store
+ .update(cx, |lsp_store, cx| {
+ buffer.update(cx, |buffer, cx| {
+ lsp_store
+ .language_servers_for_local_buffer(&buffer, cx)
+ .next()
+ .map(|(_, language_server)| language_server.server_id())
+ })
})
- })?
- else {
+ .unwrap()
+ .is_some()
+ {
+ return;
+ } else {
cx.background_executor()
.timer(Duration::from_millis(10))
.await;
- continue;
- };
-
- return Ok((lsp_open_handle, language_server_id, buffer));
+ }
}
- return Err(anyhow!("No language server found for buffer"));
+ panic!("No language server found for buffer");
}
-// TODO: Dedupe with similar function in crates/eval/src/instance.rs
pub fn wait_for_lang_server(
project: &Entity<Project>,
buffer: &Entity<Buffer>,
@@ -0,0 +1,119 @@
+use crate::{
+ PredictArgs,
+ example::{Example, ExampleScore},
+ headless::EpAppState,
+ metrics::{self, ClassificationMetrics},
+ predict::run_prediction,
+};
+use edit_prediction::udiff::DiffLine;
+use gpui::AsyncApp;
+use std::sync::Arc;
+
+pub async fn run_scoring(
+ example: &mut Example,
+ args: &PredictArgs,
+ app_state: Arc<EpAppState>,
+ cx: AsyncApp,
+) {
+ run_prediction(
+ example,
+ Some(args.provider),
+ args.repetitions,
+ app_state,
+ cx,
+ )
+ .await;
+
+ let expected_patch = parse_patch(&example.expected_patch);
+
+ let mut scores = vec![];
+
+ for pred in &example.predictions {
+ let actual_patch = parse_patch(&pred.actual_patch);
+ let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
+ let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
+
+ scores.push(ExampleScore {
+ delta_chr_f,
+ line_match,
+ });
+ }
+
+ example.score = scores;
+}
+
+fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
+ patch.lines().map(DiffLine::parse).collect()
+}
+
+pub fn print_report(examples: &[Example]) {
+ eprintln!(
+ "──────────────────────────────────────────────────────────────────────────────────────"
+ );
+ eprintln!(
+ "{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
+ "Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
+ );
+ eprintln!(
+ "──────────────────────────────────────────────────────────────────────────────────────"
+ );
+
+ let mut all_line_match_scores = Vec::new();
+ let mut all_delta_chr_f_scores = Vec::new();
+
+ for example in examples {
+ for score in example.score.iter() {
+ let line_match = &score.line_match;
+
+ eprintln!(
+ "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
+ truncate_name(&example.name, 30),
+ line_match.true_positives,
+ line_match.false_positives,
+ line_match.false_negatives,
+ line_match.precision() * 100.0,
+ line_match.recall() * 100.0,
+ line_match.f1_score() * 100.0,
+ score.delta_chr_f
+ );
+
+ all_line_match_scores.push(line_match.clone());
+ all_delta_chr_f_scores.push(score.delta_chr_f);
+ }
+ }
+
+ eprintln!(
+ "──────────────────────────────────────────────────────────────────────────────────────"
+ );
+
+ if !all_line_match_scores.is_empty() {
+ let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
+ let avg_delta_chr_f: f32 =
+ all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
+
+ eprintln!(
+ "{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
+ "TOTAL",
+ total_line_match.true_positives,
+ total_line_match.false_positives,
+ total_line_match.false_negatives,
+ total_line_match.precision() * 100.0,
+ total_line_match.recall() * 100.0,
+ total_line_match.f1_score() * 100.0,
+ avg_delta_chr_f
+ );
+ eprintln!(
+ "──────────────────────────────────────────────────────────────────────────────────────"
+ );
+ }
+
+ eprintln!("\n");
+}
+
+fn truncate_name(name: &str, max_len: usize) -> String {
+ if name.len() <= max_len {
+ name.to_string()
+ } else {
+ format!("{}...", &name[..max_len - 3])
+ }
+}
@@ -1,70 +0,0 @@
-use std::{fmt, fmt::Display, path::Path, str::FromStr, sync::Arc};
-
-use ::util::{paths::PathStyle, rel_path::RelPath};
-use anyhow::{Result, anyhow};
-use language::Point;
-use serde::{Deserialize, Deserializer, Serialize, Serializer};
-
-#[derive(Debug, Clone, Hash, Eq, PartialEq)]
-pub struct SourceLocation {
- pub path: Arc<RelPath>,
- pub point: Point,
-}
-
-impl Serialize for SourceLocation {
- fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
- where
- S: Serializer,
- {
- serializer.serialize_str(&self.to_string())
- }
-}
-
-impl<'de> Deserialize<'de> for SourceLocation {
- fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
- where
- D: Deserializer<'de>,
- {
- let s = String::deserialize(deserializer)?;
- s.parse().map_err(serde::de::Error::custom)
- }
-}
-
-impl Display for SourceLocation {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(
- f,
- "{}:{}:{}",
- self.path.display(PathStyle::Posix),
- self.point.row + 1,
- self.point.column + 1
- )
- }
-}
-
-impl FromStr for SourceLocation {
- type Err = anyhow::Error;
-
- fn from_str(s: &str) -> Result<Self> {
- let parts: Vec<&str> = s.split(':').collect();
- if parts.len() != 3 {
- return Err(anyhow!(
- "Invalid source location. Expected 'file.rs:line:column', got '{}'",
- s
- ));
- }
-
- let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
- let line: u32 = parts[1]
- .parse()
- .map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
- let column: u32 = parts[2]
- .parse()
- .map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
-
- // Convert from 1-based to 0-based indexing
- let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
-
- Ok(SourceLocation { path, point })
- }
-}
@@ -46,3 +46,7 @@ Output example:
## Code Context
{{context}}
+
+## Editable region
+
+{{editable_region}}
@@ -1,89 +0,0 @@
-use std::path::Path;
-
-use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
-
-#[derive(Debug, Clone, Default, clap::ValueEnum)]
-pub enum ContextType {
- #[default]
- CurrentFile,
-}
-
-const MAX_CONTEXT_SIZE: usize = 32768;
-
-pub fn collect_context(
- context_type: &ContextType,
- worktree_dir: &Path,
- cursor: SourceLocation,
-) -> String {
- let context = match context_type {
- ContextType::CurrentFile => {
- let file_path = worktree_dir.join(cursor.path.as_std_path());
- let context = std::fs::read_to_string(&file_path).unwrap_or_default();
-
- let context = add_special_tags(&context, worktree_dir, cursor);
- context
- }
- };
-
- let region_end_offset = context.find(TeacherModel::REGION_END);
-
- if context.len() <= MAX_CONTEXT_SIZE {
- return context;
- }
-
- if let Some(region_end_offset) = region_end_offset
- && region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
- {
- let to_truncate = context.len() - MAX_CONTEXT_SIZE;
- format!(
- "[...{} bytes truncated]\n{}\n",
- to_truncate,
- &context[to_truncate..]
- )
- } else {
- format!(
- "{}\n[...{} bytes truncated]\n",
- &context[..MAX_CONTEXT_SIZE],
- context.len() - MAX_CONTEXT_SIZE
- )
- }
-}
-
-/// Add <|editable_region_start/end|> tags
-fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
- let path = worktree_dir.join(cursor.path.as_std_path());
- let file = std::fs::read_to_string(&path).unwrap_or_default();
- let lines = file.lines().collect::<Vec<_>>();
- let cursor_row = cursor.point.row as usize;
- let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
- let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
-
- let snippet = lines[start_line..end_line].join("\n");
-
- if context.contains(&snippet) {
- let mut cursor_line = lines[cursor_row].to_string();
- cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
-
- let mut snippet_with_tags_lines = vec![];
- snippet_with_tags_lines.push(TeacherModel::REGION_START);
- snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
- snippet_with_tags_lines.push(&cursor_line);
- snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
- snippet_with_tags_lines.push(TeacherModel::REGION_END);
- let snippet_with_tags = snippet_with_tags_lines.join("\n");
-
- context.replace(&snippet, &snippet_with_tags)
- } else {
- log::warn!(
- "Can't find area around the cursor in the context; proceeding without special tags"
- );
- context.to_string()
- }
-}
-
-pub fn strip_special_tags(context: &str) -> String {
- context
- .replace(TeacherModel::REGION_START, "")
- .replace(TeacherModel::REGION_END, "")
- .replace(TeacherModel::USER_CURSOR, "")
-}
@@ -1,94 +0,0 @@
-use serde::Deserialize;
-use std::sync::Arc;
-
-use crate::{
- DistillArguments,
- example::Example,
- source_location::SourceLocation,
- training::{
- context::ContextType,
- llm_client::LlmClient,
- teacher::{TeacherModel, TeacherOutput},
- },
-};
-use anyhow::Result;
-use reqwest_client::ReqwestClient;
-
-#[derive(Debug, Deserialize)]
-pub struct SplitCommit {
- repo_url: String,
- commit_sha: String,
- edit_history: String,
- expected_patch: String,
- cursor_position: String,
-}
-
-pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
- let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
- .expect("Failed to read split commit dataset")
- .lines()
- .map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
- .collect();
-
- let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
-
- let llm_client = if let Some(cache_path) = arguments.batch {
- LlmClient::batch(&cache_path, http_client)?
- } else {
- LlmClient::plain(http_client)?
- };
-
- let mut teacher = TeacherModel::new(
- "claude-sonnet-4-5".to_string(),
- ContextType::CurrentFile,
- llm_client,
- );
-
- let mut num_marked_for_batching = 0;
-
- for commit in split_commits {
- if let Some(distilled) = distill_one(&mut teacher, commit).await? {
- println!("{}", serde_json::to_string(&distilled)?);
- } else {
- if num_marked_for_batching == 0 {
- log::warn!("Marked for batching");
- }
- num_marked_for_batching += 1;
- }
- }
-
- eprintln!(
- "{} requests are marked for batching",
- num_marked_for_batching
- );
- let llm_client = teacher.client;
- llm_client.sync_batches().await?;
-
- Ok(())
-}
-
-pub async fn distill_one(
- teacher: &mut TeacherModel,
- commit: SplitCommit,
-) -> Result<Option<TeacherOutput>> {
- let cursor: SourceLocation = commit
- .cursor_position
- .parse()
- .expect("Failed to parse cursor position");
-
- let path = cursor.path.to_rel_path_buf();
-
- let example = Example {
- repository_url: commit.repo_url,
- revision: commit.commit_sha,
- uncommitted_diff: commit.edit_history.clone(),
- cursor_path: path.as_std_path().to_path_buf(),
- cursor_position: commit.cursor_position,
- edit_history: commit.edit_history, // todo: trim
- expected_patch: commit.expected_patch,
- };
-
- let prediction = teacher.predict(example).await;
-
- prediction
-}
@@ -1,4 +0,0 @@
-pub mod context;
-pub mod distill;
-pub mod llm_client;
-pub mod teacher;
@@ -1,266 +0,0 @@
-use crate::{
- example::Example,
- source_location::SourceLocation,
- training::{
- context::{ContextType, collect_context, strip_special_tags},
- llm_client::LlmClient,
- },
-};
-use anthropic::{Message, RequestContent, ResponseContent, Role};
-use anyhow::Result;
-
-pub struct TeacherModel {
- pub llm_name: String,
- pub context: ContextType,
- pub client: LlmClient,
-}
-
-#[derive(Debug, serde::Serialize)]
-pub struct TeacherOutput {
- parsed_output: String,
- prompt: String,
- raw_llm_response: String,
- context: String,
- diff: String,
-}
-
-impl TeacherModel {
- const PROMPT: &str = include_str!("teacher.prompt.md");
- pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
- pub(crate) const REGION_END: &str = "<|editable_region_end|>";
- pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
-
- /// Number of lines to include before the cursor position
- pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
-
- /// Number of lines to include after the cursor position
- pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
-
- /// Truncate edit history to this number of last lines
- const MAX_HISTORY_LINES: usize = 128;
-
- pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
- TeacherModel {
- llm_name,
- context,
- client,
- }
- }
-
- pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
- let name = input.unique_name();
- let worktree_dir = input.setup_worktree(name).await?;
- let cursor: SourceLocation = input
- .cursor_position
- .parse()
- .expect("Failed to parse cursor position");
-
- let context = collect_context(&self.context, &worktree_dir, cursor.clone());
- let edit_history = Self::format_edit_history(&input.edit_history);
-
- let prompt = Self::PROMPT
- .replace("{{context}}", &context)
- .replace("{{edit_history}}", &edit_history);
-
- let messages = vec![Message {
- role: Role::User,
- content: vec![RequestContent::Text {
- text: prompt.clone(),
- cache_control: None,
- }],
- }];
-
- let Some(response) = self
- .client
- .generate(self.llm_name.clone(), 16384, messages)
- .await?
- else {
- return Ok(None);
- };
-
- let response_text = response
- .content
- .into_iter()
- .filter_map(|content| match content {
- ResponseContent::Text { text } => Some(text),
- _ => None,
- })
- .collect::<Vec<String>>()
- .join("\n");
-
- let parsed_output = self.parse_response(&response_text);
-
- let original_editable_region = Self::extract_editable_region(&context);
- let context_after_edit = context.replace(&original_editable_region, &parsed_output);
- let context_after_edit = strip_special_tags(&context_after_edit);
- let context_before_edit = strip_special_tags(&context);
- let diff = language::unified_diff(&context_before_edit, &context_after_edit);
-
- // zeta distill --batch batch_results.txt
- // zeta distill
- // 1. Run `zeta distill <2000 examples <- all examples>` for the first time
- // - store LLM requests in a batch, don't actual send the request
- // - send the batch (2000 requests) after all inputs are processed
- // 2. `zeta send-batches`
- // - upload the batch to Anthropic
-
- // https://platform.claude.com/docs/en/build-with-claude/batch-processing
- // https://crates.io/crates/anthropic-sdk-rust
-
- // - poll for results
- // - when ready, store results in cache (a database)
- // 3. `zeta distill` again
- // - use the cached results this time
-
- Ok(Some(TeacherOutput {
- parsed_output,
- prompt,
- raw_llm_response: response_text,
- context,
- diff,
- }))
- }
-
- fn parse_response(&self, content: &str) -> String {
- let codeblock = Self::extract_last_codeblock(content);
- let editable_region = Self::extract_editable_region(&codeblock);
-
- editable_region
- }
-
- /// Extract content from the last code-fenced block if any, or else return content as is
- fn extract_last_codeblock(text: &str) -> String {
- let mut last_block = None;
- let mut search_start = 0;
-
- while let Some(start) = text[search_start..].find("```") {
- let start = start + search_start;
- let bytes = text.as_bytes();
- let mut backtick_end = start;
-
- while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
- backtick_end += 1;
- }
-
- let backtick_count = backtick_end - start;
- let closing_backticks = "`".repeat(backtick_count);
-
- if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
- let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
- last_block = Some(code_block.to_string());
- search_start = backtick_end + end_pos + backtick_count;
- } else {
- break;
- }
- }
-
- last_block.unwrap_or_else(|| text.to_string())
- }
-
- fn extract_editable_region(text: &str) -> String {
- let start = text
- .find(Self::REGION_START)
- .map_or(0, |pos| pos + Self::REGION_START.len());
- let end = text.find(Self::REGION_END).unwrap_or(text.len());
-
- text[start..end].to_string()
- }
-
- /// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
- fn format_edit_history(edit_history: &str) -> String {
- let lines = edit_history
- .lines()
- .filter(|&s| Self::is_content_line(s))
- .collect::<Vec<_>>();
-
- let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
- &lines[lines.len() - Self::MAX_HISTORY_LINES..]
- } else {
- &lines
- };
- history_lines.join("\n")
- }
-
- fn is_content_line(s: &str) -> bool {
- s.starts_with("-")
- || s.starts_with("+")
- || s.starts_with(" ")
- || s.starts_with("---")
- || s.starts_with("+++")
- || s.starts_with("@@")
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_parse_response() {
- let teacher = TeacherModel::new(
- "test".to_string(),
- ContextType::CurrentFile,
- LlmClient::dummy(),
- );
- let response = "This is a test response.";
- let parsed = teacher.parse_response(response);
- assert_eq!(parsed, response.to_string());
-
- let response = indoc::indoc! {"
- Some thinking
-
- `````
- actual response
- `````
- "};
- let parsed = teacher.parse_response(response);
- assert_eq!(parsed, "actual response");
- }
-
- #[test]
- fn test_extract_last_code_block() {
- let text = indoc::indoc! {"
- Some thinking
-
- ```
- first block
- ```
-
- `````
- last block
- `````
- "};
- let last_block = TeacherModel::extract_last_codeblock(text);
- assert_eq!(last_block, "last block");
- }
-
- #[test]
- fn test_extract_editable_region() {
- let teacher = TeacherModel::new(
- "test".to_string(),
- ContextType::CurrentFile,
- LlmClient::dummy(),
- );
- let response = indoc::indoc! {"
- some lines
- are
- here
- <|editable_region_start|>
- one
- two three
-
- <|editable_region_end|>
- more
- lines here
- "};
- let parsed = teacher.parse_response(response);
- assert_eq!(
- parsed,
- indoc::indoc! {"
- one
- two three
-
- "}
- );
- }
-}
@@ -26,6 +26,7 @@ serde.workspace = true
smallvec.workspace = true
tree-sitter.workspace = true
util.workspace = true
+zeta_prompt.workspace = true
[dev-dependencies]
env_logger.workspace = true
@@ -1,6 +1,6 @@
-use crate::RelatedExcerpt;
use language::{BufferSnapshot, OffsetRangeExt as _, Point};
use std::ops::Range;
+use zeta_prompt::RelatedExcerpt;
#[cfg(not(test))]
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512;
@@ -76,14 +76,9 @@ pub fn assemble_excerpts(
input_ranges
.into_iter()
- .map(|range| {
- let offset_range = range.to_offset(buffer);
- RelatedExcerpt {
- point_range: range,
- anchor_range: buffer.anchor_before(offset_range.start)
- ..buffer.anchor_after(offset_range.end),
- text: buffer.as_rope().slice(offset_range),
- }
+ .map(|range| RelatedExcerpt {
+ row_range: range.start.row..range.end.row,
+ text: buffer.text_for_range(range).collect(),
})
.collect()
}
@@ -3,13 +3,13 @@ use anyhow::Result;
use collections::HashMap;
use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
-use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
+use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
use project::{LocationLink, Project, ProjectPath};
-use serde::{Serialize, Serializer};
use smallvec::SmallVec;
use std::{
collections::hash_map,
ops::Range,
+ path::Path,
sync::Arc,
time::{Duration, Instant},
};
@@ -24,12 +24,14 @@ mod fake_definition_lsp;
pub use cloud_llm_client::predict_edits_v3::Line;
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
+pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
const IDENTIFIER_LINE_COUNT: u32 = 3;
pub struct RelatedExcerptStore {
project: WeakEntity<Project>,
- related_files: Vec<RelatedFile>,
+ related_files: Arc<[RelatedFile]>,
+ related_file_buffers: Vec<Entity<Buffer>>,
cache: HashMap<Identifier, Arc<CacheEntry>>,
update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
identifier_line_count: u32,
@@ -68,82 +70,6 @@ struct CachedDefinition {
anchor_range: Range<Anchor>,
}
-#[derive(Clone, Debug, Serialize)]
-pub struct RelatedFile {
- #[serde(serialize_with = "serialize_project_path")]
- pub path: ProjectPath,
- #[serde(skip)]
- pub buffer: WeakEntity<Buffer>,
- pub excerpts: Vec<RelatedExcerpt>,
- pub max_row: u32,
-}
-
-impl RelatedFile {
- pub fn merge_excerpts(&mut self) {
- self.excerpts.sort_unstable_by(|a, b| {
- a.point_range
- .start
- .cmp(&b.point_range.start)
- .then(b.point_range.end.cmp(&a.point_range.end))
- });
-
- let mut index = 1;
- while index < self.excerpts.len() {
- if self.excerpts[index - 1]
- .point_range
- .end
- .cmp(&self.excerpts[index].point_range.start)
- .is_ge()
- {
- let removed = self.excerpts.remove(index);
- if removed
- .point_range
- .end
- .cmp(&self.excerpts[index - 1].point_range.end)
- .is_gt()
- {
- self.excerpts[index - 1].point_range.end = removed.point_range.end;
- self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
- }
- } else {
- index += 1;
- }
- }
- }
-}
-
-#[derive(Clone, Debug, Serialize)]
-pub struct RelatedExcerpt {
- #[serde(skip)]
- pub anchor_range: Range<Anchor>,
- #[serde(serialize_with = "serialize_point_range")]
- pub point_range: Range<Point>,
- #[serde(serialize_with = "serialize_rope")]
- pub text: Rope,
-}
-
-fn serialize_project_path<S: Serializer>(
- project_path: &ProjectPath,
- serializer: S,
-) -> Result<S::Ok, S::Error> {
- project_path.path.serialize(serializer)
-}
-
-fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
- rope.to_string().serialize(serializer)
-}
-
-fn serialize_point_range<S: Serializer>(
- range: &Range<Point>,
- serializer: S,
-) -> Result<S::Ok, S::Error> {
- [
- [range.start.row, range.start.column],
- [range.end.row, range.end.column],
- ]
- .serialize(serializer)
-}
-
const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
@@ -179,7 +105,8 @@ impl RelatedExcerptStore {
RelatedExcerptStore {
project: project.downgrade(),
update_tx,
- related_files: Vec::new(),
+ related_files: Vec::new().into(),
+ related_file_buffers: Vec::new(),
cache: Default::default(),
identifier_line_count: IDENTIFIER_LINE_COUNT,
}
@@ -193,8 +120,21 @@ impl RelatedExcerptStore {
self.update_tx.unbounded_send((buffer, position)).ok();
}
- pub fn related_files(&self) -> &[RelatedFile] {
- &self.related_files
+ pub fn related_files(&self) -> Arc<[RelatedFile]> {
+ self.related_files.clone()
+ }
+
+ pub fn related_files_with_buffers(
+ &self,
+ ) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
+ self.related_files
+ .iter()
+ .cloned()
+ .zip(self.related_file_buffers.iter().cloned())
+ }
+
+ pub fn set_related_files(&mut self, files: Vec<RelatedFile>) {
+ self.related_files = files.into();
}
async fn fetch_excerpts(
@@ -297,7 +237,8 @@ impl RelatedExcerptStore {
}
mean_definition_latency /= cache_miss_count.max(1) as u32;
- let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
+ let (new_cache, related_files, related_file_buffers) =
+ rebuild_related_files(&project, new_cache, cx).await?;
if let Some(file) = &file {
log::debug!(
@@ -309,7 +250,8 @@ impl RelatedExcerptStore {
this.update(cx, |this, cx| {
this.cache = new_cache;
- this.related_files = related_files;
+ this.related_files = related_files.into();
+ this.related_file_buffers = related_file_buffers;
cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
cache_hit_count,
cache_miss_count,
@@ -323,10 +265,16 @@ impl RelatedExcerptStore {
}
async fn rebuild_related_files(
+ project: &Entity<Project>,
new_entries: HashMap<Identifier, Arc<CacheEntry>>,
cx: &mut AsyncApp,
-) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
+) -> Result<(
+ HashMap<Identifier, Arc<CacheEntry>>,
+ Vec<RelatedFile>,
+ Vec<Entity<Buffer>>,
+)> {
let mut snapshots = HashMap::default();
+ let mut worktree_root_names = HashMap::default();
for entry in new_entries.values() {
for definition in &entry.definitions {
if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
@@ -340,12 +288,22 @@ async fn rebuild_related_files(
.read_with(cx, |buffer, _| buffer.snapshot())?,
);
}
+ let worktree_id = definition.path.worktree_id;
+ if let hash_map::Entry::Vacant(e) =
+ worktree_root_names.entry(definition.path.worktree_id)
+ {
+ project.read_with(cx, |project, cx| {
+ if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
+ e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
+ }
+ })?;
+ }
}
}
Ok(cx
.background_spawn(async move {
- let mut files = Vec::<RelatedFile>::new();
+ let mut files = Vec::new();
let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
let mut paths_by_buffer = HashMap::default();
for entry in new_entries.values() {
@@ -369,16 +327,31 @@ async fn rebuild_related_files(
continue;
};
let excerpts = assemble_excerpts(snapshot, ranges);
- files.push(RelatedFile {
- path: project_path.clone(),
- buffer: buffer.downgrade(),
- excerpts,
- max_row: snapshot.max_point().row,
- });
+ let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else {
+ continue;
+ };
+
+ let path = Path::new(&format!(
+ "{}/{}",
+ root_name,
+ project_path.path.as_unix_str()
+ ))
+ .into();
+
+ files.push((
+ buffer,
+ RelatedFile {
+ path,
+ excerpts,
+ max_row: snapshot.max_point().row,
+ },
+ ));
}
- files.sort_by_key(|file| file.path.clone());
- (new_entries, files)
+ files.sort_by_key(|(_, file)| file.path.clone());
+ let (related_buffers, related_files) = files.into_iter().unzip();
+
+ (new_entries, related_files, related_buffers)
})
.await)
}
@@ -48,7 +48,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
&excerpts,
&[
(
- "src/company.rs",
+ "root/src/company.rs",
&[indoc! {"
pub struct Company {
owner: Arc<Person>,
@@ -56,7 +56,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
}"}],
),
(
- "src/main.rs",
+ "root/src/main.rs",
&[
indoc! {"
pub struct Session {
@@ -71,7 +71,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
],
),
(
- "src/person.rs",
+ "root/src/person.rs",
&[
indoc! {"
impl Person {
@@ -446,7 +446,7 @@ fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &
.iter()
.map(|excerpt| excerpt.text.to_string())
.collect::<Vec<_>>();
- (file.path.path.as_unix_str(), excerpts)
+ (file.path.to_str().unwrap(), excerpts)
})
.collect::<Vec<_>>();
let expected_excerpts = expected_files
@@ -492,10 +492,10 @@ fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
if excerpt.text.is_empty() {
continue;
}
- if current_row < excerpt.point_range.start.row {
+ if current_row < excerpt.row_range.start {
writeln!(&mut output, "…").unwrap();
}
- current_row = excerpt.point_range.start.row;
+ current_row = excerpt.row_range.start;
for line in excerpt.text.to_string().lines() {
output.push_str(line);
@@ -17,7 +17,6 @@ anyhow.workspace = true
buffer_diff.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
-cloud_zeta2_prompt.workspace = true
codestral.workspace = true
command_palette_hooks.workspace = true
copilot.workspace = true
@@ -46,6 +45,7 @@ ui_input.workspace = true
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true
+zeta_prompt.workspace = true
[dev-dependencies]
copilot = { workspace = true, features = ["test-support"] }
@@ -17,7 +17,7 @@ use gpui::{
};
use multi_buffer::MultiBuffer;
use project::Project;
-use text::OffsetRangeExt;
+use text::Point;
use ui::{
ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName,
StyledTypography as _, h_flex, v_flex,
@@ -66,7 +66,7 @@ impl EditPredictionContextView {
) -> Self {
let store = EditPredictionStore::global(client, user_store, cx);
- let mut debug_rx = store.update(cx, |store, _| store.debug_info());
+ let mut debug_rx = store.update(cx, |store, cx| store.debug_info(&project, cx));
let _update_task = cx.spawn_in(window, async move |this, cx| {
while let Some(event) = debug_rx.next().await {
this.update_in(cx, |this, window, cx| {
@@ -103,7 +103,8 @@ impl EditPredictionContextView {
self.handle_context_retrieval_finished(info, window, cx);
}
}
- DebugEvent::EditPredictionRequested(_) => {}
+ DebugEvent::EditPredictionStarted(_) => {}
+ DebugEvent::EditPredictionFinished(_) => {}
}
}
@@ -152,12 +153,11 @@ impl EditPredictionContextView {
run.finished_at = Some(info.timestamp);
run.metadata = info.metadata;
- let project = self.project.clone();
let related_files = self
.store
.read(cx)
- .context_for_project(&self.project, cx)
- .to_vec();
+ .context_for_project_with_buffers(&self.project, cx)
+ .map_or(Vec::new(), |files| files.collect());
let editor = run.editor.clone();
let multibuffer = run.editor.read(cx).buffer().clone();
@@ -168,33 +168,14 @@ impl EditPredictionContextView {
cx.spawn_in(window, async move |this, cx| {
let mut paths = Vec::new();
- for related_file in related_files {
- let (buffer, point_ranges): (_, Vec<_>) =
- if let Some(buffer) = related_file.buffer.upgrade() {
- let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
-
- (
- buffer,
- related_file
- .excerpts
- .iter()
- .map(|excerpt| excerpt.anchor_range.to_point(&snapshot))
- .collect(),
- )
- } else {
- (
- project
- .update(cx, |project, cx| {
- project.open_buffer(related_file.path.clone(), cx)
- })?
- .await?,
- related_file
- .excerpts
- .iter()
- .map(|excerpt| excerpt.point_range.clone())
- .collect(),
- )
- };
+ for (related_file, buffer) in related_files {
+ let point_ranges = related_file
+ .excerpts
+ .iter()
+ .map(|excerpt| {
+ Point::new(excerpt.row_range.start, 0)..Point::new(excerpt.row_range.end, 0)
+ })
+ .collect::<Vec<_>>();
cx.update(|_, cx| {
let path = PathKey::for_buffer(&buffer, cx);
paths.push((path, buffer, point_ranges));
@@ -1,5 +1,4 @@
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
-use cloud_zeta2_prompt::write_codeblock;
use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore};
use editor::{Editor, ExcerptRange, MultiBuffer};
use feature_flags::FeatureFlag;
@@ -362,14 +361,14 @@ impl RatePredictionsModal {
write!(&mut formatted_inputs, "## Events\n\n").unwrap();
for event in &prediction.inputs.events {
- write!(&mut formatted_inputs, "```diff\n{event}```\n\n").unwrap();
+ formatted_inputs.push_str("```diff\n");
+ zeta_prompt::write_event(&mut formatted_inputs, event.as_ref());
+ formatted_inputs.push_str("```\n\n");
}
- write!(&mut formatted_inputs, "## Included files\n\n").unwrap();
-
- for included_file in &prediction.inputs.included_files {
- let cursor_insertions = &[(prediction.inputs.cursor_point, "<|CURSOR|>")];
+ write!(&mut formatted_inputs, "## Related files\n\n").unwrap();
+ for included_file in prediction.inputs.related_files.as_ref() {
write!(
&mut formatted_inputs,
"### {}\n\n",
@@ -377,20 +376,28 @@ impl RatePredictionsModal {
)
.unwrap();
- write_codeblock(
- &included_file.path,
- &included_file.excerpts,
- if included_file.path == prediction.inputs.cursor_path {
- cursor_insertions.as_slice()
- } else {
- &[]
- },
- included_file.max_row,
- false,
- &mut formatted_inputs,
- );
+ for excerpt in included_file.excerpts.iter() {
+ write!(
+ &mut formatted_inputs,
+ "```{}\n{}\n```\n",
+ included_file.path.display(),
+ excerpt.text
+ )
+ .unwrap();
+ }
}
+ write!(&mut formatted_inputs, "## Cursor Excerpt\n\n").unwrap();
+
+ writeln!(
+ &mut formatted_inputs,
+ "```{}\n{}<CURSOR>{}\n```\n",
+ prediction.inputs.cursor_path.display(),
+ &prediction.inputs.cursor_excerpt[..prediction.inputs.cursor_offset_in_excerpt],
+ &prediction.inputs.cursor_excerpt[prediction.inputs.cursor_offset_in_excerpt..],
+ )
+ .unwrap();
+
self.active_prediction = Some(ActivePrediction {
prediction,
feedback_editor: cx.new(|cx| {
@@ -0,0 +1,15 @@
+[package]
+name = "zeta_prompt"
+version = "0.1.0"
+publish.workspace = true
+edition.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/zeta_prompt.rs"
+
+[dependencies]
+serde.workspace = true
@@ -0,0 +1,165 @@
+use serde::{Deserialize, Serialize};
+use std::fmt::Write;
+use std::ops::Range;
+use std::path::Path;
+use std::sync::Arc;
+
+pub const CURSOR_MARKER: &str = "<|user_cursor|>";
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct ZetaPromptInput {
+ pub cursor_path: Arc<Path>,
+ pub cursor_excerpt: Arc<str>,
+ pub editable_range_in_excerpt: Range<usize>,
+ pub cursor_offset_in_excerpt: usize,
+ pub events: Vec<Arc<Event>>,
+ pub related_files: Arc<[RelatedFile]>,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+#[serde(tag = "event")]
+pub enum Event {
+ BufferChange {
+ path: Arc<Path>,
+ old_path: Arc<Path>,
+ diff: String,
+ predicted: bool,
+ in_open_source_repo: bool,
+ },
+}
+
+pub fn write_event(prompt: &mut String, event: &Event) {
+ fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
+ for component in path.components() {
+ prompt.push('/');
+ write!(prompt, "{}", component.as_os_str().display()).ok();
+ }
+ }
+ match event {
+ Event::BufferChange {
+ path,
+ old_path,
+ diff,
+ predicted,
+ in_open_source_repo: _,
+ } => {
+ if *predicted {
+ prompt.push_str("// User accepted prediction:\n");
+ }
+ prompt.push_str("--- a");
+ write_path_as_unix_str(prompt, old_path.as_ref());
+ prompt.push_str("\n+++ b");
+ write_path_as_unix_str(prompt, path.as_ref());
+ prompt.push('\n');
+ prompt.push_str(diff);
+ }
+ }
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct RelatedFile {
+ pub path: Arc<Path>,
+ pub max_row: u32,
+ pub excerpts: Vec<RelatedExcerpt>,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct RelatedExcerpt {
+ pub row_range: Range<u32>,
+ pub text: String,
+}
+
+pub fn format_zeta_prompt(input: &ZetaPromptInput) -> String {
+ let mut prompt = String::new();
+ write_related_files(&mut prompt, &input.related_files);
+ write_edit_history_section(&mut prompt, input);
+ write_cursor_excerpt_section(&mut prompt, input);
+ prompt
+}
+
+pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) {
+ push_delimited(prompt, "related_files", &[], |prompt| {
+ for file in related_files {
+ let path_str = file.path.to_string_lossy();
+ push_delimited(prompt, "related_file", &[("path", &path_str)], |prompt| {
+ for excerpt in &file.excerpts {
+ push_delimited(
+ prompt,
+ "related_excerpt",
+ &[(
+ "lines",
+ &format!(
+ "{}-{}",
+ excerpt.row_range.start + 1,
+ excerpt.row_range.end + 1
+ ),
+ )],
+ |prompt| {
+ prompt.push_str(&excerpt.text);
+ prompt.push('\n');
+ },
+ );
+ }
+ });
+ }
+ });
+}
+
+fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
+ push_delimited(prompt, "edit_history", &[], |prompt| {
+ if input.events.is_empty() {
+ prompt.push_str("(No edit history)");
+ } else {
+ for event in &input.events {
+ write_event(prompt, event);
+ }
+ }
+ });
+}
+
+fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
+ push_delimited(prompt, "cursor_excerpt", &[], |prompt| {
+ let path_str = input.cursor_path.to_string_lossy();
+ push_delimited(prompt, "file", &[("path", &path_str)], |prompt| {
+ prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+ push_delimited(prompt, "editable_region", &[], |prompt| {
+ prompt.push_str(
+ &input.cursor_excerpt
+ [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
+ );
+ prompt.push_str(CURSOR_MARKER);
+ prompt.push_str(
+ &input.cursor_excerpt
+ [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
+ );
+ });
+ prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+ });
+ });
+}
+
+fn push_delimited(
+ prompt: &mut String,
+ tag: &'static str,
+ arguments: &[(&str, &str)],
+ cb: impl FnOnce(&mut String),
+) {
+ if !prompt.ends_with("\n") {
+ prompt.push('\n');
+ }
+ prompt.push('<');
+ prompt.push_str(tag);
+ for (arg_name, arg_value) in arguments {
+ write!(prompt, " {}=\"{}\"", arg_name, arg_value).ok();
+ }
+ prompt.push_str(">\n");
+
+ cb(prompt);
+
+ if !prompt.ends_with('\n') {
+ prompt.push('\n');
+ }
+ prompt.push_str("</");
+ prompt.push_str(tag);
+ prompt.push_str(">\n");
+}