From a1df9baa298631f07f739bb1509fff441dd51c1a Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 26 Jan 2026 02:49:44 -0800 Subject: [PATCH] Allow EP cli to fetch rejections from snowflake (#47628) Release Notes: - N/A --- crates/edit_prediction/src/capture_example.rs | 2 + crates/edit_prediction/src/example_spec.rs | 15 + crates/edit_prediction_cli/src/main.rs | 70 ++- .../edit_prediction_cli/src/pull_examples.rs | 419 +++++++++++++++++- .../edit_prediction_cli/src/split_commit.rs | 4 +- crates/edit_prediction_cli/src/synthesize.rs | 1 + 6 files changed, 483 insertions(+), 28 deletions(-) diff --git a/crates/edit_prediction/src/capture_example.rs b/crates/edit_prediction/src/capture_example.rs index 39eb1cbf45089e871be50884a682c8647a917f7f..07d97a27af2065dd33946705475a0b9127747c7f 100644 --- a/crates/edit_prediction/src/capture_example.rs +++ b/crates/edit_prediction/src/capture_example.rs @@ -173,6 +173,7 @@ pub fn capture_example( expected_patches, rejected_patch, captured_prompt_input: prompt_input, + telemetry: None, }; spec.set_cursor_excerpt( &cursor_excerpt, @@ -599,6 +600,7 @@ mod tests { .to_string() ), captured_prompt_input: example.captured_prompt_input.clone(), + telemetry: None, } ); diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs index 09ef97dffc60d1eda26c292e943c5619eb0bda39..0d5832e0ac74f9a36372fbacded55864085dfd2a 100644 --- a/crates/edit_prediction/src/example_spec.rs +++ b/crates/edit_prediction/src/example_spec.rs @@ -30,6 +30,18 @@ pub struct ExampleSpec { pub rejected_patch: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub captured_prompt_input: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub telemetry: Option, +} + +/// Metadata for examples sourced from production telemetry (rejected predictions). +#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)] +pub struct TelemetrySource { + pub request_id: String, + pub device_id: String, + pub time: String, + pub rejection_reason: String, + pub was_shown: bool, } /// All data needed to run format_prompt without loading the project. @@ -239,6 +251,7 @@ impl ExampleSpec { expected_patches: Vec::new(), rejected_patch: None, captured_prompt_input: None, + telemetry: None, }; if let Some(rest) = input.strip_prefix("+++\n") @@ -486,6 +499,7 @@ mod tests { expected_patches: Vec::new(), rejected_patch: None, captured_prompt_input: None, + telemetry: None, }; // Cursor before `42` @@ -620,6 +634,7 @@ mod tests { expected_patches: Vec::new(), rejected_patch: None, captured_prompt_input: None, + telemetry: None, }; // Cursor before `42` using inline marker diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 315e954c5f52d32c77e1f71ee4af2950ace6f83d..036743b2373afa173981f2903e899fae92841d6d 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -105,8 +105,11 @@ Inputs can be file paths or special specifiers: captured-after:{timestamp} Fetch captured examples from Snowflake after the given RFC3339 timestamp. + These are examples captured via the "Capture Edit Prediction Example" action. - You can specify this multiple times and mix it with file inputs. + rejected-after:{timestamp} + Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp. + These are predictions that were shown to users but rejected (useful for DPO training). Required environment variables to connect to Snowflake: EP_SNOWFLAKE_API_KEY @@ -117,20 +120,23 @@ Inputs can be file paths or special specifiers: Examples: - # Predict from a file - ep predict examples.jsonl + # Read examples from a file + ep read examples.jsonl -o output.jsonl - # Predict from captured examples after a timestamp - ep predict captured-after:2025-01-01T00:00:00Z + # Read captured examples after a timestamp + ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl - # Mix file inputs and captured-after in the same invocation + # Read rejected predictions for DPO training + ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl + + # Mix multiple input sources ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z "#; #[derive(Subcommand, Debug, Clone)] enum Command { - /// Parse markdown examples and output a combined .jsonl file - ParseExample, + /// Read examples from files or fetch from Snowflake, output as .jsonl + Read, /// Create git worktrees for each example and load file contents LoadProject, /// Retrieve context for input examples. @@ -168,7 +174,7 @@ enum Command { impl Display for Command { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Command::ParseExample => write!(f, "parse-example"), + Command::Read => write!(f, "read"), Command::LoadProject => write!(f, "load-project"), Command::Context => write!(f, "context"), Command::FormatPrompt(args) => { @@ -357,12 +363,17 @@ async fn load_examples( background_executor: BackgroundExecutor, ) -> anyhow::Result> { let mut captured_after_timestamps = Vec::new(); + let mut rejected_after_timestamps = Vec::new(); let mut file_inputs = Vec::new(); for input in &args.inputs { let input_string = input.to_string_lossy(); if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) { captured_after_timestamps.push(timestamp.to_string()); + } else if let Some(timestamp) = + pull_examples::parse_rejected_after_input(input_string.as_ref()) + { + rejected_after_timestamps.push(timestamp.to_string()); } else { file_inputs.push(input.clone()); } @@ -377,21 +388,36 @@ async fn load_examples( if let Some(0) = remaining_limit_for_snowflake { log::info!( - "skipping captured-after inputs because --limit is already satisfied by example files" + "skipping Snowflake inputs because --limit is already satisfied by example files" ); - } else if !captured_after_timestamps.is_empty() { - captured_after_timestamps.sort(); - + } else { let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000); - let mut captured_examples = pull_examples::fetch_captured_examples_after( - http_client, - &captured_after_timestamps, - max_rows_per_timestamp, - background_executor, - ) - .await?; - examples.append(&mut captured_examples); + if !captured_after_timestamps.is_empty() { + captured_after_timestamps.sort(); + + let mut captured_examples = pull_examples::fetch_captured_examples_after( + http_client.clone(), + &captured_after_timestamps, + max_rows_per_timestamp, + background_executor.clone(), + ) + .await?; + examples.append(&mut captured_examples); + } + + if !rejected_after_timestamps.is_empty() { + rejected_after_timestamps.sort(); + + let mut rejected_examples = pull_examples::fetch_rejected_examples_after( + http_client, + &rejected_after_timestamps, + max_rows_per_timestamp, + background_executor, + ) + .await?; + examples.append(&mut rejected_examples); + } } crate::example::sort_examples_by_repo_and_rev(&mut examples); @@ -687,7 +713,7 @@ fn main() { let result = async { match &command { - Command::ParseExample => {} + Command::Read => {} Command::LoadProject => { run_load_project( example, diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index eb0c5cbcd42cf1188ce555cdc76ee879356966bd..c4928f076db1f0d98031395b6a823cec4ef52062 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -9,15 +9,21 @@ use std::io::Read; use std::sync::Arc; use std::time::Duration; -use crate::{ - example::Example, - progress::{InfoStyle, Progress, Step}, +use zeta_prompt::ZetaPromptInput; + +use crate::example::Example; +use crate::progress::{InfoStyle, Progress, Step}; +use edit_prediction::example_spec::{ + CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec, + TelemetrySource, }; -use edit_prediction::example_spec::ExampleSpec; +use std::fmt::Write as _; const SNOWFLAKE_SUCCESS_CODE: &str = "090001"; const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334"; const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured"; +const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested"; +const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected"; const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120; const POLL_INTERVAL: Duration = Duration::from_secs(2); @@ -28,6 +34,11 @@ pub fn parse_captured_after_input(input: &str) -> Option<&str> { input.strip_prefix("captured-after:") } +/// Parse an input token of the form `rejected-after:{timestamp}`. +pub fn parse_rejected_after_input(input: &str) -> Option<&str> { + input.strip_prefix("rejected-after:") +} + pub async fn fetch_captured_examples_after( http_client: Arc, after_timestamps: &[String], @@ -302,6 +313,7 @@ async fn fetch_partition( ) .header("Accept", "application/json") .header("Accept-Encoding", "gzip") + .header("User-Agent", "edit_prediction_cli") .body(AsyncBody::empty())?; let response = http_client @@ -387,6 +399,7 @@ async fn run_sql( ) .header("Content-Type", "application/json") .header("Accept", "application/json") + .header("User-Agent", "edit_prediction_cli") .body(AsyncBody::from(request_body.clone()))?; let response = http_client @@ -414,3 +427,401 @@ async fn run_sql( serde_json::from_slice::(&body_bytes) .context("failed to parse Snowflake SQL API response JSON") } + +pub async fn fetch_rejected_examples_after( + http_client: Arc, + after_timestamps: &[String], + max_rows_per_timestamp: usize, + background_executor: BackgroundExecutor, +) -> Result> { + if after_timestamps.is_empty() { + return Ok(Vec::new()); + } + + let progress = Progress::global(); + + let token = std::env::var("EP_SNOWFLAKE_API_KEY") + .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?; + let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context( + "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://.snowflakecomputing.com)", + )?; + let role = std::env::var("EP_SNOWFLAKE_ROLE").ok(); + + let mut all_examples = Vec::new(); + + for after_date in after_timestamps.iter() { + let step_progress_name = format!("rejected>{after_date}"); + let step_progress = progress.start(Step::PullExamples, &step_progress_name); + step_progress.set_substatus("querying"); + + // Join rejected events with their corresponding request events to get the full context. + // We filter for V3 sampling data which contains the structured input we need. + // We also filter for predictions that were actually shown to the user (was_shown = true) + // to focus on explicit user rejections rather than implicit cancellations. + let statement = indoc! {r#" + SELECT + req.event_properties:request_id::string AS request_id, + req.device_id::string AS device_id, + req.time::string AS time, + req.event_properties:input AS input, + req.event_properties:prompt::string AS prompt, + req.event_properties:output::string AS output, + rej.event_properties:was_shown::boolean AS was_shown, + rej.event_properties:reason::string AS reason + FROM events req + INNER JOIN events rej + ON req.event_properties:request_id = rej.event_properties:request_id + WHERE req.event_type = ? + AND rej.event_type = ? + AND req.event_properties:version = 'V3' + AND rej.event_properties:was_shown = true + AND req.time > TRY_TO_TIMESTAMP_NTZ(?) + ORDER BY req.time ASC + LIMIT ? + "#}; + + let request = json!({ + "statement": statement, + "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS, + "database": "EVENTS", + "schema": "PUBLIC", + "warehouse": "DBT", + "role": role, + "bindings": { + "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, + "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT }, + "3": { "type": "TEXT", "value": after_date }, + "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() } + } + }); + + let response = run_sql_with_polling( + http_client.clone(), + &base_url, + &token, + &request, + &step_progress, + background_executor.clone(), + ) + .await?; + + let total_rows = response + .result_set_meta_data + .as_ref() + .and_then(|m| m.num_rows) + .unwrap_or(response.data.len() as i64); + + let num_partitions = response + .result_set_meta_data + .as_ref() + .map(|m| m.partition_info.len()) + .unwrap_or(1) + .max(1); + + step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal); + step_progress.set_substatus("parsing"); + + all_examples.extend(rejected_examples_from_response(&response)?); + + if num_partitions > 1 { + let statement_handle = response + .statement_handle + .as_ref() + .context("response has multiple partitions but no statementHandle")?; + + for partition in 1..num_partitions { + step_progress.set_substatus(format!( + "fetching partition {}/{}", + partition + 1, + num_partitions + )); + + let partition_response = fetch_partition( + http_client.clone(), + &base_url, + &token, + statement_handle, + partition, + ) + .await?; + + all_examples.extend(rejected_examples_from_response(&partition_response)?); + } + } + + step_progress.set_substatus("done"); + } + + Ok(all_examples) +} + +fn rejected_examples_from_response( + response: &SnowflakeStatementResponse, +) -> Result + '_> { + if let Some(code) = &response.code { + if code != SNOWFLAKE_SUCCESS_CODE { + anyhow::bail!( + "snowflake sql api returned error code={code} message={}", + response.message.as_deref().unwrap_or("") + ); + } + } + + let column_indices = get_column_indices( + &response.result_set_meta_data, + &[ + "request_id", + "device_id", + "time", + "input", + "prompt", + "output", + "was_shown", + "reason", + ], + ); + + let iter = response + .data + .iter() + .enumerate() + .filter_map(move |(row_index, data_row)| { + let get_string = |name: &str| -> Option { + let index = column_indices.get(name).copied()?; + match data_row.get(index)? { + JsonValue::String(s) => Some(s.clone()), + JsonValue::Null => None, + other => Some(other.to_string()), + } + }; + + let get_json = |name: &str| -> Option { + let index = column_indices.get(name).copied()?; + let value = data_row.get(index)?; + if value.is_null() { + return None; + } + match value { + JsonValue::String(s) => serde_json::from_str(s).ok(), + other => Some(other.clone()), + } + }; + + let get_bool = |name: &str| -> Option { + let index = column_indices.get(name).copied()?; + match data_row.get(index)? { + JsonValue::Bool(b) => Some(*b), + JsonValue::String(s) => s.parse().ok(), + _ => None, + } + }; + + let request_id_str = get_string("request_id"); + let device_id = get_string("device_id"); + let time = get_string("time"); + let input_json = get_json("input"); + let input: Option = + input_json.clone().and_then(|v| serde_json::from_value(v).ok()); + let output = get_string("output"); + let was_shown = get_bool("was_shown"); + let reason = get_string("reason"); + + match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) { + (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => { + Some(build_rejected_example( + request_id, + device_id, + time, + input, + output, + was_shown, + reason, + )) + } + _ => { + log::warn!( + "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}", + request_id_str.is_some(), + device_id.is_some(), + time.is_some(), + input_json.is_some(), + output.is_some(), + was_shown.is_some(), + reason.is_some() + ); + None + } + } + }); + + Ok(iter) +} + +fn build_rejected_example( + request_id: String, + device_id: String, + time: String, + input: ZetaPromptInput, + output: String, + was_shown: bool, + reason: String, +) -> Example { + let events: Vec = input + .events + .iter() + .map(|event| match event.as_ref() { + zeta_prompt::Event::BufferChange { + path, + old_path, + diff, + predicted, + in_open_source_repo, + } => CapturedEvent { + path: path.clone(), + old_path: old_path.clone(), + diff: diff.clone(), + predicted: *predicted, + in_open_source_repo: *in_open_source_repo, + }, + }) + .collect(); + + let related_files: Vec = input + .related_files + .iter() + .map(|rf| CapturedRelatedFile { + path: rf.path.clone(), + max_row: rf.max_row, + excerpts: rf + .excerpts + .iter() + .map(|e| CapturedRelatedExcerpt { + row_range: e.row_range.clone(), + text: e.text.to_string(), + }) + .collect(), + }) + .collect(); + + let cursor_excerpt = input.cursor_excerpt.as_ref(); + let cursor_offset = input.cursor_offset_in_excerpt; + + let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset); + + let mut edit_history = String::new(); + for event in &input.events { + zeta_prompt::write_event(&mut edit_history, event); + edit_history.push('\n'); + } + + let rejected_patch = build_rejected_patch( + &input.cursor_path, + cursor_excerpt, + &input.editable_range_in_excerpt, + &output, + ); + + let spec = ExampleSpec { + name: request_id.clone(), + repository_url: String::new(), + revision: String::new(), + tags: vec![format!("rejection:{}", reason.to_lowercase())], + reasoning: None, + uncommitted_diff: String::new(), + cursor_path: input.cursor_path.clone(), + cursor_position: build_cursor_position(cursor_excerpt, cursor_offset), + edit_history, + expected_patches: Vec::new(), + rejected_patch: Some(rejected_patch), + captured_prompt_input: Some(CapturedPromptInput { + cursor_file_content: cursor_excerpt.to_string(), + cursor_offset, + cursor_row, + cursor_column, + events, + related_files, + }), + telemetry: Some(TelemetrySource { + request_id, + device_id, + time, + rejection_reason: reason, + was_shown, + }), + }; + + Example { + spec, + prompt_inputs: None, + prompt: None, + predictions: Vec::new(), + score: Vec::new(), + state: None, + } +} + +fn compute_row_column(text: &str, offset: usize) -> (u32, u32) { + let mut row = 0u32; + let mut last_newline_offset = 0; + for (i, c) in text.char_indices() { + if i >= offset { + break; + } + if c == '\n' { + row += 1; + last_newline_offset = i + 1; + } + } + let column = (offset - last_newline_offset) as u32; + (row, column) +} + +fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String { + let before = &excerpt[..cursor_offset.min(excerpt.len())]; + let after = &excerpt[cursor_offset.min(excerpt.len())..]; + format!("{}[CURSOR_POSITION]{}", before, after) +} + +fn build_rejected_patch( + cursor_path: &std::path::Path, + cursor_excerpt: &str, + editable_range: &std::ops::Range, + model_output: &str, +) -> String { + let old_text = &cursor_excerpt[editable_range.clone()]; + + let editable_start_row = cursor_excerpt[..editable_range.start] + .chars() + .filter(|&c| c == '\n') + .count() as u32; + + let diff_body = language::unified_diff_with_offsets( + old_text, + model_output, + editable_start_row, + editable_start_row, + ); + + let mut patch = String::new(); + writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok(); + writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok(); + patch.push_str(&diff_body); + patch +} + +fn get_column_indices( + meta: &Option, + names: &[&str], +) -> std::collections::HashMap { + let mut indices = std::collections::HashMap::new(); + if let Some(meta) = meta { + for (index, col) in meta.row_type.iter().enumerate() { + for &name in names { + if col.name.eq_ignore_ascii_case(name) { + indices.insert(name.to_string(), index); + } + } + } + } + indices +} diff --git a/crates/edit_prediction_cli/src/split_commit.rs b/crates/edit_prediction_cli/src/split_commit.rs index 4a034f0c35b56266aa67dffe53c6c178f15bcfb0..8d00fd2f0846bfaf20c84a71a0f00a3e77cbc54e 100644 --- a/crates/edit_prediction_cli/src/split_commit.rs +++ b/crates/edit_prediction_cli/src/split_commit.rs @@ -363,7 +363,6 @@ pub fn generate_evaluation_example_from_ordered_commit( repository_url: repository_url.to_string(), revision: format!("{}~1", commit_hash), edit_history: split_commit.source_patch.clone(), - // cursor_position: cursor.to_string(), cursor_path: Path::new(&cursor.file).into(), cursor_position: cursor_excerpt, expected_patches: vec![split_commit.target_patch], @@ -372,6 +371,7 @@ pub fn generate_evaluation_example_from_ordered_commit( uncommitted_diff: String::new(), rejected_patch: None, captured_prompt_input: None, + telemetry: None, }) } @@ -1395,7 +1395,6 @@ Date: Mon Jan 1 00:00:00 2024 repository_url: "https://github.com/test/repo".to_string(), revision: "abc123~1".to_string(), edit_history: "patch1".to_string(), - // cursor_position: "file.rs:10:5".to_string(), cursor_path: Path::new("file.rs").into(), cursor_position: "some code<|user_cursor|>".to_string(), expected_patches: vec!["patch".to_string()], @@ -1404,6 +1403,7 @@ Date: Mon Jan 1 00:00:00 2024 uncommitted_diff: String::new(), rejected_patch: None, captured_prompt_input: None, + telemetry: None, }; let json = serde_json::to_string(&case).unwrap(); diff --git a/crates/edit_prediction_cli/src/synthesize.rs b/crates/edit_prediction_cli/src/synthesize.rs index 1d7b4eb874fc099b6a898d60be683e358a96b55b..603ee8de65663256246a5413dab6906c641d3de6 100644 --- a/crates/edit_prediction_cli/src/synthesize.rs +++ b/crates/edit_prediction_cli/src/synthesize.rs @@ -793,6 +793,7 @@ async fn build_example( expected_patches: vec![expected_patch_with_header], rejected_patch: None, captured_prompt_input: None, + telemetry: None, }; spec.set_cursor_excerpt(&excerpt, cursor_offset, comment_prefix);