Allow EP cli to fetch rejections from snowflake (#47628)

Max Brunsfeld created

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/capture_example.rs   |   2 
crates/edit_prediction/src/example_spec.rs      |  15 
crates/edit_prediction_cli/src/main.rs          |  70 ++
crates/edit_prediction_cli/src/pull_examples.rs | 419 ++++++++++++++++++
crates/edit_prediction_cli/src/split_commit.rs  |   4 
crates/edit_prediction_cli/src/synthesize.rs    |   1 
6 files changed, 483 insertions(+), 28 deletions(-)

Detailed changes

crates/edit_prediction/src/capture_example.rs 🔗

@@ -173,6 +173,7 @@ pub fn capture_example(
             expected_patches,
             rejected_patch,
             captured_prompt_input: prompt_input,
+            telemetry: None,
         };
         spec.set_cursor_excerpt(
             &cursor_excerpt,
@@ -599,6 +600,7 @@ mod tests {
                     .to_string()
                 ),
                 captured_prompt_input: example.captured_prompt_input.clone(),
+                telemetry: None,
             }
         );
 

crates/edit_prediction/src/example_spec.rs 🔗

@@ -30,6 +30,18 @@ pub struct ExampleSpec {
     pub rejected_patch: Option<String>,
     #[serde(default, skip_serializing_if = "Option::is_none")]
     pub captured_prompt_input: Option<CapturedPromptInput>,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub telemetry: Option<TelemetrySource>,
+}
+
+/// Metadata for examples sourced from production telemetry (rejected predictions).
+#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
+pub struct TelemetrySource {
+    pub request_id: String,
+    pub device_id: String,
+    pub time: String,
+    pub rejection_reason: String,
+    pub was_shown: bool,
 }
 
 /// All data needed to run format_prompt without loading the project.
@@ -239,6 +251,7 @@ impl ExampleSpec {
             expected_patches: Vec::new(),
             rejected_patch: None,
             captured_prompt_input: None,
+            telemetry: None,
         };
 
         if let Some(rest) = input.strip_prefix("+++\n")
@@ -486,6 +499,7 @@ mod tests {
             expected_patches: Vec::new(),
             rejected_patch: None,
             captured_prompt_input: None,
+            telemetry: None,
         };
 
         // Cursor before `42`
@@ -620,6 +634,7 @@ mod tests {
             expected_patches: Vec::new(),
             rejected_patch: None,
             captured_prompt_input: None,
+            telemetry: None,
         };
 
         // Cursor before `42` using inline marker

crates/edit_prediction_cli/src/main.rs 🔗

@@ -105,8 +105,11 @@ Inputs can be file paths or special specifiers:
 
   captured-after:{timestamp}
       Fetch captured examples from Snowflake after the given RFC3339 timestamp.
+      These are examples captured via the "Capture Edit Prediction Example" action.
 
-      You can specify this multiple times and mix it with file inputs.
+  rejected-after:{timestamp}
+      Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
+      These are predictions that were shown to users but rejected (useful for DPO training).
 
       Required environment variables to connect to Snowflake:
           EP_SNOWFLAKE_API_KEY
@@ -117,20 +120,23 @@ Inputs can be file paths or special specifiers:
 
 Examples:
 
-  # Predict from a file
-  ep predict examples.jsonl
+  # Read examples from a file
+  ep read examples.jsonl -o output.jsonl
 
-  # Predict from captured examples after a timestamp
-  ep predict captured-after:2025-01-01T00:00:00Z
+  # Read captured examples after a timestamp
+  ep read captured-after:2025-01-01T00:00:00Z -o captured.jsonl
 
