ep: Automate distillation (#51293)

Oleksiy Syvokon and Ben Kunkle created

Increases timeouts, simplifies queries, and makes other changes needed to
pull data more efficiently.

Release Notes:

- N/A

---------

Co-authored-by: Ben Kunkle <ben@zed.dev>

Change summary

crates/edit_prediction_cli/src/anthropic_client.rs |  18 
crates/edit_prediction_cli/src/main.rs             |  68 +
crates/edit_prediction_cli/src/openai_client.rs    |  16 
crates/edit_prediction_cli/src/predict.rs          |  85 +
crates/edit_prediction_cli/src/pull_examples.rs    | 718 +++++++++------
crates/edit_prediction_cli/src/repair.rs           |  67 +
6 files changed, 697 insertions(+), 275 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/anthropic_client.rs 🔗

@@ -292,6 +292,14 @@ impl BatchingLlmClient {
         self.download_finished_batches().await
     }
 
+    pub fn pending_batch_count(&self) -> Result<usize> {
+        let connection = self.connection.lock().unwrap();
+        let counts: Vec<i32> = connection.select(
+            sql!(SELECT COUNT(*) FROM cache WHERE batch_id IS NOT NULL AND response IS NULL),
+        )?()?;
+        Ok(counts.into_iter().next().unwrap_or(0) as usize)
+    }
+
     /// Import batch results from external batch IDs (useful for recovering after database loss)
     pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
         for batch_id in batch_ids {
@@ -831,6 +839,16 @@ impl AnthropicClient {
         }
     }
 
+    pub fn pending_batch_count(&self) -> Result<usize> {
+        match self {
+            AnthropicClient::Plain(_) => Ok(0),
+            AnthropicClient::Batch(batching_llm_client) => {
+                batching_llm_client.pending_batch_count()
+            }
+            AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
+        }
+    }
+
     pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
         match self {
             AnthropicClient::Plain(_) => {

crates/edit_prediction_cli/src/main.rs 🔗

@@ -26,6 +26,7 @@ mod split_dataset;
 mod synthesize;
 mod truncate_expected_patch;
 mod word_diff;
+use anyhow::Context as _;
 use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 use collections::{HashMap, HashSet};
 use edit_prediction::EditPredictionStore;
@@ -294,6 +295,9 @@ struct PredictArgs {
     /// Only use cached responses, don't queue new requests for batching
     #[clap(long)]
     cache_only: bool,
+    /// Wait for all batches to complete before exiting (only applies to batched providers like teacher)
+    #[clap(long)]
+    wait: bool,
 }
 
 #[derive(Debug, Args, Clone)]
@@ -762,7 +766,7 @@ async fn load_examples(
             "skipping Snowflake inputs because --limit is already satisfied by example files"
         );
     } else {
-        let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
+        let max_rows_per_timestamp = remaining_limit_for_snowflake;
 
         if !rejected_after_timestamps.is_empty() {
             rejected_after_timestamps.sort();
@@ -1339,18 +1343,45 @@ fn main() {
 
                 Progress::global().finalize();
 
+                let is_markdown = args.markdown;
+                let write_path = in_place_temp_path.as_ref().or(output.as_ref());
                 match &command {
                     Command::Predict(args) | Command::Score(args) => {
                         predict::sync_batches(args.provider.as_ref()).await?;
+                        if args.wait {
+                            predict::wait_for_batches(args.provider.as_ref()).await?;
+                            let mut examples =
+                                std::mem::take(&mut *finished_examples.lock().unwrap());
+                            predict::reprocess_after_batch_wait(&mut examples, args).await?;
+                            rewrite_output(&examples, write_path, is_markdown)?;
+                            *finished_examples.lock().unwrap() = examples;
+                        }
                     }
                     Command::Eval(args) => {
                         predict::sync_batches(args.predict.provider.as_ref()).await?;
+                        if args.predict.wait {
+                            predict::wait_for_batches(args.predict.provider.as_ref()).await?;
+                            let mut examples =
+                                std::mem::take(&mut *finished_examples.lock().unwrap());
+                            predict::reprocess_after_batch_wait(&mut examples, &args.predict)
+                                .await?;
+                            rewrite_output(&examples, write_path, is_markdown)?;
+                            *finished_examples.lock().unwrap() = examples;
+                        }
                     }
                     Command::Qa(args) => {
                         qa::sync_batches(args).await?;
                     }
                     Command::Repair(args) => {
                         repair::sync_batches(args).await?;
+                        if args.wait {
+                            repair::wait_for_batches(args).await?;
+                            let mut examples =
+                                std::mem::take(&mut *finished_examples.lock().unwrap());
+                            repair::reprocess_after_batch_wait(&mut examples, args).await?;
+                            rewrite_output(&examples, write_path, is_markdown)?;
+                            *finished_examples.lock().unwrap() = examples;
+                        }
                     }
                     _ => (),
                 }
@@ -1391,6 +1422,41 @@ fn main() {
     });
 }
 
+fn rewrite_output(
+    examples: &[Example],
+    output_path: Option<&PathBuf>,
+    markdown: bool,
+) -> anyhow::Result<()> {
+    if markdown {
+        let dir = output_path.context("--markdown requires -o")?;
+        for example in examples {
+            let filename = format!("{}.md", example.spec.filename());
+            let path = dir.join(&filename);
+            let markdown = example.spec.to_markdown();
+            std::fs::write(&path, &markdown).context("Failed to write markdown file")?;
+        }
+    } else if let Some(path) = output_path {
+        let file = OpenOptions::new()
+            .create(true)
+            .write(true)
+            .truncate(true)
+            .open(path)
+            .context("Failed to open output file for rewriting")?;
+        let mut writer = BufWriter::new(file);
+        for example in examples {
+            let line = serde_json::to_string(example)?;
+            writeln!(writer, "{}", line)?;
+        }
+        writer.flush()?;
+    } else {
+        for example in examples {
+            let line = serde_json::to_string(example)?;
+            println!("{}", line);
+        }
+    }
+    Ok(())
+}
+
 async fn handle_error(
     error: anyhow::Error,
     args: &EpArgs,

crates/edit_prediction_cli/src/openai_client.rs 🔗

@@ -214,6 +214,14 @@ impl BatchingOpenAiClient {
         self.download_finished_batches().await
     }
 
+    pub fn pending_batch_count(&self) -> Result<usize> {
+        let connection = self.connection.lock().unwrap();
+        let counts: Vec<i32> = connection.select(
+            sql!(SELECT COUNT(*) FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL),
+        )?()?;
+        Ok(counts.into_iter().next().unwrap_or(0) as usize)
+    }
+
     pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
         for batch_id in batch_ids {
             log::info!("Importing OpenAI batch {}", batch_id);
@@ -672,6 +680,14 @@ impl OpenAiClient {
         }
     }
 
+    pub fn pending_batch_count(&self) -> Result<usize> {
+        match self {
+            OpenAiClient::Plain(_) => Ok(0),
+            OpenAiClient::Batch(batching_client) => batching_client.pending_batch_count(),
+            OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
+        }
+    }
+
     pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> {
         match self {
             OpenAiClient::Plain(_) => {

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -8,7 +8,7 @@ use crate::{
     openai_client::OpenAiClient,
     parse_output::parse_prediction_output,
     paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
-    progress::{ExampleProgress, InfoStyle, Step, StepProgress},
+    progress::{ExampleProgress, InfoStyle, Progress, Step, StepProgress},
     retrieve_context::run_context_retrieval,
 };
 use anyhow::Context as _;
@@ -699,3 +699,86 @@ pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Resu
     };
     Ok(())
 }
+
+pub async fn reprocess_after_batch_wait(
+    examples: &mut [Example],
+    args: &PredictArgs,
+) -> anyhow::Result<()> {
+    let Some(PredictionProvider::Teacher(backend)) = args.provider else {
+        return Ok(());
+    };
+
+    let mut reprocessed = 0;
+    for example in examples.iter_mut() {
+        let has_prediction = example
+            .predictions
+            .iter()
+            .any(|p| p.actual_patch.is_some() || !p.actual_output.is_empty());
+        if has_prediction || example.prompt.is_none() {
+            continue;
+        }
+
+        let example_progress = Progress::global().start_group(&example.spec.name);
+        let step_progress = example_progress.start(Step::Predict);
+        predict_teacher(
+            example,
+            backend,
+            true,
+            args.repetitions,
+            false,
+            &step_progress,
+        )
+        .await?;
+        reprocessed += 1;
+    }
+
+    if reprocessed > 0 {
+        eprintln!("Reprocessed {} example(s) with batch results", reprocessed);
+    }
+
+    Ok(())
+}
+
+pub async fn wait_for_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
+    let poll_interval = std::time::Duration::from_secs(30);
+
+    loop {
+        let pending = pending_batch_count(provider)?;
+        if pending == 0 {
+            break;
+        }
+
+        eprintln!(
+            "Waiting for {} pending batch request(s) to complete... (polling every {}s)",
+            pending,
+            poll_interval.as_secs()
+        );
+        std::thread::sleep(poll_interval);
+
+        sync_batches(provider).await?;
+    }
+
+    Ok(())
+}
+
+fn pending_batch_count(provider: Option<&PredictionProvider>) -> anyhow::Result<usize> {
+    match provider {
+        Some(PredictionProvider::Teacher(backend)) => match backend {
+            TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
+                let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
+                    AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
+                        .expect("Failed to create Anthropic client")
+                });
+                llm_client.pending_batch_count()
+            }
+            TeacherBackend::Gpt52 => {
+                let llm_client = OPENAI_CLIENT.get_or_init(|| {
+                    OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
+                        .expect("Failed to create OpenAI client")
+                });
+                llm_client.pending_batch_count()
+            }
+        },
+        _ => Ok(0),
+    }
+}

crates/edit_prediction_cli/src/pull_examples.rs 🔗

@@ -5,6 +5,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request};
 use indoc::indoc;
 use serde::Deserialize;
 use serde_json::{Value as JsonValue, json};
