From da401c4e7b999989f96d2477125acee90414d29c Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Tue, 24 Mar 2026 19:42:27 +0200 Subject: [PATCH] ep: Automate distillation (#51293) Increases timeouts, simplifies queries, and makes other changes needed to pull data more efficiently. Release Notes: - N/A --------- Co-authored-by: Ben Kunkle --- .../src/anthropic_client.rs | 18 + crates/edit_prediction_cli/src/main.rs | 68 +- .../edit_prediction_cli/src/openai_client.rs | 16 + crates/edit_prediction_cli/src/predict.rs | 85 ++- .../edit_prediction_cli/src/pull_examples.rs | 718 +++++++++++------- crates/edit_prediction_cli/src/repair.rs | 67 +- 6 files changed, 697 insertions(+), 275 deletions(-) diff --git a/crates/edit_prediction_cli/src/anthropic_client.rs b/crates/edit_prediction_cli/src/anthropic_client.rs index 869635c53a15e5c3f6cdaca7632a3e99f0b0bec1..7841e8a2cc1f5236697cae46f071123607c0b2d7 100644 --- a/crates/edit_prediction_cli/src/anthropic_client.rs +++ b/crates/edit_prediction_cli/src/anthropic_client.rs @@ -292,6 +292,14 @@ impl BatchingLlmClient { self.download_finished_batches().await } + pub fn pending_batch_count(&self) -> Result { + let connection = self.connection.lock().unwrap(); + let counts: Vec = connection.select( + sql!(SELECT COUNT(*) FROM cache WHERE batch_id IS NOT NULL AND response IS NULL), + )?()?; + Ok(counts.into_iter().next().unwrap_or(0) as usize) + } + /// Import batch results from external batch IDs (useful for recovering after database loss) pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> { for batch_id in batch_ids { @@ -831,6 +839,16 @@ impl AnthropicClient { } } + pub fn pending_batch_count(&self) -> Result { + match self { + AnthropicClient::Plain(_) => Ok(0), + AnthropicClient::Batch(batching_llm_client) => { + batching_llm_client.pending_batch_count() + } + AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"), + } + } + pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> { match self { AnthropicClient::Plain(_) => { diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 1dcd1d4aa3ad34df853e9d7b193c246f151a61b2..06fdbadbf53ce0f9f84b909081691c0097c4c5a4 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -26,6 +26,7 @@ mod split_dataset; mod synthesize; mod truncate_expected_patch; mod word_diff; +use anyhow::Context as _; use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum}; use collections::{HashMap, HashSet}; use edit_prediction::EditPredictionStore; @@ -294,6 +295,9 @@ struct PredictArgs { /// Only use cached responses, don't queue new requests for batching #[clap(long)] cache_only: bool, + /// Wait for all batches to complete before exiting (only applies to batched providers like teacher) + #[clap(long)] + wait: bool, } #[derive(Debug, Args, Clone)] @@ -762,7 +766,7 @@ async fn load_examples( "skipping Snowflake inputs because --limit is already satisfied by example files" ); } else { - let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000); + let max_rows_per_timestamp = remaining_limit_for_snowflake; if !rejected_after_timestamps.is_empty() { rejected_after_timestamps.sort(); @@ -1339,18 +1343,45 @@ fn main() { Progress::global().finalize(); + let is_markdown = args.markdown; + let write_path = in_place_temp_path.as_ref().or(output.as_ref()); match &command { Command::Predict(args) | Command::Score(args) => { predict::sync_batches(args.provider.as_ref()).await?; + if args.wait { + predict::wait_for_batches(args.provider.as_ref()).await?; + let mut examples = + std::mem::take(&mut *finished_examples.lock().unwrap()); + predict::reprocess_after_batch_wait(&mut examples, args).await?; + rewrite_output(&examples, write_path, is_markdown)?; + *finished_examples.lock().unwrap() = examples; + } } Command::Eval(args) => { predict::sync_batches(args.predict.provider.as_ref()).await?; + if args.predict.wait { + predict::wait_for_batches(args.predict.provider.as_ref()).await?; + let mut examples = + std::mem::take(&mut *finished_examples.lock().unwrap()); + predict::reprocess_after_batch_wait(&mut examples, &args.predict) + .await?; + rewrite_output(&examples, write_path, is_markdown)?; + *finished_examples.lock().unwrap() = examples; + } } Command::Qa(args) => { qa::sync_batches(args).await?; } Command::Repair(args) => { repair::sync_batches(args).await?; + if args.wait { + repair::wait_for_batches(args).await?; + let mut examples = + std::mem::take(&mut *finished_examples.lock().unwrap()); + repair::reprocess_after_batch_wait(&mut examples, args).await?; + rewrite_output(&examples, write_path, is_markdown)?; + *finished_examples.lock().unwrap() = examples; + } } _ => (), } @@ -1391,6 +1422,41 @@ fn main() { }); } +fn rewrite_output( + examples: &[Example], + output_path: Option<&PathBuf>, + markdown: bool, +) -> anyhow::Result<()> { + if markdown { + let dir = output_path.context("--markdown requires -o")?; + for example in examples { + let filename = format!("{}.md", example.spec.filename()); + let path = dir.join(&filename); + let markdown = example.spec.to_markdown(); + std::fs::write(&path, &markdown).context("Failed to write markdown file")?; + } + } else if let Some(path) = output_path { + let file = OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(path) + .context("Failed to open output file for rewriting")?; + let mut writer = BufWriter::new(file); + for example in examples { + let line = serde_json::to_string(example)?; + writeln!(writer, "{}", line)?; + } + writer.flush()?; + } else { + for example in examples { + let line = serde_json::to_string(example)?; + println!("{}", line); + } + } + Ok(()) +} + async fn handle_error( error: anyhow::Error, args: &EpArgs, diff --git a/crates/edit_prediction_cli/src/openai_client.rs b/crates/edit_prediction_cli/src/openai_client.rs index 6bc9c2d77c0d6be6e2955182ebbce096be422945..e35848aa1ccbd46d29f88a6c9a0ccfd35309114a 100644 --- a/crates/edit_prediction_cli/src/openai_client.rs +++ b/crates/edit_prediction_cli/src/openai_client.rs @@ -214,6 +214,14 @@ impl BatchingOpenAiClient { self.download_finished_batches().await } + pub fn pending_batch_count(&self) -> Result { + let connection = self.connection.lock().unwrap(); + let counts: Vec = connection.select( + sql!(SELECT COUNT(*) FROM openai_cache WHERE batch_id IS NOT NULL AND response IS NULL), + )?()?; + Ok(counts.into_iter().next().unwrap_or(0) as usize) + } + pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> { for batch_id in batch_ids { log::info!("Importing OpenAI batch {}", batch_id); @@ -672,6 +680,14 @@ impl OpenAiClient { } } + pub fn pending_batch_count(&self) -> Result { + match self { + OpenAiClient::Plain(_) => Ok(0), + OpenAiClient::Batch(batching_client) => batching_client.pending_batch_count(), + OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"), + } + } + pub async fn import_batches(&self, batch_ids: &[String]) -> Result<()> { match self { OpenAiClient::Plain(_) => { diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index df797b0abaa4933e73e40b746797ffb5581d7f79..1effca9d21a297d28ebf1eab738beead9f1af837 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -8,7 +8,7 @@ use crate::{ openai_client::OpenAiClient, parse_output::parse_prediction_output, paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR}, - progress::{ExampleProgress, InfoStyle, Step, StepProgress}, + progress::{ExampleProgress, InfoStyle, Progress, Step, StepProgress}, retrieve_context::run_context_retrieval, }; use anyhow::Context as _; @@ -699,3 +699,86 @@ pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Resu }; Ok(()) } + +pub async fn reprocess_after_batch_wait( + examples: &mut [Example], + args: &PredictArgs, +) -> anyhow::Result<()> { + let Some(PredictionProvider::Teacher(backend)) = args.provider else { + return Ok(()); + }; + + let mut reprocessed = 0; + for example in examples.iter_mut() { + let has_prediction = example + .predictions + .iter() + .any(|p| p.actual_patch.is_some() || !p.actual_output.is_empty()); + if has_prediction || example.prompt.is_none() { + continue; + } + + let example_progress = Progress::global().start_group(&example.spec.name); + let step_progress = example_progress.start(Step::Predict); + predict_teacher( + example, + backend, + true, + args.repetitions, + false, + &step_progress, + ) + .await?; + reprocessed += 1; + } + + if reprocessed > 0 { + eprintln!("Reprocessed {} example(s) with batch results", reprocessed); + } + + Ok(()) +} + +pub async fn wait_for_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> { + let poll_interval = std::time::Duration::from_secs(30); + + loop { + let pending = pending_batch_count(provider)?; + if pending == 0 { + break; + } + + eprintln!( + "Waiting for {} pending batch request(s) to complete... (polling every {}s)", + pending, + poll_interval.as_secs() + ); + std::thread::sleep(poll_interval); + + sync_batches(provider).await?; + } + + Ok(()) +} + +fn pending_batch_count(provider: Option<&PredictionProvider>) -> anyhow::Result { + match provider { + Some(PredictionProvider::Teacher(backend)) => match backend { + TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => { + let llm_client = ANTHROPIC_CLIENT.get_or_init(|| { + AnthropicClient::batch(&crate::paths::LLM_CACHE_DB) + .expect("Failed to create Anthropic client") + }); + llm_client.pending_batch_count() + } + TeacherBackend::Gpt52 => { + let llm_client = OPENAI_CLIENT.get_or_init(|| { + OpenAiClient::batch(&crate::paths::LLM_CACHE_DB) + .expect("Failed to create OpenAI client") + }); + llm_client.pending_batch_count() + } + }, + _ => Ok(0), + } +} diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index 15591ae03ccd7b0d537b437c1da2c0898e7e9446..9ea8ac3bda1fa17295dab29bb3d5c78eaa54d765 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -5,6 +5,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request}; use indoc::indoc; use serde::Deserialize; use serde_json::{Value as JsonValue, json}; +use std::collections::HashMap; use std::fmt::Write as _; use std::io::Read; use std::sync::Arc; @@ -13,17 +14,14 @@ use telemetry_events::EditPredictionRating; use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format}; -use crate::example::Example; +use crate::PredictionProvider; +use crate::example::{Example, ExamplePrompt}; use crate::progress::{InfoStyle, Progress, Step}; -const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment"; use edit_prediction::example_spec::{ExampleSpec, TelemetrySource}; pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001"; pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334"; -const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested"; -const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected"; -const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated"; -const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled"; +const SNOWFLAKE_TIMEOUT_CODE: &str = "000630"; /// Minimum Zed version for filtering captured examples. /// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples @@ -34,10 +32,13 @@ pub struct MinCaptureVersion { pub patch: u32, } -const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 240; -const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240; pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2); -pub(crate) const MAX_POLL_ATTEMPTS: usize = 120; +const PARTITION_FETCH_MAX_RETRIES: usize = 3; +const PARTITION_FETCH_RETRY_DELAYS: [Duration; PARTITION_FETCH_MAX_RETRIES] = [ + Duration::from_millis(500), + Duration::from_secs(1), + Duration::from_secs(2), +]; /// Parse an input token of the form `captured-after:{timestamp}`. pub fn parse_captured_after_input(input: &str) -> Option<&str> { @@ -127,26 +128,25 @@ async fn run_sql_with_polling( .context("async query response missing statementHandle")? .clone(); - for attempt in 1..=MAX_POLL_ATTEMPTS { + for attempt in 0.. { step_progress.set_substatus(format!("polling ({attempt})")); background_executor.timer(POLL_INTERVAL).await; - response = - fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?; + response = fetch_partition_with_retries( + http_client.clone(), + base_url, + token, + &statement_handle, + 0, + background_executor.clone(), + ) + .await?; if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) { break; } } - - if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) { - anyhow::bail!( - "query still running after {} poll attempts ({} seconds)", - MAX_POLL_ATTEMPTS, - MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs() - ); - } } Ok(response) @@ -158,19 +158,29 @@ struct SnowflakeConfig { role: Option, } -async fn fetch_examples_with_query( +#[derive(Clone)] +struct QueryRetryState { + resume_after: String, + remaining_limit: Option, + offset: usize, +} + +async fn fetch_examples_with_query( http_client: Arc, step_progress: &crate::progress::StepProgress, background_executor: BackgroundExecutor, statement: &str, - bindings: JsonValue, - timeout_seconds: u64, + initial_retry_state: QueryRetryState, + make_bindings: MakeBindings, required_columns: &[&str], parse_response: for<'a> fn( &'a SnowflakeStatementResponse, - &'a std::collections::HashMap, + &'a HashMap, ) -> Result + 'a>>, -) -> Result> { +) -> Result> +where + MakeBindings: Fn(&QueryRetryState) -> JsonValue, +{ let snowflake = SnowflakeConfig { token: std::env::var("EP_SNOWFLAKE_API_KEY") .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?, @@ -179,74 +189,153 @@ async fn fetch_examples_with_query( )?, role: std::env::var("EP_SNOWFLAKE_ROLE").ok(), }; - let request = json!({ - "statement": statement, - "timeout": timeout_seconds, - "database": "EVENTS", - "schema": "PUBLIC", - "warehouse": "DBT", - "role": snowflake.role.as_deref(), - "bindings": bindings - }); - let response = run_sql_with_polling( - http_client.clone(), - &snowflake.base_url, - &snowflake.token, - &request, - step_progress, - background_executor, - ) - .await?; - - let total_rows = response - .result_set_meta_data - .as_ref() - .and_then(|meta| meta.num_rows) - .unwrap_or(response.data.len() as i64); - let partition_count = response - .result_set_meta_data - .as_ref() - .map(|meta| meta.partition_info.len()) - .unwrap_or(1) - .max(1); - - step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal); - step_progress.set_substatus("parsing"); - - let column_indices = get_column_indices(&response.result_set_meta_data, required_columns); - - let mut parsed_examples = Vec::with_capacity(total_rows as usize); - parsed_examples.extend(parse_response(&response, &column_indices)?); - - if partition_count > 1 { - let statement_handle = response - .statement_handle + let mut requested_columns = required_columns.to_vec(); + if !requested_columns.contains(&"continuation_time") { + requested_columns.push("continuation_time"); + } + + let mut parsed_examples = Vec::new(); + let mut retry_state = initial_retry_state; + let mut retry_count = 0usize; + + loop { + let bindings = make_bindings(&retry_state); + let request = json!({ + "statement": statement, + "database": "EVENTS", + "schema": "PUBLIC", + "warehouse": "DBT", + "role": snowflake.role.as_deref(), + "bindings": bindings + }); + + let response = match run_sql_with_polling( + http_client.clone(), + &snowflake.base_url, + &snowflake.token, + &request, + step_progress, + background_executor.clone(), + ) + .await + { + Ok(response) => response, + Err(error) => { + if is_snowflake_timeout_error(&error) && !parsed_examples.is_empty() { + retry_count += 1; + step_progress.set_substatus(format!( + "retrying from {} ({retry_count})", + retry_state.resume_after + )); + continue; + } + + return Err(error); + } + }; + + let total_rows = response + .result_set_meta_data + .as_ref() + .and_then(|meta| meta.num_rows) + .unwrap_or(response.data.len() as i64); + let partition_count = response + .result_set_meta_data .as_ref() - .context("response has multiple partitions but no statementHandle")?; + .map(|meta| meta.partition_info.len()) + .unwrap_or(1) + .max(1); - for partition in 1..partition_count { - step_progress.set_substatus(format!( - "fetching partition {}/{}", - partition + 1, - partition_count - )); + step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal); + step_progress.set_substatus("parsing"); - let partition_response = fetch_partition( - http_client.clone(), - &snowflake.base_url, - &snowflake.token, - statement_handle, - partition, - ) - .await?; + let column_indices = get_column_indices(&response.result_set_meta_data, &requested_columns); + let mut rows_fetched_this_attempt = 0usize; + let mut timed_out_fetching_partition = false; + + parsed_examples.extend(parse_response(&response, &column_indices)?); + rows_fetched_this_attempt += response.data.len(); + let mut last_continuation_time_this_attempt = + last_continuation_timestamp_from_response(&response, &column_indices); - parsed_examples.extend(parse_response(&partition_response, &column_indices)?); + if partition_count > 1 { + let statement_handle = response + .statement_handle + .as_ref() + .context("response has multiple partitions but no statementHandle")?; + + for partition in 1..partition_count { + step_progress.set_substatus(format!( + "fetching partition {}/{}", + partition + 1, + partition_count + )); + + let partition_response = match fetch_partition_with_retries( + http_client.clone(), + &snowflake.base_url, + &snowflake.token, + statement_handle, + partition, + background_executor.clone(), + ) + .await + { + Ok(response) => response, + Err(error) => { + if is_snowflake_timeout_error(&error) && rows_fetched_this_attempt > 0 { + timed_out_fetching_partition = true; + break; + } + + return Err(error); + } + }; + + parsed_examples.extend(parse_response(&partition_response, &column_indices)?); + rows_fetched_this_attempt += partition_response.data.len(); + + if let Some(partition_continuation_time) = + last_continuation_timestamp_from_response(&partition_response, &column_indices) + { + last_continuation_time_this_attempt = Some(partition_continuation_time); + } + } } - } - step_progress.set_substatus("done"); - Ok(parsed_examples) + if rows_fetched_this_attempt == 0 { + step_progress.set_substatus("done"); + return Ok(parsed_examples); + } + + if let Some(remaining_limit_value) = &mut retry_state.remaining_limit { + *remaining_limit_value = + remaining_limit_value.saturating_sub(rows_fetched_this_attempt); + if *remaining_limit_value == 0 { + step_progress.set_substatus("done"); + return Ok(parsed_examples); + } + } + + if !timed_out_fetching_partition { + step_progress.set_substatus("done"); + return Ok(parsed_examples); + } + + let Some(last_continuation_time_this_attempt) = last_continuation_time_this_attempt else { + step_progress.set_substatus("done"); + return Ok(parsed_examples); + }; + + retry_state.resume_after = last_continuation_time_this_attempt; + retry_state.offset = 0; + retry_count += 1; + step_progress.set_substatus(format!( + "retrying from {} ({retry_count})", + retry_state.resume_after + )); + } } pub(crate) async fn fetch_partition( @@ -338,6 +427,57 @@ pub(crate) async fn fetch_partition( }) } +async fn fetch_partition_with_retries( + http_client: Arc, + base_url: &str, + token: &str, + statement_handle: &str, + partition: usize, + background_executor: BackgroundExecutor, +) -> Result { + let mut last_error = None; + + for retry_attempt in 0..=PARTITION_FETCH_MAX_RETRIES { + match fetch_partition( + http_client.clone(), + base_url, + token, + statement_handle, + partition, + ) + .await + { + Ok(response) => return Ok(response), + Err(error) => { + if retry_attempt == PARTITION_FETCH_MAX_RETRIES + || !is_transient_partition_fetch_error(&error) + { + return Err(error); + } + + last_error = Some(error); + background_executor + .timer(PARTITION_FETCH_RETRY_DELAYS[retry_attempt]) + .await; + } + } + } + + match last_error { + Some(error) => Err(error), + None => anyhow::bail!("partition fetch retry loop exited without a result"), + } +} + +fn is_transient_partition_fetch_error(error: &anyhow::Error) -> bool { + error.chain().any(|cause| { + let message = cause.to_string(); + message.contains("failed to read Snowflake SQL API partition response body") + || message.contains("unexpected EOF") + || message.contains("peer closed connection without sending TLS close_notify") + }) +} + pub(crate) async fn run_sql( http_client: Arc, base_url: &str, @@ -379,19 +519,32 @@ pub(crate) async fn run_sql( bytes }; - if !status.is_success() && status.as_u16() != 202 { + let snowflake_response = serde_json::from_slice::(&body_bytes) + .context("failed to parse Snowflake SQL API response JSON")?; + + if !status.is_success() && status.as_u16() != 202 && !is_timeout_response(&snowflake_response) { let body_text = String::from_utf8_lossy(&body_bytes); anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text); } - serde_json::from_slice::(&body_bytes) - .context("failed to parse Snowflake SQL API response JSON") + if is_timeout_response(&snowflake_response) { + anyhow::bail!( + "snowflake sql api timed out code={} message={}", + snowflake_response.code.as_deref().unwrap_or(""), + snowflake_response + .message + .as_deref() + .unwrap_or("") + ); + } + + Ok(snowflake_response) } pub async fn fetch_rejected_examples_after( http_client: Arc, after_timestamps: &[String], - max_rows_per_timestamp: usize, + max_rows_per_timestamp: Option, offset: usize, background_executor: BackgroundExecutor, min_capture_version: Option, @@ -416,55 +569,53 @@ pub async fn fetch_rejected_examples_after( let statement = indoc! {r#" SELECT - req.event_properties:request_id::string AS request_id, - req.device_id::string AS device_id, - req.time::string AS time, - req.event_properties:input AS input, - req.event_properties:prompt::string AS prompt, - req.event_properties:output::string AS output, - rej.event_properties:was_shown::boolean AS was_shown, - rej.event_properties:reason::string AS reason, - req.event_properties:zed_version::string AS zed_version - FROM events req - INNER JOIN events rej - ON req.event_properties:request_id = rej.event_properties:request_id - WHERE req.event_type = ? - AND rej.event_type = ? - AND req.event_properties:version = 'V3' - AND rej.event_properties:was_shown = true - AND req.event_properties:input:can_collect_data = true - AND req.time > TRY_TO_TIMESTAMP_NTZ(?) + ep_request_id AS request_id, + device_id AS device_id, + requested_at::string AS continuation_time, + requested_at::string AS time, + input_payload AS input, + prompt AS prompt, + requested_output AS output, + is_ep_shown_before_rejected AS was_shown, + ep_rejected_reason AS reason, + zed_version AS zed_version + FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples + WHERE ep_outcome LIKE 'Rejected%' + AND is_ep_shown_before_rejected = true + AND requested_at > TRY_TO_TIMESTAMP_NTZ(?) AND (? IS NULL OR ( - TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ? + TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ? OR ( - TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ? - AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ? + TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ? + AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ? ) )) - ORDER BY req.time ASC + ORDER BY requested_at ASC LIMIT ? OFFSET ? "#}; - let bindings = json!({ - "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, - "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT }, - "3": { "type": "TEXT", "value": after_date }, - "4": { "type": "FIXED", "value": min_minor_str_ref }, - "5": { "type": "FIXED", "value": min_minor_str_ref }, - "6": { "type": "FIXED", "value": min_minor_str_ref }, - "7": { "type": "FIXED", "value": min_patch_str_ref }, - "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }, - "9": { "type": "FIXED", "value": offset.to_string() } - }); - let examples = fetch_examples_with_query( http_client.clone(), &step_progress, background_executor.clone(), statement, - bindings, - DEFAULT_STATEMENT_TIMEOUT_SECONDS, + QueryRetryState { + resume_after: after_date.clone(), + remaining_limit: max_rows_per_timestamp, + offset, + }, + |retry_state| { + json!({ + "1": { "type": "TEXT", "value": retry_state.resume_after }, + "2": { "type": "FIXED", "value": min_minor_str_ref }, + "3": { "type": "FIXED", "value": min_minor_str_ref }, + "4": { "type": "FIXED", "value": min_minor_str_ref }, + "5": { "type": "FIXED", "value": min_patch_str_ref }, + "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) }, + "7": { "type": "FIXED", "value": retry_state.offset.to_string() } + }) + }, &[ "request_id", "device_id", @@ -486,10 +637,14 @@ pub async fn fetch_rejected_examples_after( Ok(all_examples) } +fn format_limit(limit: Option) -> String { + return limit.map(|l| l.to_string()).unwrap_or("NULL".to_string()); +} + pub async fn fetch_requested_examples_after( http_client: Arc, after_timestamps: &[String], - max_rows_per_timestamp: usize, + max_rows_per_timestamp: Option, offset: usize, background_executor: BackgroundExecutor, min_capture_version: Option, @@ -514,46 +669,47 @@ pub async fn fetch_requested_examples_after( let statement = indoc! {r#" SELECT - req.event_properties:request_id::string AS request_id, - req.device_id::string AS device_id, - req.time::string AS time, - req.event_properties:input AS input, - req.event_properties:zed_version::string AS zed_version - FROM events req - WHERE req.event_type = ? - AND req.event_properties:version = 'V3' - AND req.event_properties:input:can_collect_data = true - AND req.time > TRY_TO_TIMESTAMP_NTZ(?) + ep_request_id AS request_id, + device_id AS device_id, + requested_at::string AS continuation_time, + requested_at::string AS time, + input_payload AS input, + zed_version AS zed_version + FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples + WHERE requested_at > TRY_TO_TIMESTAMP_NTZ(?) AND (? IS NULL OR ( - TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ? + TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ? OR ( - TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ? - AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ? + TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ? + AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ? ) )) - ORDER BY req.time ASC + ORDER BY requested_at ASC LIMIT ? OFFSET ? "#}; - let bindings = json!({ - "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, - "2": { "type": "TEXT", "value": after_date }, - "3": { "type": "FIXED", "value": min_minor_str_ref }, - "4": { "type": "FIXED", "value": min_minor_str_ref }, - "5": { "type": "FIXED", "value": min_minor_str_ref }, - "6": { "type": "FIXED", "value": min_patch_str_ref }, - "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }, - "8": { "type": "FIXED", "value": offset.to_string() } - }); - let examples = fetch_examples_with_query( http_client.clone(), &step_progress, background_executor.clone(), statement, - bindings, - DEFAULT_STATEMENT_TIMEOUT_SECONDS, + QueryRetryState { + resume_after: after_date.clone(), + remaining_limit: max_rows_per_timestamp, + offset, + }, + |retry_state| { + json!({ + "1": { "type": "TEXT", "value": retry_state.resume_after }, + "2": { "type": "FIXED", "value": min_minor_str_ref }, + "3": { "type": "FIXED", "value": min_minor_str_ref }, + "4": { "type": "FIXED", "value": min_minor_str_ref }, + "5": { "type": "FIXED", "value": min_patch_str_ref }, + "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) }, + "7": { "type": "FIXED", "value": retry_state.offset.to_string() } + }) + }, &["request_id", "device_id", "time", "input", "zed_version"], requested_examples_from_response, ) @@ -568,7 +724,7 @@ pub async fn fetch_requested_examples_after( pub async fn fetch_captured_examples_after( http_client: Arc, after_timestamps: &[String], - max_rows_per_timestamp: usize, + max_rows_per_timestamp: Option, offset: usize, background_executor: BackgroundExecutor, min_capture_version: Option, @@ -593,54 +749,51 @@ pub async fn fetch_captured_examples_after( let statement = indoc! {r#" SELECT - settled.event_properties:request_id::string AS request_id, - settled.device_id::string AS device_id, - settled.time::string AS time, - req.event_properties:input AS input, - settled.event_properties:settled_editable_region::string AS settled_editable_region, - settled.event_properties:example AS example, - req.event_properties:zed_version::string AS zed_version - FROM events settled - INNER JOIN events req - ON settled.event_properties:request_id::string = req.event_properties:request_id::string - WHERE settled.event_type = ? - AND req.event_type = ? - AND req.event_properties:version = 'V3' - AND req.event_properties:input:can_collect_data = true - AND settled.event_properties:example IS NOT NULL - AND TYPEOF(settled.event_properties:example) != 'NULL_VALUE' - AND settled.time > TRY_TO_TIMESTAMP_NTZ(?) + ep_request_id AS request_id, + device_id AS device_id, + requested_at::string AS continuation_time, + requested_at::string AS time, + input_payload AS input, + settled_editable_region AS settled_editable_region, + example_payload AS example, + zed_version AS zed_version + FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples + WHERE settled_editable_region IS NOT NULL + AND example_payload IS NOT NULL + AND requested_at > TRY_TO_TIMESTAMP_NTZ(?) AND (? IS NULL OR ( - TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ? + TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ? OR ( - TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ? - AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ? + TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ? + AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ? ) )) - ORDER BY settled.time ASC + ORDER BY requested_at ASC LIMIT ? OFFSET ? "#}; - let bindings = json!({ - "1": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT }, - "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, - "3": { "type": "TEXT", "value": after_date }, - "4": { "type": "FIXED", "value": min_minor_str_ref }, - "5": { "type": "FIXED", "value": min_minor_str_ref }, - "6": { "type": "FIXED", "value": min_minor_str_ref }, - "7": { "type": "FIXED", "value": min_patch_str_ref }, - "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }, - "9": { "type": "FIXED", "value": offset.to_string() } - }); - let examples = fetch_examples_with_query( http_client.clone(), &step_progress, background_executor.clone(), statement, - bindings, - DEFAULT_STATEMENT_TIMEOUT_SECONDS, + QueryRetryState { + resume_after: after_date.clone(), + remaining_limit: max_rows_per_timestamp, + offset, + }, + |retry_state| { + json!({ + "1": { "type": "TEXT", "value": retry_state.resume_after }, + "2": { "type": "FIXED", "value": min_minor_str_ref }, + "3": { "type": "FIXED", "value": min_minor_str_ref }, + "4": { "type": "FIXED", "value": min_minor_str_ref }, + "5": { "type": "FIXED", "value": min_patch_str_ref }, + "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) }, + "7": { "type": "FIXED", "value": retry_state.offset.to_string() } + }) + }, &[ "request_id", "device_id", @@ -663,7 +816,7 @@ pub async fn fetch_captured_examples_after( pub async fn fetch_settled_examples_after( http_client: Arc, after_timestamps: &[String], - max_rows_per_timestamp: usize, + max_rows_per_timestamp: Option, offset: usize, background_executor: BackgroundExecutor, min_capture_version: Option, @@ -684,55 +837,41 @@ pub async fn fetch_settled_examples_after( let _ = min_capture_version; let statement = indoc! {r#" - WITH requested AS ( - SELECT - req.event_properties:request_id::string AS request_id, - req.device_id::string AS device_id, - req.time AS req_time, - req.time::string AS time, - req.event_properties:input AS input, - req.event_properties:format::string AS requested_format, - req.event_properties:output::string AS requested_output, - req.event_properties:zed_version::string AS zed_version - FROM events req - WHERE req.event_type = ? - AND req.event_properties:version = 'V3' - AND req.event_properties:input:can_collect_data = true - AND req.time > TRY_TO_TIMESTAMP_NTZ(?) - ) SELECT - req.request_id AS request_id, - req.device_id AS device_id, - req.time AS time, - req.input AS input, - req.requested_output AS requested_output, - settled.event_properties:settled_editable_region::string AS settled_editable_region, - req.requested_format AS requested_format, - req.zed_version AS zed_version - FROM requested req - INNER JOIN events settled - ON req.request_id = settled.event_properties:request_id::string - WHERE settled.event_type = ? - ORDER BY req.req_time ASC + ep_request_id AS request_id, + device_id AS device_id, + requested_at::string AS continuation_time, + requested_at::string AS time, + input_payload AS input, + requested_output AS requested_output, + settled_editable_region AS settled_editable_region, + requested_format AS requested_format, + zed_version AS zed_version + FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples + WHERE settled_editable_region IS NOT NULL + AND requested_at > TRY_TO_TIMESTAMP_NTZ(?) + ORDER BY requested_at ASC LIMIT ? OFFSET ? "#}; - let bindings = json!({ - "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, - "2": { "type": "TEXT", "value": after_date }, - "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT }, - "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }, - "5": { "type": "FIXED", "value": offset.to_string() } - }); - let examples = fetch_examples_with_query( http_client.clone(), &step_progress, background_executor.clone(), statement, - bindings, - SETTLED_STATEMENT_TIMEOUT_SECONDS, + QueryRetryState { + resume_after: after_date.clone(), + remaining_limit: max_rows_per_timestamp, + offset, + }, + |retry_state| { + json!({ + "1": { "type": "TEXT", "value": retry_state.resume_after }, + "2": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) }, + "3": { "type": "FIXED", "value": retry_state.offset.to_string() } + }) + }, &[ "request_id", "device_id", @@ -756,7 +895,7 @@ pub async fn fetch_settled_examples_after( pub async fn fetch_rated_examples_after( http_client: Arc, inputs: &[(String, Option)], - max_rows_per_timestamp: usize, + max_rows_per_timestamp: Option, offset: usize, background_executor: BackgroundExecutor, _min_capture_version: Option, @@ -786,54 +925,48 @@ pub async fn fetch_rated_examples_after( let statement = indoc! {r#" SELECT - rated.event_properties:request_id::string AS request_id, - rated.event_properties:inputs AS inputs, - rated.event_properties:output::string AS output, - rated.event_properties:rating::string AS rating, - rated.event_properties:feedback::string AS feedback, - rated.device_id::string AS device_id, - rated.time::string AS time, - deploy.event_properties:experiment_name::string AS experiment_name, - deploy.event_properties:environment::string AS environment, - rated.event_properties:zed_version::string AS zed_version - FROM events rated - LEFT JOIN events req - ON rated.event_properties:request_id::string = req.event_properties:request_id::string - AND req.event_type = ? - LEFT JOIN events deploy - ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string - AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string - AND deploy.event_type = ? - WHERE rated.event_type = ? - AND (? IS NULL OR rated.event_properties:rating::string = ?) - AND rated.time > TRY_TO_TIMESTAMP_NTZ(?) - AND rated.event_properties:inputs IS NOT NULL - AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL - AND rated.event_properties:output IS NOT NULL - AND rated.event_properties:inputs:can_collect_data = true - ORDER BY rated.time ASC + ep_request_id AS request_id, + rated_inputs AS inputs, + rated_output AS output, + rating AS rating, + feedback AS feedback, + device_id AS device_id, + requested_at::string AS continuation_time, + requested_at::string AS time, + NULL AS experiment_name, + NULL AS environment, + zed_version AS zed_version + FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples + WHERE rating IS NOT NULL + AND (? IS NULL OR rating = ?) + AND requested_at > TRY_TO_TIMESTAMP_NTZ(?) + AND rated_inputs IS NOT NULL + AND rated_inputs:cursor_excerpt IS NOT NULL + AND rated_output IS NOT NULL + ORDER BY requested_at ASC LIMIT ? OFFSET ? "#}; - let bindings = json!({ - "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, - "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT }, - "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT }, - "4": { "type": "TEXT", "value": rating_value }, - "5": { "type": "TEXT", "value": rating_value }, - "6": { "type": "TEXT", "value": after_date }, - "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }, - "8": { "type": "FIXED", "value": offset.to_string() } - }); - let examples = fetch_examples_with_query( http_client.clone(), &step_progress, background_executor.clone(), statement, - bindings, - DEFAULT_STATEMENT_TIMEOUT_SECONDS, + QueryRetryState { + resume_after: after_date.clone(), + remaining_limit: max_rows_per_timestamp, + offset, + }, + |retry_state| { + json!({ + "1": { "type": "TEXT", "value": rating_value }, + "2": { "type": "TEXT", "value": rating_value }, + "3": { "type": "TEXT", "value": retry_state.resume_after }, + "4": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) }, + "5": { "type": "FIXED", "value": retry_state.offset.to_string() } + }) + }, &[ "request_id", "inputs", @@ -1473,6 +1606,7 @@ fn rejected_examples_from_response<'a>( let input_json = get_json("input"); let input: Option = input_json.clone().and_then(|v| serde_json::from_value(v).ok()); + let prompt = get_string("prompt"); let output = get_string("output"); let was_shown = get_bool("was_shown"); let reason = get_string("reason"); @@ -1485,6 +1619,7 @@ fn rejected_examples_from_response<'a>( device_id, time, input, + prompt, output, was_shown, reason, @@ -1515,6 +1650,7 @@ fn build_rejected_example( device_id: String, time: String, input: ZetaPromptInput, + prompt: Option, output: String, was_shown: bool, reason: String, @@ -1536,6 +1672,13 @@ fn build_rejected_example( zed_version, ); example.spec.rejected_patch = Some(rejected_patch); + example.prompt = prompt.map(|prompt| ExamplePrompt { + input: prompt, + expected_output: String::new(), + rejected_output: Some(output), + prefill: None, + provider: PredictionProvider::default(), + }); example } @@ -1635,11 +1778,42 @@ fn build_output_patch( patch } +fn is_timeout_response(response: &SnowflakeStatementResponse) -> bool { + response.code.as_deref() == Some(SNOWFLAKE_TIMEOUT_CODE) + && response + .message + .as_deref() + .map(|message| message.to_ascii_lowercase().contains("timeout")) + .unwrap_or(false) +} + +fn is_snowflake_timeout_error(error: &anyhow::Error) -> bool { + error + .chain() + .any(|cause| cause.to_string().contains(SNOWFLAKE_TIMEOUT_CODE)) +} + +fn last_continuation_timestamp_from_response( + response: &SnowflakeStatementResponse, + column_indices: &HashMap, +) -> Option { + let continuation_time_index = column_indices.get("continuation_time").copied()?; + response + .data + .iter() + .rev() + .find_map(|row| match row.get(continuation_time_index)? { + JsonValue::String(value) => Some(value.clone()), + JsonValue::Null => None, + other => Some(other.to_string()), + }) +} + pub(crate) fn get_column_indices( meta: &Option, names: &[&str], -) -> std::collections::HashMap { - let mut indices = std::collections::HashMap::new(); +) -> HashMap { + let mut indices = HashMap::new(); if let Some(meta) = meta { for (index, col) in meta.row_type.iter().enumerate() { for &name in names { diff --git a/crates/edit_prediction_cli/src/repair.rs b/crates/edit_prediction_cli/src/repair.rs index 58da6c47e91491cc785804c7f4c2aab30887a741..e8fb36eae28bc65a3f2c865bb95a22175b1d7ad0 100644 --- a/crates/edit_prediction_cli/src/repair.rs +++ b/crates/edit_prediction_cli/src/repair.rs @@ -15,7 +15,7 @@ use crate::{ openai_client::OpenAiClient, parse_output::run_parse_output, paths::LLM_CACHE_DB, - progress::{ExampleProgress, Step}, + progress::{ExampleProgress, Progress, Step}, word_diff::unified_to_word_diff, }; use anyhow::{Context as _, Result}; @@ -75,6 +75,9 @@ pub struct RepairArgs { /// Which LLM provider to use (anthropic or openai) #[clap(long, default_value = "anthropic")] pub backend: BatchProvider, + /// Wait for all batches to complete before exiting + #[clap(long)] + pub wait: bool, } fn model_for_backend(backend: BatchProvider) -> &'static str { @@ -454,6 +457,68 @@ pub async fn sync_batches(args: &RepairArgs) -> Result<()> { Ok(()) } +pub async fn reprocess_after_batch_wait(examples: &mut [Example], args: &RepairArgs) -> Result<()> { + let mut reprocessed = 0; + for example in examples.iter_mut() { + if has_successful_repair(example) || !needs_repair(example, args.confidence_threshold) { + continue; + } + + let example_progress = Progress::global().start_group(&example.spec.name); + run_repair(example, args, &example_progress).await?; + reprocessed += 1; + } + + if reprocessed > 0 { + eprintln!("Reprocessed {} example(s) with batch results", reprocessed); + } + + Ok(()) +} + +pub async fn wait_for_batches(args: &RepairArgs) -> Result<()> { + if args.no_batch { + return Ok(()); + } + + let poll_interval = std::time::Duration::from_secs(30); + + loop { + let pending = pending_batch_count(args)?; + if pending == 0 { + break; + } + + eprintln!( + "Waiting for {} pending repair batch request(s) to complete... (polling every {}s)", + pending, + poll_interval.as_secs() + ); + std::thread::sleep(poll_interval); + + sync_batches(args).await?; + } + + Ok(()) +} + +fn pending_batch_count(args: &RepairArgs) -> Result { + match args.backend { + BatchProvider::Anthropic => { + let client = ANTHROPIC_CLIENT_BATCH.get_or_init(|| { + AnthropicClient::batch(&LLM_CACHE_DB).expect("Failed to create Anthropic client") + }); + client.pending_batch_count() + } + BatchProvider::Openai => { + let client = OPENAI_CLIENT_BATCH.get_or_init(|| { + OpenAiClient::batch(&LLM_CACHE_DB).expect("Failed to create OpenAI client") + }); + client.pending_batch_count() + } + } +} + #[cfg(test)] mod tests { use super::*;