ep cli: Resume from output file (#46293)

Agus Zubiaga created

Change summary

crates/edit_prediction/src/example_spec.rs |   2 
crates/edit_prediction_cli/src/example.rs  |  16 ---
crates/edit_prediction_cli/src/main.rs     | 114 ++++++++++++++++++++++-
3 files changed, 109 insertions(+), 23 deletions(-)

Detailed changes

crates/edit_prediction/src/example_spec.rs 🔗

@@ -5,7 +5,7 @@ use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc};
 pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
 pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
 
-#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
+#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
 pub struct ExampleSpec {
     #[serde(default)]
     pub name: String,

crates/edit_prediction_cli/src/example.rs 🔗

@@ -13,7 +13,7 @@ use std::ops::Range;
 use std::sync::Arc;
 use std::{
     borrow::Cow,
-    io::{Read, Write},
+    io::Read,
     path::{Path, PathBuf},
 };
 use zeta_prompt::RelatedFile;
@@ -216,20 +216,6 @@ pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
     examples
 }
 
-pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
-    let mut content = String::new();
-    for example in examples {
-        let line = serde_json::to_string(example).unwrap();
-        content.push_str(&line);
-        content.push('\n');
-    }
-    if let Some(output_path) = output_path {
-        std::fs::write(output_path, content).expect("Failed to write examples");
-    } else {
-        std::io::stdout().write_all(&content.as_bytes()).unwrap();
-    }
-}
-
 pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
     examples.sort_by(|a, b| {
         a.spec

crates/edit_prediction_cli/src/main.rs 🔗

@@ -16,15 +16,22 @@ mod score;
 mod split_commit;
 mod synthesize;
 use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
+use collections::HashSet;
 use edit_prediction::EditPredictionStore;
-use gpui::Application;
+use futures::channel::mpsc;
+use futures::{SinkExt as _, StreamExt as _};
+use gpui::{AppContext as _, Application};
+
 use reqwest_client::ReqwestClient;
 use serde::{Deserialize, Serialize};
 use std::fmt::Display;
+use std::fs::{File, OpenOptions};
+use std::hash::{Hash, Hasher};
+use std::io::{BufRead, BufReader, BufWriter, Write};
 use std::{path::PathBuf, sync::Arc};
 
 use crate::distill::run_distill;
-use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
+use crate::example::{Example, group_examples_by_repo, read_example_files};
 use crate::format_prompt::run_format_prompt;
 use crate::load_project::run_load_project;
 use crate::paths::FAILED_EXAMPLES_DIR;
@@ -241,6 +248,7 @@ impl EpArgs {
 async fn load_examples(
     http_client: Arc<dyn http_client::HttpClient>,
     args: &EpArgs,
+    output_path: Option<&PathBuf>,
 ) -> anyhow::Result<Vec<Example>> {
     let mut captured_after_timestamps = Vec::new();
     let mut file_inputs = Vec::new();
@@ -294,11 +302,70 @@ async fn load_examples(
         }
     }
 
+    if let Some(path) = output_path {
+        resume_from_output(path, &mut examples);
+    }
+
     Progress::global().set_total_examples(examples.len());
 
     Ok(examples)
 }
 
+fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
+    let mut hasher = collections::FxHasher::default();
+    spec.hash(&mut hasher);
+    hasher.finish()
+}
+
+fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
+    let file = match File::open(path) {
+        Ok(f) => f,
+        Err(_) => return,
+    };
+
+    let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
+
+    let reader = BufReader::new(file);
+    let mut kept_lines = Vec::new();
+    let mut kept_hashes = HashSet::default();
+
+    for line in reader.lines() {
+        let line = match line {
+            Ok(l) => l,
+            Err(_) => continue,
+        };
+
+        if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
+            let hash = spec_hash(&output_example.spec);
+            if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
+                kept_hashes.insert(hash);
+                kept_lines.push(line);
+            }
+        }
+    }
+
+    let total = examples.len();
+    let already_processed = kept_hashes.len();
+
+    eprintln!(
+        "Resuming: {}/{} examples already processed",
+        already_processed, total
+    );
+
+    let file = OpenOptions::new()
+        .write(true)
+        .truncate(true)
+        .open(path)
+        .expect("Failed to open output file for rewriting");
+    let mut writer = BufWriter::new(file);
+    for line in &kept_lines {
+        writeln!(writer, "{}", line).expect("Failed to write to output file");
+    }
+    writer.flush().expect("Failed to flush output file");
+
+    examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
+}
+
 fn main() {
     let args = EpArgs::parse();
 
@@ -361,7 +428,8 @@ fn main() {
 
         cx.spawn(async move |cx| {
             let result = async {
-                let mut examples = load_examples(app_state.client.http_client(), &args).await?;
+                let mut examples =
+                    load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
 
                 if let Command::Predict(args) = &command {
                     predict::sync_batches(&args.provider).await?;
@@ -369,6 +437,29 @@ fn main() {
 
                 let failfast_on_single_example = examples.len() == 1;
 
+                let output_sender: Option<mpsc::UnboundedSender<String>> =
+                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
+                        output.as_ref().map(|path| {
+                            let file = OpenOptions::new()
+                                .create(true)
+                                .append(true)
+                                .open(path)
+                                .expect("Failed to open output file");
+                            let mut writer = BufWriter::new(file);
+                            let (sender, mut receiver) = mpsc::unbounded::<String>();
+                            cx.background_spawn(async move {
+                                while let Some(line) = receiver.next().await {
+                                    writeln!(writer, "{}", line).expect("Failed to write example");
+                                    writer.flush().expect("Failed to flush output");
+                                }
+                            })
+                            .detach();
+                            sender
+                        })
+                    } else {
+                        None
+                    };
+
                 let mut grouped_examples = group_examples_by_repo(&mut examples);
                 let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
 
@@ -437,15 +528,24 @@ fn main() {
                                 )
                                 .await;
                             }
+
+                            if let Some(ref mut sender) = output_sender.clone() {
+                                let line = serde_json::to_string(example).unwrap();
+                                sender
+                                    .send(line)
+                                    .await
+                                    .expect("Failed to send to output writer");
+                            } else if args.output.is_none() && !matches!(command, Command::Eval(_))
+                            {
+                                let line = serde_json::to_string(example).unwrap();
+                                println!("{}", line);
+                            }
                         }
                     });
                     futures::future::join_all(futures).await;
                 }
-                Progress::global().finalize();
 
-                if args.output.is_some() || !matches!(command, Command::Eval(_)) {
-                    write_examples(&examples, output.as_ref());
-                }
+                Progress::global().finalize();
 
                 match &command {
                     Command::Predict(args) => predict::sync_batches(&args.provider).await?,