+use std::collections::HashMap;
 use std::fmt::Write as _;
 use std::io::Read;
 use std::sync::Arc;
@@ -13,17 +14,14 @@ use telemetry_events::EditPredictionRating;
 
 use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format};
 
-use crate::example::Example;
+use crate::PredictionProvider;
+use crate::example::{Example, ExamplePrompt};
 use crate::progress::{InfoStyle, Progress, Step};
-const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
 use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
 
 pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
 pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
-const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
-const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
-const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
-const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
+const SNOWFLAKE_TIMEOUT_CODE: &str = "000630";
 
 /// Minimum Zed version for filtering captured examples.
 /// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
@@ -34,10 +32,13 @@ pub struct MinCaptureVersion {
     pub patch: u32,
 }
 
-const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
-const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
 pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
-pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
+const PARTITION_FETCH_MAX_RETRIES: usize = 3;
+const PARTITION_FETCH_RETRY_DELAYS: [Duration; PARTITION_FETCH_MAX_RETRIES] = [
+    Duration::from_millis(500),
+    Duration::from_secs(1),
+    Duration::from_secs(2),
+];
 
 /// Parse an input token of the form `captured-after:{timestamp}`.
 pub fn parse_captured_after_input(input: &str) -> Option<&str> {
@@ -127,26 +128,25 @@ async fn run_sql_with_polling(
             .context("async query response missing statementHandle")?
             .clone();
 
-        for attempt in 1..=MAX_POLL_ATTEMPTS {
+        for attempt in 0.. {
             step_progress.set_substatus(format!("polling ({attempt})"));
 
             background_executor.timer(POLL_INTERVAL).await;
 
-            response =
-                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
+            response = fetch_partition_with_retries(
+                http_client.clone(),
+                base_url,
+                token,
+                &statement_handle,
+                0,
+                background_executor.clone(),
+            )
+            .await?;
 
             if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
                 break;
             }
         }
-
-        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
-            anyhow::bail!(
-                "query still running after {} poll attempts ({} seconds)",
-                MAX_POLL_ATTEMPTS,
-                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
-            );
-        }
     }
 
     Ok(response)
