Detailed changes
@@ -1,5 +1,5 @@
use std::{
- collections::{BTreeSet, HashMap},
+ collections::HashMap,
io::{IsTerminal, Write},
sync::Arc,
};
@@ -125,21 +125,10 @@ fn write_aggregated_scores(
.peekable();
let has_edit_predictions = edit_predictions.peek().is_some();
let aggregated_result = EvaluationResult {
- context: Scores::aggregate(successful.iter().map(|r| &r.context)),
edit_prediction: has_edit_predictions.then(|| Scores::aggregate(edit_predictions)),
prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
/ successful.len(),
- context_lines_found_in_context: successful
- .iter()
- .map(|r| r.context_lines_found_in_context)
- .sum::<usize>()
- / successful.len(),
- context_lines_in_expected_patch: successful
- .iter()
- .map(|r| r.context_lines_in_expected_patch)
- .sum::<usize>()
- / successful.len(),
};
writeln!(w, "\n{}", "-".repeat(80))?;
@@ -261,11 +250,8 @@ fn write_eval_result(
#[derive(Debug, Default)]
pub struct EvaluationResult {
pub edit_prediction: Option<Scores>,
- pub context: Scores,
pub prompt_len: usize,
pub generated_len: usize,
- pub context_lines_in_expected_patch: usize,
- pub context_lines_found_in_context: usize,
}
#[derive(Default, Debug)]
@@ -363,14 +349,6 @@ impl std::fmt::Display for EvaluationResult {
impl EvaluationResult {
fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(
- f,
- r#"
-### Context Scores
-{}
-"#,
- self.context.to_markdown(),
- )?;
if let Some(prediction) = &self.edit_prediction {
write!(
f,
@@ -387,34 +365,18 @@ impl EvaluationResult {
writeln!(f, "### Scores\n")?;
writeln!(
f,
- " Prompt Generated RetrievedContext PatchContext TP FP FN Precision Recall F1"
+ " Prompt Generated TP FP FN Precision Recall F1"
)?;
writeln!(
f,
- "āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā"
- )?;
- writeln!(
- f,
- "Context Retrieval {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
- "",
- "",
- "",
- "",
- self.context.true_positives,
- self.context.false_positives,
- self.context.false_negatives,
- self.context.precision() * 100.0,
- self.context.recall() * 100.0,
- self.context.f1_score() * 100.0
+ "āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā"
)?;
if let Some(edit_prediction) = &self.edit_prediction {
writeln!(
f,
- "Edit Prediction {:<7} {:<9} {:<16} {:<16} {:<6} {:<6} {:<6} {:>10.2} {:>7.2} {:>7.2}",
+ "Edit Prediction {:<7} {:<9} {:<6} {:<6} {:<6} {:>9.2} {:>8.2} {:>7.2}",
self.prompt_len,
self.generated_len,
- self.context_lines_found_in_context,
- self.context_lines_in_expected_patch,
edit_prediction.true_positives,
edit_prediction.false_positives,
edit_prediction.false_negatives,
@@ -434,53 +396,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval
..Default::default()
};
- let actual_context_lines: HashSet<_> = preds
- .excerpts
- .iter()
- .flat_map(|excerpt| {
- excerpt
- .text
- .lines()
- .map(|line| format!("{}: {line}", excerpt.path.display()))
- })
- .collect();
-
- let mut false_positive_lines = actual_context_lines.clone();
-
- for entry in &example.expected_context {
- let mut best_alternative_score: Option<Scores> = None;
-
- for alternative in &entry.alternatives {
- let expected: HashSet<_> = alternative
- .excerpts
- .iter()
- .flat_map(|excerpt| {
- excerpt
- .text
- .lines()
- .map(|line| format!("{}: {line}", excerpt.path.display()))
- })
- .collect();
-
- let scores = Scores::new(&expected, &actual_context_lines);
-
- false_positive_lines.retain(|line| !expected.contains(line));
-
- if best_alternative_score
- .as_ref()
- .is_none_or(|best| scores.recall() > best.recall())
- {
- best_alternative_score = Some(scores);
- }
- }
-
- let best_alternative = best_alternative_score.unwrap_or_default();
- eval_result.context.false_negatives += best_alternative.false_negatives;
- eval_result.context.true_positives += best_alternative.true_positives;
- }
-
- eval_result.context.false_positives = false_positive_lines.len();
-
if predict {
// todo: alternatives for patches
let expected_patch = example
@@ -493,25 +408,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
.map(|line| line.to_string())
.collect();
- let expected_context_lines = expected_patch
- .iter()
- .filter_map(|line| {
- if let DiffLine::Context(str) = line {
- Some(String::from(*str))
- } else {
- None
- }
- })
- .collect::<BTreeSet<_>>();
- let actual_context_lines = preds
- .excerpts
- .iter()
- .flat_map(|excerpt| excerpt.text.lines().map(ToOwned::to_owned))
- .collect::<BTreeSet<_>>();
-
- let matched = expected_context_lines
- .intersection(&actual_context_lines)
- .count();
let actual_patch_lines = preds
.diff
@@ -522,8 +418,6 @@ fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> Eval
.collect();
eval_result.edit_prediction = Some(Scores::new(&expected_patch_lines, &actual_patch_lines));
- eval_result.context_lines_in_expected_patch = expected_context_lines.len();
- eval_result.context_lines_found_in_context = matched;
}
eval_result
@@ -14,7 +14,6 @@ use anyhow::{Context as _, Result, anyhow};
use clap::ValueEnum;
use cloud_zeta2_prompt::CURSOR_MARKER;
use collections::HashMap;
-use edit_prediction_context::Line;
use futures::{
AsyncWriteExt as _,
lock::{Mutex, OwnedMutexGuard},
@@ -53,7 +52,6 @@ pub struct Example {
pub cursor_position: String,
pub edit_history: String,
pub expected_patch: String,
- pub expected_context: Vec<ExpectedContextEntry>,
}
pub type ActualExcerpt = Excerpt;
@@ -64,25 +62,6 @@ pub struct Excerpt {
pub text: String,
}
-#[derive(Default, Clone, Debug, Serialize, Deserialize)]
-pub struct ExpectedContextEntry {
- pub heading: String,
- pub alternatives: Vec<ExpectedExcerptSet>,
-}
-
-#[derive(Default, Clone, Debug, Serialize, Deserialize)]
-pub struct ExpectedExcerptSet {
- pub heading: String,
- pub excerpts: Vec<ExpectedExcerpt>,
-}
-
-#[derive(Clone, Debug, Serialize, Deserialize)]
-pub struct ExpectedExcerpt {
- pub path: PathBuf,
- pub text: String,
- pub required_lines: Vec<Line>,
-}
-
#[derive(ValueEnum, Debug, Clone)]
pub enum ExampleFormat {
Json,
@@ -132,7 +111,6 @@ impl NamedExample {
cursor_position: String::new(),
edit_history: String::new(),
expected_patch: String::new(),
- expected_context: Vec::new(),
},
};
@@ -197,30 +175,10 @@ impl NamedExample {
};
}
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
- let heading = mem::take(&mut text);
- match current_section {
- Section::ExpectedExcerpts => {
- named.example.expected_context.push(ExpectedContextEntry {
- heading,
- alternatives: Vec::new(),
- });
- }
- _ => {}
- }
+ mem::take(&mut text);
}
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
- let heading = mem::take(&mut text);
- match current_section {
- Section::ExpectedExcerpts => {
- let expected_context = &mut named.example.expected_context;
- let last_entry = expected_context.last_mut().unwrap();
- last_entry.alternatives.push(ExpectedExcerptSet {
- heading,
- excerpts: Vec::new(),
- })
- }
- _ => {}
- }
+ mem::take(&mut text);
}
Event::End(TagEnd::Heading(level)) => {
anyhow::bail!("Unexpected heading level: {level}");
@@ -253,41 +211,7 @@ impl NamedExample {
named.example.cursor_position = mem::take(&mut text);
}
Section::ExpectedExcerpts => {
- let text = mem::take(&mut text);
- for excerpt in text.split("\nā¦\n") {
- let (mut text, required_lines) = extract_required_lines(&excerpt);
- if !text.ends_with('\n') {
- text.push('\n');
- }
-
- if named.example.expected_context.is_empty() {
- named.example.expected_context.push(Default::default());
- }
-
- let alternatives = &mut named
- .example
- .expected_context
- .last_mut()
- .unwrap()
- .alternatives;
-
- if alternatives.is_empty() {
- alternatives.push(ExpectedExcerptSet {
- heading: String::new(),
- excerpts: vec![],
- });
- }
-
- alternatives
- .last_mut()
- .unwrap()
- .excerpts
- .push(ExpectedExcerpt {
- path: block_info.into(),
- text,
- required_lines,
- });
- }
+ mem::take(&mut text);
}
Section::ExpectedPatch => {
named.example.expected_patch = mem::take(&mut text);
@@ -561,47 +485,6 @@ impl NamedExample {
}
}
-fn extract_required_lines(text: &str) -> (String, Vec<Line>) {
- const MARKER: &str = "[ZETA]";
- let mut new_text = String::new();
- let mut required_lines = Vec::new();
- let mut skipped_lines = 0_u32;
-
- for (row, mut line) in text.split('\n').enumerate() {
- if let Some(marker_column) = line.find(MARKER) {
- let mut strip_column = marker_column;
-
- while strip_column > 0 {
- let prev_char = line[strip_column - 1..].chars().next().unwrap();
- if prev_char.is_whitespace() || ['/', '#'].contains(&prev_char) {
- strip_column -= 1;
- } else {
- break;
- }
- }
-
- let metadata = &line[marker_column + MARKER.len()..];
- if metadata.contains("required") {
- required_lines.push(Line(row as u32 - skipped_lines));
- }
-
- if strip_column == 0 {
- skipped_lines += 1;
- continue;
- }
-
- line = &line[..strip_column];
- }
-
- new_text.push_str(line);
- new_text.push('\n');
- }
-
- new_text.pop();
-
- (new_text, required_lines)
-}
-
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = smol::process::Command::new("git")
.current_dir(repo_path)
@@ -656,37 +539,6 @@ impl Display for NamedExample {
)?;
}
- if !self.example.expected_context.is_empty() {
- write!(f, "\n## {EXPECTED_CONTEXT_HEADING}\n\n")?;
-
- for entry in &self.example.expected_context {
- write!(f, "\n### {}\n\n", entry.heading)?;
-
- let skip_h4 =
- entry.alternatives.len() == 1 && entry.alternatives[0].heading.is_empty();
-
- for excerpt_set in &entry.alternatives {
- if !skip_h4 {
- write!(f, "\n#### {}\n\n", excerpt_set.heading)?;
- }
-
- for excerpt in &excerpt_set.excerpts {
- write!(
- f,
- "`````{}{}\n{}`````\n\n",
- excerpt
- .path
- .extension()
- .map(|ext| format!("{} ", ext.to_string_lossy()))
- .unwrap_or_default(),
- excerpt.path.display(),
- excerpt.text
- )?;
- }
- }
- }
- }
-
Ok(())
}
}
@@ -707,38 +559,3 @@ pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
.lock_owned()
.await
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use indoc::indoc;
- use pretty_assertions::assert_eq;
-
- #[test]
- fn test_extract_required_lines() {
- let input = indoc! {"
- zero
- one // [ZETA] required
- two
- // [ZETA] something
- three
- four # [ZETA] required
- five
- "};
-
- let expected_updated_input = indoc! {"
- zero
- one
- two
- three
- four
- five
- "};
-
- let expected_required_lines = vec![Line(1), Line(4)];
-
- let (updated_input, required_lines) = extract_required_lines(input);
- assert_eq!(updated_input, expected_updated_input);
- assert_eq!(required_lines, expected_required_lines);
- }
-}
@@ -128,8 +128,6 @@ pub struct PredictArguments {
#[derive(Clone, Debug, Args)]
pub struct PredictionOptions {
- #[arg(long)]
- use_expected_context: bool,
#[clap(flatten)]
zeta2: Zeta2Args,
#[clap(long)]
@@ -1,4 +1,4 @@
-use crate::example::{ActualExcerpt, ExpectedExcerpt, NamedExample};
+use crate::example::{ActualExcerpt, NamedExample};
use crate::headless::ZetaCliAppState;
use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
use crate::{
@@ -7,16 +7,13 @@ use crate::{
use ::serde::Serialize;
use anyhow::{Context, Result, anyhow};
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
-use collections::HashMap;
use futures::StreamExt as _;
use gpui::{AppContext, AsyncApp, Entity};
-use language::{Anchor, Buffer, Point};
use project::Project;
use project::buffer_store::BufferStoreEvent;
use serde::Deserialize;
use std::fs;
use std::io::{IsTerminal, Write};
-use std::ops::Range;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
@@ -204,15 +201,12 @@ pub async fn perform_predict(
let mut result = result.lock().unwrap();
result.generated_len = response.chars().count();
- if !options.use_expected_context {
- result.planning_search_time = Some(
- search_queries_generated_at.unwrap() - start_time.unwrap(),
- );
- result.running_search_time = Some(
- search_queries_executed_at.unwrap()
- - search_queries_generated_at.unwrap(),
- );
- }
+ result.planning_search_time =
+ Some(search_queries_generated_at.unwrap() - start_time.unwrap());
+ result.running_search_time = Some(
+ search_queries_executed_at.unwrap()
+ - search_queries_generated_at.unwrap(),
+ );
result.prediction_time = prediction_finished_at - prediction_started_at;
result.total_time = prediction_finished_at - start_time.unwrap();
@@ -224,37 +218,10 @@ pub async fn perform_predict(
}
});
- if options.use_expected_context {
- let context_excerpts_tasks = example
- .example
- .expected_context
- .iter()
- .flat_map(|section| {
- section.alternatives[0].excerpts.iter().map(|excerpt| {
- resolve_context_entry(project.clone(), excerpt.clone(), cx.clone())
- })
- })
- .collect::<Vec<_>>();
- let context_excerpts_vec =
- futures::future::try_join_all(context_excerpts_tasks).await?;
-
- let mut context_excerpts = HashMap::default();
- for (buffer, mut excerpts) in context_excerpts_vec {
- context_excerpts
- .entry(buffer)
- .or_insert(Vec::new())
- .append(&mut excerpts);
- }
-
- zeta.update(cx, |zeta, _cx| {
- zeta.set_context(project.clone(), context_excerpts)
- })?;
- } else {
- zeta.update(cx, |zeta, cx| {
- zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
- })?
- .await?;
- }
+ zeta.update(cx, |zeta, cx| {
+ zeta.refresh_context(project.clone(), cursor_buffer.clone(), cursor_anchor, cx)
+ })?
+ .await?;
}
let prediction = zeta
@@ -274,38 +241,6 @@ pub async fn perform_predict(
anyhow::Ok(result)
}
-async fn resolve_context_entry(
- project: Entity<Project>,
- excerpt: ExpectedExcerpt,
- mut cx: AsyncApp,
-) -> Result<(Entity<Buffer>, Vec<Range<Anchor>>)> {
- let buffer = project
- .update(&mut cx, |project, cx| {
- let project_path = project.find_project_path(&excerpt.path, cx).unwrap();
- project.open_buffer(project_path, cx)
- })?
- .await?;
-
- let ranges = buffer.read_with(&mut cx, |buffer, _| {
- let full_text = buffer.text();
- let offset = full_text
- .find(&excerpt.text)
- .expect("Expected context not found");
- let point = buffer.offset_to_point(offset);
- excerpt
- .required_lines
- .iter()
- .map(|line| {
- let row = point.row + line.0;
- let range = Point::new(row, 0)..Point::new(row + 1, 0);
- buffer.anchor_after(range.start)..buffer.anchor_before(range.end)
- })
- .collect()
- })?;
-
- Ok((buffer, ranges))
-}
-
struct RunCache {
cache_mode: CacheMode,
example_run_dir: PathBuf,