-  # Mix file inputs and captured-after in the same invocation
+  # Read rejected predictions for DPO training
+  ep read rejected-after:2025-01-01T00:00:00Z -o rejected.jsonl
+
+  # Mix multiple input sources
   ep predict examples.jsonl captured-after:2025-01-01T00:00:00Z
 "#;
 
 #[derive(Subcommand, Debug, Clone)]
 enum Command {
-    /// Parse markdown examples and output a combined .jsonl file
-    ParseExample,
+    /// Read examples from files or fetch from Snowflake, output as .jsonl
+    Read,
     /// Create git worktrees for each example and load file contents
     LoadProject,
     /// Retrieve context for input examples.
@@ -168,7 +174,7 @@ enum Command {
 impl Display for Command {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         match self {
-            Command::ParseExample => write!(f, "parse-example"),
+            Command::Read => write!(f, "read"),
             Command::LoadProject => write!(f, "load-project"),
             Command::Context => write!(f, "context"),
             Command::FormatPrompt(args) => {
@@ -357,12 +363,17 @@ async fn load_examples(
     background_executor: BackgroundExecutor,
 ) -> anyhow::Result<Vec<Example>> {
     let mut captured_after_timestamps = Vec::new();
+    let mut rejected_after_timestamps = Vec::new();
     let mut file_inputs = Vec::new();
 
     for input in &args.inputs {
         let input_string = input.to_string_lossy();
         if let Some(timestamp) = pull_examples::parse_captured_after_input(input_string.as_ref()) {
             captured_after_timestamps.push(timestamp.to_string());
+        } else if let Some(timestamp) =
+            pull_examples::parse_rejected_after_input(input_string.as_ref())
+        {
+            rejected_after_timestamps.push(timestamp.to_string());
         } else {
             file_inputs.push(input.clone());
         }
@@ -377,21 +388,36 @@ async fn load_examples(
 
     if let Some(0) = remaining_limit_for_snowflake {
         log::info!(
-            "skipping captured-after inputs because --limit is already satisfied by example files"
+            "skipping Snowflake inputs because --limit is already satisfied by example files"
         );
-    } else if !captured_after_timestamps.is_empty() {
-        captured_after_timestamps.sort();
-
+    } else {
         let max_rows_per_timestamp = remaining_limit_for_snowflake.unwrap_or(5000);
 
-        let mut captured_examples = pull_examples::fetch_captured_examples_after(
-            http_client,
-            &captured_after_timestamps,
-            max_rows_per_timestamp,
-            background_executor,
-        )
-        .await?;
-        examples.append(&mut captured_examples);
+        if !captured_after_timestamps.is_empty() {
+            captured_after_timestamps.sort();
+
+            let mut captured_examples = pull_examples::fetch_captured_examples_after(
+                http_client.clone(),
+                &captured_after_timestamps,
+                max_rows_per_timestamp,
+                background_executor.clone(),
+            )
+            .await?;
+            examples.append(&mut captured_examples);
+        }
+
+        if !rejected_after_timestamps.is_empty() {
+            rejected_after_timestamps.sort();
+
+            let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
+                http_client,
+                &rejected_after_timestamps,
+                max_rows_per_timestamp,
+                background_executor,
+            )
+            .await?;
+            examples.append(&mut rejected_examples);
+        }
     }
 
     crate::example::sort_examples_by_repo_and_rev(&mut examples);
@@ -687,7 +713,7 @@ fn main() {
 
                                 let result = async {
                                     match &command {
-                                        Command::ParseExample => {}
+                                        Command::Read => {}
                                         Command::LoadProject => {
                                             run_load_project(
                                                 example,

crates/edit_prediction_cli/src/pull_examples.rs 🔗

@@ -9,15 +9,21 @@ use std::io::Read;
 use std::sync::Arc;
 use std::time::Duration;
 
-use crate::{
-    example::Example,
-    progress::{InfoStyle, Progress, Step},
+use zeta_prompt::ZetaPromptInput;
+
+use crate::example::Example;
+use crate::progress::{InfoStyle, Progress, Step};
+use edit_prediction::example_spec::{
+    CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
+    TelemetrySource,
 };
-use edit_prediction::example_spec::ExampleSpec;
+use std::fmt::Write as _;
 
 const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
 const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
 const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
+const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
+const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
 
 const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
 const POLL_INTERVAL: Duration = Duration::from_secs(2);
@@ -28,6 +34,11 @@ pub fn parse_captured_after_input(input: &str) -> Option<&str> {
     input.strip_prefix("captured-after:")
 }
 
+/// Parse an input token of the form `rejected-after:{timestamp}`.
+pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
+    input.strip_prefix("rejected-after:")
+}
+
 pub async fn fetch_captured_examples_after(
     http_client: Arc<dyn HttpClient>,
     after_timestamps: &[String],
@@ -302,6 +313,7 @@ async fn fetch_partition(
         )
         .header("Accept", "application/json")
         .header("Accept-Encoding", "gzip")
+        .header("User-Agent", "edit_prediction_cli")
         .body(AsyncBody::empty())?;
 
     let response = http_client
@@ -387,6 +399,7 @@ async fn run_sql(
         )
         .header("Content-Type", "application/json")
         .header("Accept", "application/json")
+        .header("User-Agent", "edit_prediction_cli")
         .body(AsyncBody::from(request_body.clone()))?;
 
     let response = http_client
@@ -414,3 +427,401 @@ async fn run_sql(
     serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
         .context("failed to parse Snowflake SQL API response JSON")
 }
+
+pub async fn fetch_rejected_examples_after(
+    http_client: Arc<dyn HttpClient>,
+    after_timestamps: &[String],
+    max_rows_per_timestamp: usize,
+    background_executor: BackgroundExecutor,
+) -> Result<Vec<Example>> {
+    if after_timestamps.is_empty() {
+        return Ok(Vec::new());
+    }
+
+    let progress = Progress::global();
+
+    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
+        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
+    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
+        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
+    )?;
+    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
+
+    let mut all_examples = Vec::new();
+
+    for after_date in after_timestamps.iter() {
+        let step_progress_name = format!("rejected>{after_date}");
+        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
+        step_progress.set_substatus("querying");
+
+        // Join rejected events with their corresponding request events to get the full context.
+        // We filter for V3 sampling data which contains the structured input we need.
+        // We also filter for predictions that were actually shown to the user (was_shown = true)
+        // to focus on explicit user rejections rather than implicit cancellations.
+        let statement = indoc! {r#"
+            SELECT
+                req.event_properties:request_id::string AS request_id,
+                req.device_id::string AS device_id,
+                req.time::string AS time,
+                req.event_properties:input AS input,
+                req.event_properties:prompt::string AS prompt,
+                req.event_properties:output::string AS output,
+                rej.event_properties:was_shown::boolean AS was_shown,
+                rej.event_properties:reason::string AS reason
+            FROM events req
+            INNER JOIN events rej
+                ON req.event_properties:request_id = rej.event_properties:request_id
+            WHERE req.event_type = ?
+                AND rej.event_type = ?
+                AND req.event_properties:version = 'V3'
+                AND rej.event_properties:was_shown = true
+                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
+            ORDER BY req.time ASC
+            LIMIT ?
+        "#};
+
+        let request = json!({
+            "statement": statement,
+            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+            "database": "EVENTS",
+            "schema": "PUBLIC",
+            "warehouse": "DBT",
+            "role": role,
+            "bindings": {
+                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
+                "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
+                "3": { "type": "TEXT", "value": after_date },
+                "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
+            }
+        });
+
+        let response = run_sql_with_polling(
+            http_client.clone(),
+            &base_url,
+            &token,
+            &request,
+            &step_progress,
+            background_executor.clone(),
+        )
+        .await?;
+
+        let total_rows = response
+            .result_set_meta_data
+            .as_ref()
+            .and_then(|m| m.num_rows)
+            .unwrap_or(response.data.len() as i64);
+
+        let num_partitions = response
+            .result_set_meta_data
+            .as_ref()
+            .map(|m| m.partition_info.len())
+            .unwrap_or(1)
+            .max(1);
+
+        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
+        step_progress.set_substatus("parsing");
+
+        all_examples.extend(rejected_examples_from_response(&response)?);
+
+        if num_partitions > 1 {
+            let statement_handle = response
+                .statement_handle
+                .as_ref()
+                .context("response has multiple partitions but no statementHandle")?;
+
+            for partition in 1..num_partitions {
+                step_progress.set_substatus(format!(
+                    "fetching partition {}/{}",
+                    partition + 1,
+                    num_partitions
+                ));
+
+                let partition_response = fetch_partition(
+                    http_client.clone(),
+                    &base_url,
+                    &token,
+                    statement_handle,
+                    partition,
+                )
+                .await?;
+
+                all_examples.extend(rejected_examples_from_response(&partition_response)?);
+            }
+        }
+
+        step_progress.set_substatus("done");
+    }
+
+    Ok(all_examples)
+}
+
+fn rejected_examples_from_response(
+    response: &SnowflakeStatementResponse,
+) -> Result<impl Iterator<Item = Example> + '_> {
+    if let Some(code) = &response.code {
+        if code != SNOWFLAKE_SUCCESS_CODE {
+            anyhow::bail!(
+                "snowflake sql api returned error code={code} message={}",
+                response.message.as_deref().unwrap_or("<no message>")
+            );
+        }
+    }
+
+    let column_indices = get_column_indices(
+        &response.result_set_meta_data,
+        &[
+            "request_id",
+            "device_id",
+            "time",
+            "input",
+            "prompt",
+            "output",
+            "was_shown",
+            "reason",
+        ],
+    );
+
+    let iter = response
+        .data
+        .iter()
+        .enumerate()
+        .filter_map(move |(row_index, data_row)| {
+            let get_string = |name: &str| -> Option<String> {
+                let index = column_indices.get(name).copied()?;
+                match data_row.get(index)? {
+                    JsonValue::String(s) => Some(s.clone()),
+                    JsonValue::Null => None,
+                    other => Some(other.to_string()),
+                }
+            };
+
+            let get_json = |name: &str| -> Option<JsonValue> {
+                let index = column_indices.get(name).copied()?;
+                let value = data_row.get(index)?;
+                if value.is_null() {
+                    return None;
+                }
+                match value {
+                    JsonValue::String(s) => serde_json::from_str(s).ok(),
+                    other => Some(other.clone()),
+                }
+            };
+
+            let get_bool = |name: &str| -> Option<bool> {
+                let index = column_indices.get(name).copied()?;
+                match data_row.get(index)? {
+                    JsonValue::Bool(b) => Some(*b),
+                    JsonValue::String(s) => s.parse().ok(),
+                    _ => None,
+                }
+            };
+
+            let request_id_str = get_string("request_id");
+            let device_id = get_string("device_id");
+            let time = get_string("time");
+            let input_json = get_json("input");
+            let input: Option<ZetaPromptInput> =
+                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
+            let output = get_string("output");
+            let was_shown = get_bool("was_shown");
+            let reason = get_string("reason");
+
+            match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
+                (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
+                    Some(build_rejected_example(
+                        request_id,
+                        device_id,
+                        time,
+                        input,
+                        output,
+                        was_shown,
+                        reason,
+                    ))
+                }
+                _ => {
+                    log::warn!(
+                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
+                        request_id_str.is_some(),
+                        device_id.is_some(),
+                        time.is_some(),
+                        input_json.is_some(),
+                        output.is_some(),
+                        was_shown.is_some(),
+                        reason.is_some()
+                    );
+                    None
+                }
+            }
+        });
+
+    Ok(iter)
+}
+
+fn build_rejected_example(
+    request_id: String,
+    device_id: String,
+    time: String,
+    input: ZetaPromptInput,
+    output: String,
+    was_shown: bool,
+    reason: String,
+) -> Example {
+    let events: Vec<CapturedEvent> = input
+        .events
+        .iter()
+        .map(|event| match event.as_ref() {
+            zeta_prompt::Event::BufferChange {
+                path,
+                old_path,
+                diff,
+                predicted,
+                in_open_source_repo,
+            } => CapturedEvent {
+                path: path.clone(),
+                old_path: old_path.clone(),
+                diff: diff.clone(),
+                predicted: *predicted,
+                in_open_source_repo: *in_open_source_repo,
+            },
+        })
+        .collect();
+
+    let related_files: Vec<CapturedRelatedFile> = input
+        .related_files
+        .iter()
+        .map(|rf| CapturedRelatedFile {
+            path: rf.path.clone(),
+            max_row: rf.max_row,
+            excerpts: rf
+                .excerpts
+                .iter()
+                .map(|e| CapturedRelatedExcerpt {
+                    row_range: e.row_range.clone(),
+                    text: e.text.to_string(),
+                })
+                .collect(),
+        })
+        .collect();
+
+    let cursor_excerpt = input.cursor_excerpt.as_ref();
+    let cursor_offset = input.cursor_offset_in_excerpt;
+
+    let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
+
+    let mut edit_history = String::new();
+    for event in &input.events {
+        zeta_prompt::write_event(&mut edit_history, event);
+        edit_history.push('\n');
+    }
+
+    let rejected_patch = build_rejected_patch(
+        &input.cursor_path,
+        cursor_excerpt,
+        &input.editable_range_in_excerpt,
+        &output,
+    );
+
+    let spec = ExampleSpec {
+        name: request_id.clone(),
+        repository_url: String::new(),
+        revision: String::new(),
+        tags: vec![format!("rejection:{}", reason.to_lowercase())],
+        reasoning: None,
+        uncommitted_diff: String::new(),
+        cursor_path: input.cursor_path.clone(),
+        cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
+        edit_history,
+        expected_patches: Vec::new(),
+        rejected_patch: Some(rejected_patch),
+        captured_prompt_input: Some(CapturedPromptInput {
+            cursor_file_content: cursor_excerpt.to_string(),
+            cursor_offset,
+            cursor_row,
+            cursor_column,
+            events,
+            related_files,
+        }),
+        telemetry: Some(TelemetrySource {
+            request_id,
+            device_id,
+            time,
+            rejection_reason: reason,
+            was_shown,
+        }),
+    };
+
+    Example {
+        spec,
+        prompt_inputs: None,
+        prompt: None,
+        predictions: Vec::new(),
+        score: Vec::new(),
+        state: None,
+    }
+}
+
+fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
+    let mut row = 0u32;
+    let mut last_newline_offset = 0;
+    for (i, c) in text.char_indices() {
+        if i >= offset {
+            break;
+        }
+        if c == '\n' {
+            row += 1;
+            last_newline_offset = i + 1;
+        }
+    }
+    let column = (offset - last_newline_offset) as u32;
+    (row, column)
+}
+
+fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
+    let before = &excerpt[..cursor_offset.min(excerpt.len())];
+    let after = &excerpt[cursor_offset.min(excerpt.len())..];
+    format!("{}[CURSOR_POSITION]{}", before, after)
+}
+
+fn build_rejected_patch(
+    cursor_path: &std::path::Path,
+    cursor_excerpt: &str,
+    editable_range: &std::ops::Range<usize>,
+    model_output: &str,
+) -> String {
+    let old_text = &cursor_excerpt[editable_range.clone()];
+
+    let editable_start_row = cursor_excerpt[..editable_range.start]
+        .chars()
+        .filter(|&c| c == '\n')
+        .count() as u32;
+
+    let diff_body = language::unified_diff_with_offsets(
+        old_text,
+        model_output,
+        editable_start_row,
+        editable_start_row,
+    );
+
+    let mut patch = String::new();
+    writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
+    writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
+    patch.push_str(&diff_body);
+    patch
+}
+
+fn get_column_indices(
+    meta: &Option<SnowflakeResultSetMetaData>,
+    names: &[&str],
+) -> std::collections::HashMap<String, usize> {
+    let mut indices = std::collections::HashMap::new();
+    if let Some(meta) = meta {
+        for (index, col) in meta.row_type.iter().enumerate() {
+            for &name in names {
+                if col.name.eq_ignore_ascii_case(name) {
+                    indices.insert(name.to_string(), index);
+                }
+            }
+        }
+    }
+    indices
+}

crates/edit_prediction_cli/src/split_commit.rs 🔗

@@ -363,7 +363,6 @@ pub fn generate_evaluation_example_from_ordered_commit(
         repository_url: repository_url.to_string(),
         revision: format!("{}~1", commit_hash),
         edit_history: split_commit.source_patch.clone(),
-        // cursor_position: cursor.to_string(),
         cursor_path: Path::new(&cursor.file).into(),
         cursor_position: cursor_excerpt,
         expected_patches: vec![split_commit.target_patch],
@@ -372,6 +371,7 @@ pub fn generate_evaluation_example_from_ordered_commit(
         uncommitted_diff: String::new(),
         rejected_patch: None,
         captured_prompt_input: None,
+        telemetry: None,
     })
 }
 
@@ -1395,7 +1395,6 @@ Date: Mon Jan 1 00:00:00 2024
             repository_url: "https://github.com/test/repo".to_string(),
             revision: "abc123~1".to_string(),
             edit_history: "patch1".to_string(),
-            // cursor_position: "file.rs:10:5".to_string(),
             cursor_path: Path::new("file.rs").into(),
             cursor_position: "some code<|user_cursor|>".to_string(),
             expected_patches: vec!["patch".to_string()],
@@ -1404,6 +1403,7 @@ Date: Mon Jan 1 00:00:00 2024
             uncommitted_diff: String::new(),
             rejected_patch: None,
             captured_prompt_input: None,
+            telemetry: None,
         };
 
         let json = serde_json::to_string(&case).unwrap();

crates/edit_prediction_cli/src/synthesize.rs 🔗

@@ -793,6 +793,7 @@ async fn build_example(
         expected_patches: vec![expected_patch_with_header],
         rejected_patch: None,
         captured_prompt_input: None,
+        telemetry: None,
     };
     spec.set_cursor_excerpt(&excerpt, cursor_offset, comment_prefix);