diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 1ab126d32ee19b2eb754f4ad31fbaf38ed5eaafc..207a69328fb07277c39463c0c6a460862c95fe42 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -55,6 +55,7 @@ use crate::load_project::run_load_project; use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR}; use crate::predict::run_prediction; use crate::progress::Progress; +use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input}; use crate::retrieve_context::run_context_retrieval; use crate::score::run_scoring; use crate::split_commit::SplitCommitArgs; @@ -132,6 +133,10 @@ Inputs can be file paths or special specifiers: Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp. These are predictions that were shown to users but rejected (useful for DPO training). + settled-after:{timestamp} + Fetch settled stream examples from Snowflake after the given RFC3339 timestamp. + These are examples from the edit prediction settled stream. + rated-after:{timestamp} Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp. These are predictions that users explicitly rated as positive or negative via the @@ -166,6 +171,9 @@ Examples: # Read user-rated predictions ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl + # Read settled stream examples + ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl + # Read only positively rated predictions ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl @@ -635,6 +643,7 @@ async fn load_examples( let mut captured_after_timestamps = Vec::new(); let mut rejected_after_timestamps = Vec::new(); let mut requested_after_timestamps = Vec::new(); + let mut settled_after_timestamps = Vec::new(); let mut rated_after_inputs: Vec<(String, Option)> = Vec::new(); let mut file_inputs = Vec::new(); @@ -651,6 +660,8 @@ async fn load_examples( pull_examples::parse_requested_after_input(input_string.as_ref()) { requested_after_timestamps.push(timestamp.to_string()); + } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) { + settled_after_timestamps.push(timestamp.to_string()); } else if let Some((timestamp, rating_filter)) = pull_examples::parse_rated_after_input(input_string.as_ref()) { @@ -718,6 +729,21 @@ async fn load_examples( examples.append(&mut requested_examples); } + if !settled_after_timestamps.is_empty() { + settled_after_timestamps.sort(); + + let mut settled_examples = fetch_settled_examples_after( + http_client.clone(), + &settled_after_timestamps, + max_rows_per_timestamp, + remaining_offset, + background_executor.clone(), + Some(MIN_CAPTURE_VERSION), + ) + .await?; + examples.append(&mut settled_examples); + } + if !rated_after_inputs.is_empty() { rated_after_inputs.sort(); diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index cacfc9bb679acdcb3c709736c6e4b5e79af861e8..e34fc62c031cbfa411d9d5a701a3e327d0be8166 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -5,24 +5,25 @@ use http_client::{AsyncBody, HttpClient, Method, Request}; use indoc::indoc; use serde::Deserialize; use serde_json::{Value as JsonValue, json}; +use std::fmt::Write as _; use std::io::Read; use std::sync::Arc; use std::time::Duration; use telemetry_events::EditPredictionRating; -use zeta_prompt::ZetaPromptInput; +use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format}; use crate::example::Example; use crate::progress::{InfoStyle, Progress, Step}; const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment"; use edit_prediction::example_spec::{ExampleSpec, TelemetrySource}; -use std::fmt::Write as _; pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001"; pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334"; const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested"; const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected"; const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated"; +const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled"; /// Minimum Zed version for filtering captured examples. /// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples @@ -34,6 +35,7 @@ pub struct MinCaptureVersion { } const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120; +const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240; pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2); pub(crate) const MAX_POLL_ATTEMPTS: usize = 120; @@ -52,6 +54,11 @@ pub fn parse_requested_after_input(input: &str) -> Option<&str> { input.strip_prefix("requested-after:") } +/// Parse an input token of the form `settled-after:{timestamp}`. +pub fn parse_settled_after_input(input: &str) -> Option<&str> { + input.strip_prefix("settled-after:") +} + /// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`, /// or `rated-negative-after:{timestamp}`. /// Returns `(timestamp, Option)` where `None` means all ratings. @@ -596,6 +603,163 @@ pub async fn fetch_requested_examples_after( Ok(all_examples) } +pub async fn fetch_settled_examples_after( + http_client: Arc, + after_timestamps: &[String], + max_rows_per_timestamp: usize, + offset: usize, + background_executor: BackgroundExecutor, + min_capture_version: Option, +) -> Result> { + if after_timestamps.is_empty() { + return Ok(Vec::new()); + } + + let progress = Progress::global(); + + let token = std::env::var("EP_SNOWFLAKE_API_KEY") + .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?; + let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context( + "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://.snowflakecomputing.com)", + )?; + let role = std::env::var("EP_SNOWFLAKE_ROLE").ok(); + + let mut all_examples = Vec::new(); + + for after_date in after_timestamps.iter() { + let step_progress_name = format!("settled>{after_date}"); + let step_progress = progress.start(Step::PullExamples, &step_progress_name); + step_progress.set_substatus("querying"); + + let statement = indoc! {r#" + WITH requested AS ( + SELECT + req.event_properties:request_id::string AS request_id, + req.device_id::string AS device_id, + req.time AS req_time, + req.time::string AS time, + req.event_properties:input AS input, + req.event_properties:format::string AS requested_format, + req.event_properties:output::string AS requested_output, + req.event_properties:zed_version::string AS zed_version + FROM events req + WHERE req.event_type = ? + AND req.event_properties:version = 'V3' + AND req.event_properties:input:can_collect_data = true + AND req.time > TRY_TO_TIMESTAMP_NTZ(?) + ) + SELECT + req.request_id AS request_id, + req.device_id AS device_id, + req.time AS time, + req.input AS input, + req.requested_output AS requested_output, + settled.event_properties:settled_editable_region::string AS settled_editable_region, + req.requested_format AS requested_format, + req.zed_version AS zed_version + FROM requested req + INNER JOIN events settled + ON req.request_id = settled.event_properties:request_id::string + WHERE settled.event_type = ? + ORDER BY req.req_time ASC + LIMIT ? + OFFSET ? + "#}; + + let _ = min_capture_version; + let request = json!({ + "statement": statement, + "timeout": SETTLED_STATEMENT_TIMEOUT_SECONDS, + "database": "EVENTS", + "schema": "PUBLIC", + "warehouse": "DBT", + "role": role, + "bindings": { + "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, + "2": { "type": "TEXT", "value": after_date }, + "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT }, + "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }, + "5": { "type": "FIXED", "value": offset.to_string() } + } + }); + + let response = run_sql_with_polling( + http_client.clone(), + &base_url, + &token, + &request, + &step_progress, + background_executor.clone(), + ) + .await?; + + let total_rows = response + .result_set_meta_data + .as_ref() + .and_then(|m| m.num_rows) + .unwrap_or(response.data.len() as i64); + + let num_partitions = response + .result_set_meta_data + .as_ref() + .map(|m| m.partition_info.len()) + .unwrap_or(1) + .max(1); + + step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal); + step_progress.set_substatus("parsing"); + + let column_indices = get_column_indices( + &response.result_set_meta_data, + &[ + "request_id", + "device_id", + "time", + "input", + "requested_output", + "settled_editable_region", + "requested_format", + "zed_version", + ], + ); + + all_examples.extend(settled_examples_from_response(&response, &column_indices)?); + + if num_partitions > 1 { + let statement_handle = response + .statement_handle + .as_ref() + .context("response has multiple partitions but no statementHandle")?; + + for partition in 1..num_partitions { + step_progress.set_substatus(format!( + "fetching partition {}/{}", + partition + 1, + num_partitions + )); + + let partition_response = fetch_partition( + http_client.clone(), + &base_url, + &token, + statement_handle, + partition, + ) + .await?; + + all_examples.extend(settled_examples_from_response( + &partition_response, + &column_indices, + )?); + } + } + + step_progress.set_substatus("done"); + } + + Ok(all_examples) +} + pub async fn fetch_rated_examples_after( http_client: Arc, inputs: &[(String, Option)], @@ -989,6 +1153,186 @@ fn requested_examples_from_response<'a>( Ok(iter) } +fn settled_examples_from_response<'a>( + response: &'a SnowflakeStatementResponse, + column_indices: &'a std::collections::HashMap, +) -> Result + 'a> { + if let Some(code) = &response.code { + if code != SNOWFLAKE_SUCCESS_CODE { + anyhow::bail!( + "snowflake sql api returned error code={code} message={}", + response.message.as_deref().unwrap_or("") + ); + } + } + + let iter = response + .data + .iter() + .enumerate() + .filter_map(move |(row_index, data_row)| { + let get_value = |name: &str| -> Option { + let index = column_indices.get(name).copied()?; + let value = data_row.get(index)?; + if value.is_null() { + None + } else { + Some(value.clone()) + } + }; + + let get_string = |name: &str| -> Option { + match get_value(name)? { + JsonValue::String(s) => Some(s), + other => Some(other.to_string()), + } + }; + + let parse_json_value = |_: &str, raw: Option<&JsonValue>| -> Option { + let value = raw?; + match value { + JsonValue::String(s) => serde_json::from_str::(s).ok(), + other => Some(other.clone()), + } + }; + + let request_id_str = get_string("request_id"); + let device_id = get_string("device_id"); + let time = get_string("time"); + let input_raw = get_value("input"); + let input_json = parse_json_value("input", input_raw.as_ref()); + let input: Option = input_json + .as_ref() + .and_then(|parsed| serde_json::from_value(parsed.clone()).ok()); + let requested_output = get_string("requested_output"); + let settled_editable_region = get_string("settled_editable_region"); + let requested_format = + get_string("requested_format").and_then(|s| ZetaFormat::parse(&s).ok()); + let zed_version = get_string("zed_version"); + + match ( + request_id_str.clone(), + device_id.clone(), + time.clone(), + input.clone(), + requested_output.clone(), + settled_editable_region.clone(), + requested_format, + ) { + ( + Some(request_id), + Some(device_id), + Some(time), + Some(input), + Some(requested_output), + Some(settled_editable_region), + Some(requested_format), + ) => Some(build_settled_example( + request_id, + device_id, + time, + input, + requested_output, + settled_editable_region, + requested_format, + zed_version, + )), + _ => { + let mut missing_fields = Vec::new(); + + if request_id_str.is_none() { + missing_fields.push("request_id"); + } + if device_id.is_none() { + missing_fields.push("device_id"); + } + if time.is_none() { + missing_fields.push("time"); + } + if input_raw.is_none() || input_json.is_none() || input.is_none() { + missing_fields.push("input"); + } + if requested_output.is_none() { + missing_fields.push("requested_output"); + } + if settled_editable_region.is_none() { + missing_fields.push("settled_editable_region"); + } + if requested_format.is_none() { + missing_fields.push("requested_format"); + } + + log::warn!( + "skipping settled row {row_index}: [{}]", + missing_fields.join(", "), + ); + None + } + } + }); + + Ok(iter) +} + +fn build_settled_example( + request_id: String, + device_id: String, + time: String, + input: ZetaPromptInput, + requested_output: String, + settled_editable_region: String, + requested_format: ZetaFormat, + zed_version: Option, +) -> Example { + let requested_editable_range = input + .excerpt_ranges + .as_ref() + .map(|ranges| excerpt_range_for_format(requested_format, ranges).0) + .unwrap_or_else(|| input.editable_range_in_excerpt.clone()); + + let base_cursor_excerpt = input.cursor_excerpt.to_string(); + + let requested_range_is_valid = requested_editable_range.start <= requested_editable_range.end + && requested_editable_range.end <= base_cursor_excerpt.len(); + let mut example = build_example_from_snowflake( + request_id.clone(), + device_id, + time, + input, + vec!["settled".to_string()], + None, + zed_version, + ); + + if !requested_range_is_valid { + log::warn!( + "skipping malformed requested range for request {}: requested={:?} (base_len={})", + request_id, + requested_editable_range, + base_cursor_excerpt.len(), + ); + return example; + } + + let settled_replacement = settled_editable_region.as_str(); + let rejected_patch = build_output_patch( + &example.spec.cursor_path, + &base_cursor_excerpt, + &requested_editable_range, + &requested_output, + ); + let expected_patch = build_output_patch( + &example.spec.cursor_path, + &base_cursor_excerpt, + &requested_editable_range, + settled_replacement, + ); + + example.spec.expected_patches = vec![expected_patch]; + example.spec.rejected_patch = Some(rejected_patch); + example +} + fn rejected_examples_from_response<'a>( response: &'a SnowflakeStatementResponse, column_indices: &'a std::collections::HashMap,