@@ -158,19 +158,29 @@ struct SnowflakeConfig {
     role: Option<String>,
 }
 
-async fn fetch_examples_with_query(
+#[derive(Clone)]
+struct QueryRetryState {
+    resume_after: String,
+    remaining_limit: Option<usize>,
+    offset: usize,
+}
+
+async fn fetch_examples_with_query<MakeBindings>(
     http_client: Arc<dyn HttpClient>,
     step_progress: &crate::progress::StepProgress,
     background_executor: BackgroundExecutor,
     statement: &str,
-    bindings: JsonValue,
-    timeout_seconds: u64,
+    initial_retry_state: QueryRetryState,
+    make_bindings: MakeBindings,
     required_columns: &[&str],
     parse_response: for<'a> fn(
         &'a SnowflakeStatementResponse,
-        &'a std::collections::HashMap<String, usize>,
+        &'a HashMap<String, usize>,
     ) -> Result<Box<dyn Iterator<Item = Example> + 'a>>,
-) -> Result<Vec<Example>> {
+) -> Result<Vec<Example>>
+where
+    MakeBindings: Fn(&QueryRetryState) -> JsonValue,
+{
     let snowflake = SnowflakeConfig {
         token: std::env::var("EP_SNOWFLAKE_API_KEY")
             .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?,
@@ -179,74 +189,153 @@ async fn fetch_examples_with_query(
         )?,
         role: std::env::var("EP_SNOWFLAKE_ROLE").ok(),
     };
-    let request = json!({
-        "statement": statement,
-        "timeout": timeout_seconds,
-        "database": "EVENTS",
-        "schema": "PUBLIC",
-        "warehouse": "DBT",
-        "role": snowflake.role.as_deref(),
-        "bindings": bindings
-    });
 
-    let response = run_sql_with_polling(
-        http_client.clone(),
-        &snowflake.base_url,
-        &snowflake.token,
-        &request,
-        step_progress,
-        background_executor,
-    )
-    .await?;
-
-    let total_rows = response
-        .result_set_meta_data
-        .as_ref()
-        .and_then(|meta| meta.num_rows)
-        .unwrap_or(response.data.len() as i64);
-    let partition_count = response
-        .result_set_meta_data
-        .as_ref()
-        .map(|meta| meta.partition_info.len())
-        .unwrap_or(1)
-        .max(1);
-
-    step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
-    step_progress.set_substatus("parsing");
-
-    let column_indices = get_column_indices(&response.result_set_meta_data, required_columns);
-
-    let mut parsed_examples = Vec::with_capacity(total_rows as usize);
-    parsed_examples.extend(parse_response(&response, &column_indices)?);
-
-    if partition_count > 1 {
-        let statement_handle = response
-            .statement_handle
+    let mut requested_columns = required_columns.to_vec();
+    if !requested_columns.contains(&"continuation_time") {
+        requested_columns.push("continuation_time");
+    }
+
+    let mut parsed_examples = Vec::new();
+    let mut retry_state = initial_retry_state;
+    let mut retry_count = 0usize;
+
+    loop {
+        let bindings = make_bindings(&retry_state);
+        let request = json!({
+            "statement": statement,
+            "database": "EVENTS",
+            "schema": "PUBLIC",
+            "warehouse": "DBT",
+            "role": snowflake.role.as_deref(),
+            "bindings": bindings
+        });
+
+        let response = match run_sql_with_polling(
+            http_client.clone(),
+            &snowflake.base_url,
+            &snowflake.token,
+            &request,
+            step_progress,
+            background_executor.clone(),
+        )
+        .await
+        {
+            Ok(response) => response,
+            Err(error) => {
+                if is_snowflake_timeout_error(&error) && !parsed_examples.is_empty() {
+                    retry_count += 1;
+                    step_progress.set_substatus(format!(
+                        "retrying from {} ({retry_count})",
+                        retry_state.resume_after
+                    ));
+                    continue;
+                }
+
+                return Err(error);
+            }
+        };
+
+        let total_rows = response
+            .result_set_meta_data
+            .as_ref()
+            .and_then(|meta| meta.num_rows)
+            .unwrap_or(response.data.len() as i64);
+        let partition_count = response
+            .result_set_meta_data
             .as_ref()
-            .context("response has multiple partitions but no statementHandle")?;
+            .map(|meta| meta.partition_info.len())
+            .unwrap_or(1)
+            .max(1);
 
-        for partition in 1..partition_count {
-            step_progress.set_substatus(format!(
-                "fetching partition {}/{}",
-                partition + 1,
-                partition_count
-            ));
+        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
+        step_progress.set_substatus("parsing");
 
-            let partition_response = fetch_partition(
-                http_client.clone(),
-                &snowflake.base_url,
-                &snowflake.token,
-                statement_handle,
-                partition,
-            )
-            .await?;
+        let column_indices = get_column_indices(&response.result_set_meta_data, &requested_columns);
+        let mut rows_fetched_this_attempt = 0usize;
+        let mut timed_out_fetching_partition = false;
+
+        parsed_examples.extend(parse_response(&response, &column_indices)?);
+        rows_fetched_this_attempt += response.data.len();
+        let mut last_continuation_time_this_attempt =
+            last_continuation_timestamp_from_response(&response, &column_indices);
 
-            parsed_examples.extend(parse_response(&partition_response, &column_indices)?);
+        if partition_count > 1 {
+            let statement_handle = response
+                .statement_handle
+                .as_ref()
+                .context("response has multiple partitions but no statementHandle")?;
+
+            for partition in 1..partition_count {
+                step_progress.set_substatus(format!(
+                    "fetching partition {}/{}",
+                    partition + 1,
+                    partition_count
+                ));
+
+                let partition_response = match fetch_partition_with_retries(
+                    http_client.clone(),
+                    &snowflake.base_url,
+                    &snowflake.token,
+                    statement_handle,
+                    partition,
+                    background_executor.clone(),
+                )
+                .await
+                {
+                    Ok(response) => response,
+                    Err(error) => {
+                        if is_snowflake_timeout_error(&error) && rows_fetched_this_attempt > 0 {
+                            timed_out_fetching_partition = true;
+                            break;
+                        }
+
+                        return Err(error);
+                    }
+                };
+
+                parsed_examples.extend(parse_response(&partition_response, &column_indices)?);
+                rows_fetched_this_attempt += partition_response.data.len();
+
+                if let Some(partition_continuation_time) =
+                    last_continuation_timestamp_from_response(&partition_response, &column_indices)
+                {
+                    last_continuation_time_this_attempt = Some(partition_continuation_time);
+                }
+            }
         }
-    }
 
-    step_progress.set_substatus("done");
-    Ok(parsed_examples)
+        if rows_fetched_this_attempt == 0 {
+            step_progress.set_substatus("done");
+            return Ok(parsed_examples);
+        }
+
+        if let Some(remaining_limit_value) = &mut retry_state.remaining_limit {
+            *remaining_limit_value =
+                remaining_limit_value.saturating_sub(rows_fetched_this_attempt);
+            if *remaining_limit_value == 0 {
+                step_progress.set_substatus("done");
+                return Ok(parsed_examples);
+            }
+        }
+
+        if !timed_out_fetching_partition {
+            step_progress.set_substatus("done");
+            return Ok(parsed_examples);
+        }
+
+        let Some(last_continuation_time_this_attempt) = last_continuation_time_this_attempt else {
+            step_progress.set_substatus("done");
+            return Ok(parsed_examples);
+        };
+
+        retry_state.resume_after = last_continuation_time_this_attempt;
+        retry_state.offset = 0;
+        retry_count += 1;
+        step_progress.set_substatus(format!(
+            "retrying from {} ({retry_count})",
+            retry_state.resume_after
+        ));
+    }
 }
 
 pub(crate) async fn fetch_partition(
@@ -338,6 +427,57 @@ pub(crate) async fn fetch_partition(
     })
 }
 
