@@ -5,7 +5,7 @@ use std::{borrow::Cow, fmt::Write as _, mem, path::Path, sync::Arc};
pub const CURSOR_POSITION_MARKER: &str = "[CURSOR_POSITION]";
pub const INLINE_CURSOR_MARKER: &str = "<|user_cursor|>";
-#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
+#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
pub struct ExampleSpec {
#[serde(default)]
pub name: String,
@@ -13,7 +13,7 @@ use std::ops::Range;
use std::sync::Arc;
use std::{
borrow::Cow,
- io::{Read, Write},
+ io::Read,
path::{Path, PathBuf},
};
use zeta_prompt::RelatedFile;
@@ -216,20 +216,6 @@ pub fn read_example_files(inputs: &[PathBuf]) -> Vec<Example> {
examples
}
-pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
- let mut content = String::new();
- for example in examples {
- let line = serde_json::to_string(example).unwrap();
- content.push_str(&line);
- content.push('\n');
- }
- if let Some(output_path) = output_path {
- std::fs::write(output_path, content).expect("Failed to write examples");
- } else {
- std::io::stdout().write_all(&content.as_bytes()).unwrap();
- }
-}
-
pub fn sort_examples_by_repo_and_rev(examples: &mut [Example]) {
examples.sort_by(|a, b| {
a.spec
@@ -16,15 +16,22 @@ mod score;
mod split_commit;
mod synthesize;
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
+use collections::HashSet;
use edit_prediction::EditPredictionStore;
-use gpui::Application;
+use futures::channel::mpsc;
+use futures::{SinkExt as _, StreamExt as _};
+use gpui::{AppContext as _, Application};
+
use reqwest_client::ReqwestClient;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
+use std::fs::{File, OpenOptions};
+use std::hash::{Hash, Hasher};
+use std::io::{BufRead, BufReader, BufWriter, Write};
use std::{path::PathBuf, sync::Arc};
use crate::distill::run_distill;
-use crate::example::{Example, group_examples_by_repo, read_example_files, write_examples};
+use crate::example::{Example, group_examples_by_repo, read_example_files};
use crate::format_prompt::run_format_prompt;
use crate::load_project::run_load_project;
use crate::paths::FAILED_EXAMPLES_DIR;
@@ -241,6 +248,7 @@ impl EpArgs {
async fn load_examples(
http_client: Arc<dyn http_client::HttpClient>,
args: &EpArgs,
+ output_path: Option<&PathBuf>,
) -> anyhow::Result<Vec<Example>> {
let mut captured_after_timestamps = Vec::new();
let mut file_inputs = Vec::new();
@@ -294,11 +302,70 @@ async fn load_examples(
}
}
+ if let Some(path) = output_path {
+ resume_from_output(path, &mut examples);
+ }
+
Progress::global().set_total_examples(examples.len());
Ok(examples)
}
+fn spec_hash(spec: &edit_prediction::example_spec::ExampleSpec) -> u64 {
+ let mut hasher = collections::FxHasher::default();
+ spec.hash(&mut hasher);
+ hasher.finish()
+}
+
+fn resume_from_output(path: &PathBuf, examples: &mut Vec<Example>) {
+ let file = match File::open(path) {
+ Ok(f) => f,
+ Err(_) => return,
+ };
+
+ let input_hashes: HashSet<u64> = examples.iter().map(|e| spec_hash(&e.spec)).collect();
+
+ let reader = BufReader::new(file);
+ let mut kept_lines = Vec::new();
+ let mut kept_hashes = HashSet::default();
+
+ for line in reader.lines() {
+ let line = match line {
+ Ok(l) => l,
+ Err(_) => continue,
+ };
+
+ if let Ok(output_example) = serde_json::from_str::<Example>(&line) {
+ let hash = spec_hash(&output_example.spec);
+ if input_hashes.contains(&hash) && !kept_hashes.contains(&hash) {
+ kept_hashes.insert(hash);
+ kept_lines.push(line);
+ }
+ }
+ }
+
+ let total = examples.len();
+ let already_processed = kept_hashes.len();
+
+ eprintln!(
+ "Resuming: {}/{} examples already processed",
+ already_processed, total
+ );
+
+ let file = OpenOptions::new()
+ .write(true)
+ .truncate(true)
+ .open(path)
+ .expect("Failed to open output file for rewriting");
+ let mut writer = BufWriter::new(file);
+ for line in &kept_lines {
+ writeln!(writer, "{}", line).expect("Failed to write to output file");
+ }
+ writer.flush().expect("Failed to flush output file");
+
+ examples.retain(|e| !kept_hashes.contains(&spec_hash(&e.spec)));
+}
+
fn main() {
let args = EpArgs::parse();
@@ -361,7 +428,8 @@ fn main() {
cx.spawn(async move |cx| {
let result = async {
- let mut examples = load_examples(app_state.client.http_client(), &args).await?;
+ let mut examples =
+ load_examples(app_state.client.http_client(), &args, output.as_ref()).await?;
if let Command::Predict(args) = &command {
predict::sync_batches(&args.provider).await?;
@@ -369,6 +437,29 @@ fn main() {
let failfast_on_single_example = examples.len() == 1;
+ let output_sender: Option<mpsc::UnboundedSender<String>> =
+ if args.output.is_some() || !matches!(command, Command::Eval(_)) {
+ output.as_ref().map(|path| {
+ let file = OpenOptions::new()
+ .create(true)
+ .append(true)
+ .open(path)
+ .expect("Failed to open output file");
+ let mut writer = BufWriter::new(file);
+ let (sender, mut receiver) = mpsc::unbounded::<String>();
+ cx.background_spawn(async move {
+ while let Some(line) = receiver.next().await {
+ writeln!(writer, "{}", line).expect("Failed to write example");
+ writer.flush().expect("Failed to flush output");
+ }
+ })
+ .detach();
+ sender
+ })
+ } else {
+ None
+ };
+
let mut grouped_examples = group_examples_by_repo(&mut examples);
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
@@ -437,15 +528,24 @@ fn main() {
)
.await;
}
+
+ if let Some(ref mut sender) = output_sender.clone() {
+ let line = serde_json::to_string(example).unwrap();
+ sender
+ .send(line)
+ .await
+ .expect("Failed to send to output writer");
+ } else if args.output.is_none() && !matches!(command, Command::Eval(_))
+ {
+ let line = serde_json::to_string(example).unwrap();
+ println!("{}", line);
+ }
}
});
futures::future::join_all(futures).await;
}
- Progress::global().finalize();
- if args.output.is_some() || !matches!(command, Command::Eval(_)) {
- write_examples(&examples, output.as_ref());
- }
+ Progress::global().finalize();
match &command {
Command::Predict(args) => predict::sync_batches(&args.provider).await?,