diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs index d4c36d1f9fa43abc5d3f67d771b9fdacb1f425e9..9da9297794f477473cce7d307a4707118f585dc8 100644 --- a/crates/edit_prediction/src/example_spec.rs +++ b/crates/edit_prediction/src/example_spec.rs @@ -5,7 +5,7 @@ use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc}; pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]"; pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>"; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] pub struct ExampleSpec { #[serde(default)] pub name: String, diff --git a/crates/edit_prediction_cli/src/example.rs b/crates/edit_prediction_cli/src/example.rs index efcc4508dca513f3545c836c0641b5f22e593a82..e8352fce944afc7b235be8953b970da1c1c6143f 100644 --- a/crates/edit_prediction_cli/src/example.rs +++ b/crates/edit_prediction_cli/src/example.rs @@ -13,7 +13,7 @@ use std::ops::Range; use std::sync::Arc; use std::{ borrow::Cow, - io::{Read, Write}, + io::Read, path::{Path, PathBuf}, }; use zeta_prompt::RelatedFile; @@ -216,20 +216,6 @@ pub fn read_example_files(inputs: &[PathBuf]) -> Vec { examples } -pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) { - let mut content = String::new(); - for example in examples { - let line = serde_json::to_string(example).unwrap(); - content.push_str(&line); - content.push('\n'); - } - if let Some(output_path) = output_path { - std::fs::write(output_path, content).expect("Failed to write examples"); - } else { - std::io::stdout().write_all(&content.as_bytes()).unwrap(); - } -} - pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) { examples.sort_by(|a, b| { a.spec diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 911878b098742af34a733abba0159513f000b7dd..c12c2ae2aadb1f308da4db9a5682379392287223 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -16,15 +16,22 @@ mod score; mod split_commit; mod synthesize; use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum}; +use collections::HashSet; use edit_prediction::EditPredictionStore; -use gpui::Application; +use futures::channel::mpsc; +use futures::{SinkExt as _, StreamExt as _}; +use gpui::{AppContext as _, Application}; + use reqwest_client::ReqwestClient; use serde::{Deserialize, Serialize}; use std::fmt::Display; +use std::fs::{File, OpenOptions}; +use std::hash::{Hash, Hasher}; +use std::io::{BufRead, BufReader, BufWriter, Write}; use std::{path::PathBuf, sync::Arc}; use crate::distill::run_distill; -use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples}; +use crate::example::{Example, group_examples_by_repo, read_example_files}; use crate::format_prompt::run_format_prompt; use crate::load_project::run_load_project; use crate::paths::FAILED_EXAMPLES_DIR; @@ -241,6 +248,7 @@ impl EpArgs { async fn load_examples( http_client: Arc, args: &EpArgs, + output_path: Option<&PathBuf>, ) -> anyhow::Result> { let mut captured_after_timestamps = Vec::new(); let mut file_inputs = Vec::new(); @@ -294,11 +302,70 @@ async fn load_examples( } } + if let Some(path) = output_path { + resume_from_output(path, &mut examples); + } + Progress::global().set_total_examples(examples.len()); Ok(examples) } +fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 { + let mut hasher = collections::FxHasher::default(); + spec.hash(&mut hasher); + hasher.finish() +} + +fn resume_from_output(path: &PathBuf, examples: &mut Vec) { + let file = match File::open(path) { + Ok(f) => f, + Err(_) => return, + }; + + let input_hashes: HashSet = examples.iter().map(|e| spec_hash(&e.spec)).collect(); + + let reader = BufReader::new(file); + let mut kept_lines = Vec::new(); + let mut kept_hashes = HashSet::default(); + + for line in reader.lines() { + let line = match line { + Ok(l) => l, + Err(_) => continue, + }; + + if let Ok(output_example) = serde_json::from_str::(&line) { + let hash = spec_hash(&output_example.spec); + if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) { + kept_hashes.insert(hash); + kept_lines.push(line); + } + } + } + + let total = examples.len(); + let already_processed = kept_hashes.len(); + + eprintln!( + "Resuming: {}/{} examples already processed", + already_processed, total + ); + + let file = OpenOptions::new() + .write(true) + .truncate(true) + .open(path) + .expect("Failed to open output file for rewriting"); + let mut writer = BufWriter::new(file); + for line in &kept_lines { + writeln!(writer, "{}", line).expect("Failed to write to output file"); + } + writer.flush().expect("Failed to flush output file"); + + examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec))); +} + fn main() { let args = EpArgs::parse(); @@ -361,7 +428,8 @@ fn main() { cx.spawn(async move |cx| { let result = async { - let mut examples = load_examples(app_state.client.http_client(), &args).await?; + let mut examples = + load_examples(app_state.client.http_client(), &args, output.as_ref()).await?; if let Command::Predict(args) = &command { predict::sync_batches(&args.provider).await?; @@ -369,6 +437,29 @@ fn main() { let failfast_on_single_example = examples.len() == 1; + let output_sender: Option> = + if args.output.is_some() || !matches!(command, Command::Eval(_)) { + output.as_ref().map(|path| { + let file = OpenOptions::new() + .create(true) + .append(true) + .open(path) + .expect("Failed to open output file"); + let mut writer = BufWriter::new(file); + let (sender, mut receiver) = mpsc::unbounded::(); + cx.background_spawn(async move { + while let Some(line) = receiver.next().await { + writeln!(writer, "{}", line).expect("Failed to write example"); + writer.flush().expect("Failed to flush output"); + } + }) + .detach(); + sender + }) + } else { + None + }; + let mut grouped_examples = group_examples_by_repo(&mut examples); let example_batches = grouped_examples.chunks_mut(args.max_parallelism); @@ -437,15 +528,24 @@ fn main() { ) .await; } + + if let Some(ref mut sender) = output_sender.clone() { + let line = serde_json::to_string(example).unwrap(); + sender + .send(line) + .await + .expect("Failed to send to output writer"); + } else if args.output.is_none() && !matches!(command, Command::Eval(_)) + { + let line = serde_json::to_string(example).unwrap(); + println!("{}", line); + } } }); futures::future::join_all(futures).await; } - Progress::global().finalize(); - if args.output.is_some() || !matches!(command, Command::Eval(_)) { - write_examples(&examples, output.as_ref()); - } + Progress::global().finalize(); match &command { Command::Predict(args) => predict::sync_batches(&args.provider).await?,