+async fn fetch_partition_with_retries(
+    http_client: Arc<dyn HttpClient>,
+    base_url: &str,
+    token: &str,
+    statement_handle: &str,
+    partition: usize,
+    background_executor: BackgroundExecutor,
+) -> Result<SnowflakeStatementResponse> {
+    let mut last_error = None;
+
+    for retry_attempt in 0..=PARTITION_FETCH_MAX_RETRIES {
+        match fetch_partition(
+            http_client.clone(),
+            base_url,
+            token,
+            statement_handle,
+            partition,
+        )
+        .await
+        {
+            Ok(response) => return Ok(response),
+            Err(error) => {
+                if retry_attempt == PARTITION_FETCH_MAX_RETRIES
+                    || !is_transient_partition_fetch_error(&error)
+                {
+                    return Err(error);
+                }
+
+                last_error = Some(error);
+                background_executor
+                    .timer(PARTITION_FETCH_RETRY_DELAYS[retry_attempt])
+                    .await;
+            }
+        }
+    }
+
+    match last_error {
+        Some(error) => Err(error),
+        None => anyhow::bail!("partition fetch retry loop exited without a result"),
+    }
+}
+
+fn is_transient_partition_fetch_error(error: &anyhow::Error) -> bool {
+    error.chain().any(|cause| {
+        let message = cause.to_string();
+        message.contains("failed to read Snowflake SQL API partition response body")
+            || message.contains("unexpected EOF")
+            || message.contains("peer closed connection without sending TLS close_notify")
+    })
+}
+
 pub(crate) async fn run_sql(
     http_client: Arc<dyn HttpClient>,
     base_url: &str,
@@ -379,19 +519,32 @@ pub(crate) async fn run_sql(
         bytes
     };
 
-    if !status.is_success() && status.as_u16() != 202 {
+    let snowflake_response = serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
+        .context("failed to parse Snowflake SQL API response JSON")?;
+
+    if !status.is_success() && status.as_u16() != 202 && !is_timeout_response(&snowflake_response) {
         let body_text = String::from_utf8_lossy(&body_bytes);
         anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
     }
 
-    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
-        .context("failed to parse Snowflake SQL API response JSON")
+    if is_timeout_response(&snowflake_response) {
+        anyhow::bail!(
+            "snowflake sql api timed out code={} message={}",
+            snowflake_response.code.as_deref().unwrap_or("<no code>"),
+            snowflake_response
+                .message
+                .as_deref()
+                .unwrap_or("<no message>")
+        );
+    }
+
+    Ok(snowflake_response)
 }
 
 pub async fn fetch_rejected_examples_after(
     http_client: Arc<dyn HttpClient>,
     after_timestamps: &[String],
-    max_rows_per_timestamp: usize,
+    max_rows_per_timestamp: Option<usize>,
     offset: usize,
     background_executor: BackgroundExecutor,
     min_capture_version: Option<MinCaptureVersion>,
@@ -416,55 +569,53 @@ pub async fn fetch_rejected_examples_after(
 
         let statement = indoc! {r#"
             SELECT
-                req.event_properties:request_id::string AS request_id,
-                req.device_id::string AS device_id,
-                req.time::string AS time,
-                req.event_properties:input AS input,
-                req.event_properties:prompt::string AS prompt,
-                req.event_properties:output::string AS output,
-                rej.event_properties:was_shown::boolean AS was_shown,
-                rej.event_properties:reason::string AS reason,
-                req.event_properties:zed_version::string AS zed_version
-            FROM events req
-            INNER JOIN events rej
-                ON req.event_properties:request_id = rej.event_properties:request_id
-            WHERE req.event_type = ?
-                AND rej.event_type = ?
-                AND req.event_properties:version = 'V3'
-                AND rej.event_properties:was_shown = true
-                AND req.event_properties:input:can_collect_data = true
-                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
+                ep_request_id AS request_id,
+                device_id AS device_id,
+                requested_at::string AS continuation_time,
+                requested_at::string AS time,
+                input_payload AS input,
+                prompt AS prompt,
+                requested_output AS output,
+                is_ep_shown_before_rejected AS was_shown,
+                ep_rejected_reason AS reason,
+                zed_version AS zed_version
+            FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
+            WHERE ep_outcome LIKE 'Rejected%'
+                AND is_ep_shown_before_rejected = true
+                AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
                 AND (? IS NULL OR (
-                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
+                    TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ?
                     OR (
-                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
-                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
+                        TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ?
+                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ?
                     )
                 ))
-            ORDER BY req.time ASC
+            ORDER BY requested_at ASC
             LIMIT ?
             OFFSET ?
         "#};
 
-        let bindings = json!({
-            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
-            "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
-            "3": { "type": "TEXT", "value": after_date },
-            "4": { "type": "FIXED", "value": min_minor_str_ref },
-            "5": { "type": "FIXED", "value": min_minor_str_ref },
-            "6": { "type": "FIXED", "value": min_minor_str_ref },
-            "7": { "type": "FIXED", "value": min_patch_str_ref },
-            "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
-            "9": { "type": "FIXED", "value": offset.to_string() }
-        });
-
         let examples = fetch_examples_with_query(
             http_client.clone(),
             &step_progress,
             background_executor.clone(),
             statement,
-            bindings,
-            DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+            QueryRetryState {
+                resume_after: after_date.clone(),
+                remaining_limit: max_rows_per_timestamp,
+                offset,
+            },
+            |retry_state| {
+                json!({
+                    "1": { "type": "TEXT", "value": retry_state.resume_after },
+                    "2": { "type": "FIXED", "value": min_minor_str_ref },
+                    "3": { "type": "FIXED", "value": min_minor_str_ref },
+                    "4": { "type": "FIXED", "value": min_minor_str_ref },
+                    "5": { "type": "FIXED", "value": min_patch_str_ref },
+                    "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
+                    "7": { "type": "FIXED", "value": retry_state.offset.to_string() }
+                })
+            },
             &[
                 "request_id",
                 "device_id",
@@ -486,10 +637,14 @@ pub async fn fetch_rejected_examples_after(
     Ok(all_examples)
 }
 
+fn format_limit(limit: Option<usize>) -> String {
+    return limit.map(|l| l.to_string()).unwrap_or("NULL".to_string());
+}
+
 pub async fn fetch_requested_examples_after(
     http_client: Arc<dyn HttpClient>,
     after_timestamps: &[String],
-    max_rows_per_timestamp: usize,
+    max_rows_per_timestamp: Option<usize>,
     offset: usize,
     background_executor: BackgroundExecutor,
     min_capture_version: Option<MinCaptureVersion>,
@@ -514,46 +669,47 @@ pub async fn fetch_requested_examples_after(
 
         let statement = indoc! {r#"
             SELECT
-                req.event_properties:request_id::string AS request_id,
-                req.device_id::string AS device_id,
-                req.time::string AS time,
-                req.event_properties:input AS input,
-                req.event_properties:zed_version::string AS zed_version
-            FROM events req
-            WHERE req.event_type = ?
-                AND req.event_properties:version = 'V3'
-                AND req.event_properties:input:can_collect_data = true
-                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
+                ep_request_id AS request_id,
+                device_id AS device_id,
+                requested_at::string AS continuation_time,
+                requested_at::string AS time,
+                input_payload AS input,
+                zed_version AS zed_version
+            FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
+            WHERE requested_at > TRY_TO_TIMESTAMP_NTZ(?)
                 AND (? IS NULL OR (
-                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
+                    TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ?
                     OR (
-                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
-                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
+                        TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ?
+                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ?
                     )
                 ))
-            ORDER BY req.time ASC
+            ORDER BY requested_at ASC
             LIMIT ?
             OFFSET ?
         "#};
 
-        let bindings = json!({
-            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
-            "2": { "type": "TEXT", "value": after_date },
-            "3": { "type": "FIXED", "value": min_minor_str_ref },
-            "4": { "type": "FIXED", "value": min_minor_str_ref },
-            "5": { "type": "FIXED", "value": min_minor_str_ref },
-            "6": { "type": "FIXED", "value": min_patch_str_ref },
-            "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
-            "8": { "type": "FIXED", "value": offset.to_string() }
-        });
-
         let examples = fetch_examples_with_query(
             http_client.clone(),
             &step_progress,
             background_executor.clone(),
             statement,
-            bindings,
-            DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+            QueryRetryState {
+                resume_after: after_date.clone(),
+                remaining_limit: max_rows_per_timestamp,
+                offset,
+            },
+            |retry_state| {
+                json!({
+                    "1": { "type": "TEXT", "value": retry_state.resume_after },
+                    "2": { "type": "FIXED", "value": min_minor_str_ref },
+                    "3": { "type": "FIXED", "value": min_minor_str_ref },
+                    "4": { "type": "FIXED", "value": min_minor_str_ref },
+                    "5": { "type": "FIXED", "value": min_patch_str_ref },
+                    "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
+                    "7": { "type": "FIXED", "value": retry_state.offset.to_string() }
+                })
+            },
             &["request_id", "device_id", "time", "input", "zed_version"],
             requested_examples_from_response,
         )
@@ -568,7 +724,7 @@ pub async fn fetch_requested_examples_after(
 pub async fn fetch_captured_examples_after(
     http_client: Arc<dyn HttpClient>,
     after_timestamps: &[String],
-    max_rows_per_timestamp: usize,
+    max_rows_per_timestamp: Option<usize>,
     offset: usize,
     background_executor: BackgroundExecutor,
     min_capture_version: Option<MinCaptureVersion>,
@@ -593,54 +749,51 @@ pub async fn fetch_captured_examples_after(
 
         let statement = indoc! {r#"
             SELECT
-                settled.event_properties:request_id::string AS request_id,
-                settled.device_id::string AS device_id,
-                settled.time::string AS time,
-                req.event_properties:input AS input,
-                settled.event_properties:settled_editable_region::string AS settled_editable_region,
-                settled.event_properties:example AS example,
-                req.event_properties:zed_version::string AS zed_version
-            FROM events settled
-            INNER JOIN events req
-                ON settled.event_properties:request_id::string = req.event_properties:request_id::string
-            WHERE settled.event_type = ?
-                AND req.event_type = ?
-                AND req.event_properties:version = 'V3'
-                AND req.event_properties:input:can_collect_data = true
-                AND settled.event_properties:example IS NOT NULL
-                AND TYPEOF(settled.event_properties:example) != 'NULL_VALUE'
-                AND settled.time > TRY_TO_TIMESTAMP_NTZ(?)
+                ep_request_id AS request_id,
+                device_id AS device_id,
+                requested_at::string AS continuation_time,
+                requested_at::string AS time,
+                input_payload AS input,
+                settled_editable_region AS settled_editable_region,
+                example_payload AS example,
+                zed_version AS zed_version
+            FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
+            WHERE settled_editable_region IS NOT NULL
+                AND example_payload IS NOT NULL
+                AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
                 AND (? IS NULL OR (
-                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
+                    TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ?
                     OR (
-                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
-                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
+                        TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ?
+                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ?
                     )
                 ))
-            ORDER BY settled.time ASC
+            ORDER BY requested_at ASC
             LIMIT ?
             OFFSET ?
         "#};
 
-        let bindings = json!({
-            "1": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
-            "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
-            "3": { "type": "TEXT", "value": after_date },
-            "4": { "type": "FIXED", "value": min_minor_str_ref },
-            "5": { "type": "FIXED", "value": min_minor_str_ref },
-            "6": { "type": "FIXED", "value": min_minor_str_ref },
-            "7": { "type": "FIXED", "value": min_patch_str_ref },
-            "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
-            "9": { "type": "FIXED", "value": offset.to_string() }
-        });
-
         let examples = fetch_examples_with_query(
             http_client.clone(),
             &step_progress,
             background_executor.clone(),
             statement,
-            bindings,
-            DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+            QueryRetryState {
+                resume_after: after_date.clone(),
+                remaining_limit: max_rows_per_timestamp,
+                offset,
+            },
+            |retry_state| {
+                json!({
+                    "1": { "type": "TEXT", "value": retry_state.resume_after },
+                    "2": { "type": "FIXED", "value": min_minor_str_ref },
+                    "3": { "type": "FIXED", "value": min_minor_str_ref },
+                    "4": { "type": "FIXED", "value": min_minor_str_ref },
+                    "5": { "type": "FIXED", "value": min_patch_str_ref },
+                    "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
+                    "7": { "type": "FIXED", "value": retry_state.offset.to_string() }
+                })
+            },
             &[
                 "request_id",
                 "device_id",
@@ -663,7 +816,7 @@ pub async fn fetch_captured_examples_after(
 pub async fn fetch_settled_examples_after(
     http_client: Arc<dyn HttpClient>,
     after_timestamps: &[String],
-    max_rows_per_timestamp: usize,
+    max_rows_per_timestamp: Option<usize>,
     offset: usize,
     background_executor: BackgroundExecutor,
     min_capture_version: Option<MinCaptureVersion>,
@@ -684,55 +837,41 @@ pub async fn fetch_settled_examples_after(
         let _ = min_capture_version;
 
         let statement = indoc! {r#"
-            WITH requested AS (
-                SELECT
-                    req.event_properties:request_id::string AS request_id,
-                    req.device_id::string AS device_id,
-                    req.time AS req_time,
-                    req.time::string AS time,
-                    req.event_properties:input AS input,
-                    req.event_properties:format::string AS requested_format,
-                    req.event_properties:output::string AS requested_output,
-                    req.event_properties:zed_version::string AS zed_version
-                FROM events req
-                WHERE req.event_type = ?
-                    AND req.event_properties:version = 'V3'
-                    AND req.event_properties:input:can_collect_data = true
-                    AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
-            )
             SELECT
-                req.request_id AS request_id,
-                req.device_id AS device_id,
-                req.time AS time,
-                req.input AS input,
-                req.requested_output AS requested_output,
-                settled.event_properties:settled_editable_region::string AS settled_editable_region,
-                req.requested_format AS requested_format,
-                req.zed_version AS zed_version
-            FROM requested req
-            INNER JOIN events settled
-                ON req.request_id = settled.event_properties:request_id::string
-            WHERE settled.event_type = ?
-            ORDER BY req.req_time ASC
+                ep_request_id AS request_id,
+                device_id AS device_id,
+                requested_at::string AS continuation_time,
+                requested_at::string AS time,
+                input_payload AS input,
+                requested_output AS requested_output,
+                settled_editable_region AS settled_editable_region,
+                requested_format AS requested_format,
+                zed_version AS zed_version
+            FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
+            WHERE settled_editable_region IS NOT NULL
+                AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
+            ORDER BY requested_at ASC
             LIMIT ?
             OFFSET ?
         "#};
 
-        let bindings = json!({
-            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
-            "2": { "type": "TEXT", "value": after_date },
-            "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
-            "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
-            "5": { "type": "FIXED", "value": offset.to_string() }
-        });
-
         let examples = fetch_examples_with_query(
             http_client.clone(),
             &step_progress,
             background_executor.clone(),
             statement,
-            bindings,
-            SETTLED_STATEMENT_TIMEOUT_SECONDS,
+            QueryRetryState {
+                resume_after: after_date.clone(),
+                remaining_limit: max_rows_per_timestamp,
+                offset,
+            },
+            |retry_state| {
+                json!({
+                    "1": { "type": "TEXT", "value": retry_state.resume_after },
+                    "2": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
+                    "3": { "type": "FIXED", "value": retry_state.offset.to_string() }
+                })
+            },
             &[
                 "request_id",
                 "device_id",
@@ -756,7 +895,7 @@ pub async fn fetch_settled_examples_after(
 pub async fn fetch_rated_examples_after(
     http_client: Arc<dyn HttpClient>,
     inputs: &[(String, Option<EditPredictionRating>)],
-    max_rows_per_timestamp: usize,
+    max_rows_per_timestamp: Option<usize>,
     offset: usize,
     background_executor: BackgroundExecutor,
     _min_capture_version: Option<MinCaptureVersion>,
@@ -786,54 +925,48 @@ pub async fn fetch_rated_examples_after(
 
         let statement = indoc! {r#"
             SELECT
-                rated.event_properties:request_id::string AS request_id,
-                rated.event_properties:inputs AS inputs,
-                rated.event_properties:output::string AS output,
-                rated.event_properties:rating::string AS rating,
-                rated.event_properties:feedback::string AS feedback,
-                rated.device_id::string AS device_id,
-                rated.time::string AS time,
-                deploy.event_properties:experiment_name::string AS experiment_name,
-                deploy.event_properties:environment::string AS environment,
-                rated.event_properties:zed_version::string AS zed_version
-            FROM events rated
-            LEFT JOIN events req
-                ON rated.event_properties:request_id::string = req.event_properties:request_id::string
-                AND req.event_type = ?
-            LEFT JOIN events deploy
-                ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
-                AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
-                AND deploy.event_type = ?
-            WHERE rated.event_type = ?
-                AND (? IS NULL OR rated.event_properties:rating::string = ?)
-                AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
-                AND rated.event_properties:inputs IS NOT NULL
-                AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
-                AND rated.event_properties:output IS NOT NULL
-                AND rated.event_properties:inputs:can_collect_data = true
-            ORDER BY rated.time ASC
+                ep_request_id AS request_id,
+                rated_inputs AS inputs,
+                rated_output AS output,
+                rating AS rating,
+                feedback AS feedback,
+                device_id AS device_id,
+                requested_at::string AS continuation_time,
+                requested_at::string AS time,
+                NULL AS experiment_name,
+                NULL AS environment,
+                zed_version AS zed_version
+            FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
+            WHERE rating IS NOT NULL
+                AND (? IS NULL OR rating = ?)
+                AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
+                AND rated_inputs IS NOT NULL
+                AND rated_inputs:cursor_excerpt IS NOT NULL
+                AND rated_output IS NOT NULL
+            ORDER BY requested_at ASC
             LIMIT ?
             OFFSET ?
         "#};
 
-        let bindings = json!({
-            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
-            "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
-            "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
-            "4": { "type": "TEXT", "value": rating_value },
-            "5": { "type": "TEXT", "value": rating_value },
-            "6": { "type": "TEXT", "value": after_date },
-            "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
-            "8": { "type": "FIXED", "value": offset.to_string() }
-        });
-
         let examples = fetch_examples_with_query(
             http_client.clone(),
             &step_progress,
             background_executor.clone(),
             statement,
-            bindings,
-            DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+            QueryRetryState {
+                resume_after: after_date.clone(),
+                remaining_limit: max_rows_per_timestamp,
+                offset,
+            },
+            |retry_state| {
+                json!({
+                    "1": { "type": "TEXT", "value": rating_value },
+                    "2": { "type": "TEXT", "value": rating_value },
+                    "3": { "type": "TEXT", "value": retry_state.resume_after },
+                    "4": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
+                    "5": { "type": "FIXED", "value": retry_state.offset.to_string() }
+                })
+            },
             &[
                 "request_id",
                 "inputs",
@@ -1473,6 +1606,7 @@ fn rejected_examples_from_response<'a>(
             let input_json = get_json("input");
             let input: Option<ZetaPromptInput> =
                 input_json.clone().and_then(|v| serde_json::from_value(v).ok());
+            let prompt = get_string("prompt");
             let output = get_string("output");
             let was_shown = get_bool("was_shown");
             let reason = get_string("reason");
@@ -1485,6 +1619,7 @@ fn rejected_examples_from_response<'a>(
                         device_id,
                         time,
                         input,
+                        prompt,
                         output,
                         was_shown,
                         reason,
@@ -1515,6 +1650,7 @@ fn build_rejected_example(
     device_id: String,
     time: String,
     input: ZetaPromptInput,
+    prompt: Option<String>,
     output: String,
     was_shown: bool,
     reason: String,
@@ -1536,6 +1672,13 @@ fn build_rejected_example(
         zed_version,
     );
     example.spec.rejected_patch = Some(rejected_patch);
+    example.prompt = prompt.map(|prompt| ExamplePrompt {
+        input: prompt,
+        expected_output: String::new(),
+        rejected_output: Some(output),
+        prefill: None,
+        provider: PredictionProvider::default(),
+    });
     example
 }
 
@@ -1635,11 +1778,42 @@ fn build_output_patch(
     patch
 }
 
+fn is_timeout_response(response: &SnowflakeStatementResponse) -> bool {
+    response.code.as_deref() == Some(SNOWFLAKE_TIMEOUT_CODE)
+        && response
+            .message
+            .as_deref()
+            .map(|message| message.to_ascii_lowercase().contains("timeout"))
+            .unwrap_or(false)
+}
+
+fn is_snowflake_timeout_error(error: &anyhow::Error) -> bool {
+    error
+        .chain()
+        .any(|cause| cause.to_string().contains(SNOWFLAKE_TIMEOUT_CODE))
+}
+
+fn last_continuation_timestamp_from_response(
+    response: &SnowflakeStatementResponse,
+    column_indices: &HashMap<String, usize>,
+) -> Option<String> {
+    let continuation_time_index = column_indices.get("continuation_time").copied()?;
+    response
+        .data
+        .iter()
+        .rev()
+        .find_map(|row| match row.get(continuation_time_index)? {
+            JsonValue::String(value) => Some(value.clone()),
+            JsonValue::Null => None,
+            other => Some(other.to_string()),
+        })
+}
+
 pub(crate) fn get_column_indices(
     meta: &Option<SnowflakeResultSetMetaData>,
     names: &[&str],
-) -> std::collections::HashMap<String, usize> {
-    let mut indices = std::collections::HashMap::new();
+) -> HashMap<String, usize> {
+    let mut indices = HashMap::new();
     if let Some(meta) = meta {
         for (index, col) in meta.row_type.iter().enumerate() {
             for &name in names {

crates/edit_prediction_cli/src/repair.rs 🔗

@@ -15,7 +15,7 @@ use crate::{
     openai_client::OpenAiClient,
     parse_output::run_parse_output,
     paths::LLM_CACHE_DB,
-    progress::{ExampleProgress, Step},
+    progress::{ExampleProgress, Progress, Step},
     word_diff::unified_to_word_diff,
 };
 use anyhow::{Context as _, Result};
@@ -75,6 +75,9 @@ pub struct RepairArgs {
     /// Which LLM provider to use (anthropic or openai)
     #[clap(long, default_value = "anthropic")]
     pub backend: BatchProvider,
+    /// Wait for all batches to complete before exiting
+    #[clap(long)]
+    pub wait: bool,
 }
 
 fn model_for_backend(backend: BatchProvider) -> &'static str {
@@ -454,6 +457,68 @@ pub async fn sync_batches(args: &RepairArgs) -> Result<()> {
     Ok(())
 }
 
+pub async fn reprocess_after_batch_wait(examples: &mut [Example], args: &RepairArgs) -> Result<()> {
+    let mut reprocessed = 0;
+    for example in examples.iter_mut() {
+        if has_successful_repair(example) || !needs_repair(example, args.confidence_threshold) {
+            continue;
+        }
+
+        let example_progress = Progress::global().start_group(&example.spec.name);
+        run_repair(example, args, &example_progress).await?;
+        reprocessed += 1;
+    }
+
+    if reprocessed > 0 {
+        eprintln!("Reprocessed {} example(s) with batch results", reprocessed);
+    }
+
+    Ok(())
+}
+
+pub async fn wait_for_batches(args: &RepairArgs) -> Result<()> {
+    if args.no_batch {
+        return Ok(());
+    }
+
+    let poll_interval = std::time::Duration::from_secs(30);
+
+    loop {
+        let pending = pending_batch_count(args)?;
+        if pending == 0 {
+            break;
+        }
+
+        eprintln!(
+            "Waiting for {} pending repair batch request(s) to complete... (polling every {}s)",
+            pending,
+            poll_interval.as_secs()
+        );
+        std::thread::sleep(poll_interval);
+
+        sync_batches(args).await?;
+    }
+
+    Ok(())
+}
+
+fn pending_batch_count(args: &RepairArgs) -> Result<usize> {
+    match args.backend {
+        BatchProvider::Anthropic => {
+            let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| {
+                AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client")
+            });
+            client.pending_batch_count()
+        }
+        BatchProvider::Openai => {
+            let client = OPENAI_CLIENT_BATCH.get_or_init(|| {
+                OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client")
+            });
+            client.pending_batch_count()
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;