ep: Add settled data fetching from snowflake (#50326)

Ben Kunkle created

Closes #ISSUE

Before you mark this PR as ready for review, make sure that you have:
- [ ] Added a solid test coverage and/or screenshots from doing manual
testing
- [ ] Done a self-review taking into account security and performance
aspects
- [ ] Aligned any UI changes with the [UI
checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist)

Release Notes:

- N/A *or* Added/Fixed/Improved ...

Change summary

crates/edit_prediction_cli/src/main.rs          |  26 +
crates/edit_prediction_cli/src/pull_examples.rs | 348 ++++++++++++++++++
2 files changed, 372 insertions(+), 2 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/main.rs 🔗

@@ -55,6 +55,7 @@ use crate::load_project::run_load_project;
 use crate::paths::{FAILED_EXAMPLES_DIR, RUN_DIR};
 use crate::predict::run_prediction;
 use crate::progress::Progress;
+use crate::pull_examples::{fetch_settled_examples_after, parse_settled_after_input};
 use crate::retrieve_context::run_context_retrieval;
 use crate::score::run_scoring;
 use crate::split_commit::SplitCommitArgs;
@@ -132,6 +133,10 @@ Inputs can be file paths or special specifiers:
       Fetch rejected edit predictions from Snowflake after the given RFC3339 timestamp.
       These are predictions that were shown to users but rejected (useful for DPO training).
 
+  settled-after:{timestamp}
+      Fetch settled stream examples from Snowflake after the given RFC3339 timestamp.
+      These are examples from the edit prediction settled stream.
+
   rated-after:{timestamp}
       Fetch user-rated edit predictions from Snowflake after the given RFC3339 timestamp.
       These are predictions that users explicitly rated as positive or negative via the
@@ -166,6 +171,9 @@ Examples:
   # Read user-rated predictions
   ep read rated-after:2025-01-01T00:00:00Z -o rated.jsonl
 
+  # Read settled stream examples
+  ep read settled-after:2025-01-01T00:00:00Z -o settled.jsonl
+
   # Read only positively rated predictions
   ep read rated-positive-after:2025-01-01T00:00:00Z -o positive.jsonl
 
@@ -635,6 +643,7 @@ async fn load_examples(
     let mut captured_after_timestamps = Vec::new();
     let mut rejected_after_timestamps = Vec::new();
     let mut requested_after_timestamps = Vec::new();
+    let mut settled_after_timestamps = Vec::new();
     let mut rated_after_inputs: Vec<(String, Option<telemetry_events::EditPredictionRating>)> =
         Vec::new();
     let mut file_inputs = Vec::new();
@@ -651,6 +660,8 @@ async fn load_examples(
             pull_examples::parse_requested_after_input(input_string.as_ref())
         {
             requested_after_timestamps.push(timestamp.to_string());
+        } else if let Some(timestamp) = parse_settled_after_input(input_string.as_ref()) {
+            settled_after_timestamps.push(timestamp.to_string());
         } else if let Some((timestamp, rating_filter)) =
             pull_examples::parse_rated_after_input(input_string.as_ref())
         {
@@ -718,6 +729,21 @@ async fn load_examples(
             examples.append(&mut requested_examples);
         }
 
+        if !settled_after_timestamps.is_empty() {
+            settled_after_timestamps.sort();
+
+            let mut settled_examples = fetch_settled_examples_after(
+                http_client.clone(),
+                &settled_after_timestamps,
+                max_rows_per_timestamp,
+                remaining_offset,
+                background_executor.clone(),
+                Some(MIN_CAPTURE_VERSION),
+            )
+            .await?;
+            examples.append(&mut settled_examples);
+        }
+
         if !rated_after_inputs.is_empty() {
             rated_after_inputs.sort();
 

crates/edit_prediction_cli/src/pull_examples.rs 🔗

@@ -5,24 +5,25 @@ use http_client::{AsyncBody, HttpClient, Method, Request};
 use indoc::indoc;
 use serde::Deserialize;
 use serde_json::{Value as JsonValue, json};
+use std::fmt::Write as _;
 use std::io::Read;
 use std::sync::Arc;
 use std::time::Duration;
 use telemetry_events::EditPredictionRating;
 
-use zeta_prompt::ZetaPromptInput;
+use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format};
 
 use crate::example::Example;
 use crate::progress::{InfoStyle, Progress, Step};
 const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
 use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
-use std::fmt::Write as _;
 
 pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
 pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
 const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
 const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
 const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
+const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
 
 /// Minimum Zed version for filtering captured examples.
 /// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
@@ -34,6 +35,7 @@ pub struct MinCaptureVersion {
 }
 
 const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
+const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
 pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
 pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
 
@@ -52,6 +54,11 @@ pub fn parse_requested_after_input(input: &str) -> Option<&str> {
     input.strip_prefix("requested-after:")
 }
 
+/// Parse an input token of the form `settled-after:{timestamp}`.
+pub fn parse_settled_after_input(input: &str) -> Option<&str> {
+    input.strip_prefix("settled-after:")
+}
+
 /// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
 /// or `rated-negative-after:{timestamp}`.
 /// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
@@ -596,6 +603,163 @@ pub async fn fetch_requested_examples_after(
     Ok(all_examples)
 }
 
+pub async fn fetch_settled_examples_after(
+    http_client: Arc<dyn HttpClient>,
+    after_timestamps: &[String],
+    max_rows_per_timestamp: usize,
+    offset: usize,
+    background_executor: BackgroundExecutor,
+    min_capture_version: Option<MinCaptureVersion>,
+) -> Result<Vec<Example>> {
+    if after_timestamps.is_empty() {
+        return Ok(Vec::new());
+    }
+
+    let progress = Progress::global();
+
+    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
+        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
+    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
+        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
+    )?;
+    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
+
+    let mut all_examples = Vec::new();
+
+    for after_date in after_timestamps.iter() {
+        let step_progress_name = format!("settled>{after_date}");
+        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
+        step_progress.set_substatus("querying");
+
+        let statement = indoc! {r#"
+            WITH requested AS (
+                SELECT
+                    req.event_properties:request_id::string AS request_id,
+                    req.device_id::string AS device_id,
+                    req.time AS req_time,
+                    req.time::string AS time,
+                    req.event_properties:input AS input,
+                    req.event_properties:format::string AS requested_format,
+                    req.event_properties:output::string AS requested_output,
+                    req.event_properties:zed_version::string AS zed_version
+                FROM events req
+                WHERE req.event_type = ?
+                    AND req.event_properties:version = 'V3'
+                    AND req.event_properties:input:can_collect_data = true
+                    AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
+            )
+            SELECT
+                req.request_id AS request_id,
+                req.device_id AS device_id,
+                req.time AS time,
+                req.input AS input,
+                req.requested_output AS requested_output,
+                settled.event_properties:settled_editable_region::string AS settled_editable_region,
+                req.requested_format AS requested_format,
+                req.zed_version AS zed_version
+            FROM requested req
+            INNER JOIN events settled
+                ON req.request_id = settled.event_properties:request_id::string
+            WHERE settled.event_type = ?
+            ORDER BY req.req_time ASC
+            LIMIT ?
+            OFFSET ?
+        "#};
+
+        let _ = min_capture_version;
+        let request = json!({
+            "statement": statement,
+            "timeout": SETTLED_STATEMENT_TIMEOUT_SECONDS,
+            "database": "EVENTS",
+            "schema": "PUBLIC",
+            "warehouse": "DBT",
+            "role": role,
+            "bindings": {
+                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
+                "2": { "type": "TEXT", "value": after_date },
+                "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
+                "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
+                "5": { "type": "FIXED", "value": offset.to_string() }
+            }
+        });
+
+        let response = run_sql_with_polling(
+            http_client.clone(),
+            &base_url,
+            &token,
+            &request,
+            &step_progress,
+            background_executor.clone(),
+        )
+        .await?;
+
+        let total_rows = response
+            .result_set_meta_data
+            .as_ref()
+            .and_then(|m| m.num_rows)
+            .unwrap_or(response.data.len() as i64);
+
+        let num_partitions = response
+            .result_set_meta_data
+            .as_ref()
+            .map(|m| m.partition_info.len())
+            .unwrap_or(1)
+            .max(1);
+
+        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
+        step_progress.set_substatus("parsing");
+
+        let column_indices = get_column_indices(
+            &response.result_set_meta_data,
+            &[
+                "request_id",
+                "device_id",
+                "time",
+                "input",
+                "requested_output",
+                "settled_editable_region",
+                "requested_format",
+                "zed_version",
+            ],
+        );
+
+        all_examples.extend(settled_examples_from_response(&response, &column_indices)?);
+
+        if num_partitions > 1 {
+            let statement_handle = response
+                .statement_handle
+                .as_ref()
+                .context("response has multiple partitions but no statementHandle")?;
+
+            for partition in 1..num_partitions {
+                step_progress.set_substatus(format!(
+                    "fetching partition {}/{}",
+                    partition + 1,
+                    num_partitions
+                ));
+
+                let partition_response = fetch_partition(
+                    http_client.clone(),
+                    &base_url,
+                    &token,
+                    statement_handle,
+                    partition,
+                )
+                .await?;
+
+                all_examples.extend(settled_examples_from_response(
+                    &partition_response,
+                    &column_indices,
+                )?);
+            }
+        }
+
+        step_progress.set_substatus("done");
+    }
+
+    Ok(all_examples)
+}
+
 pub async fn fetch_rated_examples_after(
     http_client: Arc<dyn HttpClient>,
     inputs: &[(String, Option<EditPredictionRating>)],
@@ -989,6 +1153,186 @@ fn requested_examples_from_response<'a>(
     Ok(iter)
 }
 
+fn settled_examples_from_response<'a>(
+    response: &'a SnowflakeStatementResponse,
+    column_indices: &'a std::collections::HashMap<String, usize>,
+) -> Result<impl Iterator<Item = Example> + 'a> {
+    if let Some(code) = &response.code {
+        if code != SNOWFLAKE_SUCCESS_CODE {
+            anyhow::bail!(
+                "snowflake sql api returned error code={code} message={}",
+                response.message.as_deref().unwrap_or("<no message>")
+            );
+        }
+    }
+
+    let iter = response
+        .data
+        .iter()
+        .enumerate()
+        .filter_map(move |(row_index, data_row)| {
+            let get_value = |name: &str| -> Option<JsonValue> {
+                let index = column_indices.get(name).copied()?;
+                let value = data_row.get(index)?;
+                if value.is_null() {
+                    None
+                } else {
+                    Some(value.clone())
+                }
+            };
+
+            let get_string = |name: &str| -> Option<String> {
+                match get_value(name)? {
+                    JsonValue::String(s) => Some(s),
+                    other => Some(other.to_string()),
+                }
+            };
+
+            let parse_json_value = |_: &str, raw: Option<&JsonValue>| -> Option<JsonValue> {
+                let value = raw?;
+                match value {
+                    JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
+                    other => Some(other.clone()),
+                }
+            };
+
+            let request_id_str = get_string("request_id");
+            let device_id = get_string("device_id");
+            let time = get_string("time");
+            let input_raw = get_value("input");
+            let input_json = parse_json_value("input", input_raw.as_ref());
+            let input: Option<ZetaPromptInput> = input_json
+                .as_ref()
+                .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
+            let requested_output = get_string("requested_output");
+            let settled_editable_region = get_string("settled_editable_region");
+            let requested_format =
+                get_string("requested_format").and_then(|s| ZetaFormat::parse(&s).ok());
+            let zed_version = get_string("zed_version");
+
+            match (
+                request_id_str.clone(),
+                device_id.clone(),
+                time.clone(),
+                input.clone(),
+                requested_output.clone(),
+                settled_editable_region.clone(),
+                requested_format,
+            ) {
+                (
+                    Some(request_id),
+                    Some(device_id),
+                    Some(time),
+                    Some(input),
+                    Some(requested_output),
+                    Some(settled_editable_region),
+                    Some(requested_format),
+                ) => Some(build_settled_example(
+                    request_id,
+                    device_id,
+                    time,
+                    input,
+                    requested_output,
+                    settled_editable_region,
+                    requested_format,
+                    zed_version,
+                )),
+                _ => {
+                    let mut missing_fields = Vec::new();
+
+                    if request_id_str.is_none() {
+                        missing_fields.push("request_id");
+                    }
+                    if device_id.is_none() {
+                        missing_fields.push("device_id");
+                    }
+                    if time.is_none() {
+                        missing_fields.push("time");
+                    }
+                    if input_raw.is_none() || input_json.is_none() || input.is_none() {
+                        missing_fields.push("input");
+                    }
+                    if requested_output.is_none() {
+                        missing_fields.push("requested_output");
+                    }
+                    if settled_editable_region.is_none() {
+                        missing_fields.push("settled_editable_region");
+                    }
+                    if requested_format.is_none() {
+                        missing_fields.push("requested_format");
+                    }
+
+                    log::warn!(
+                        "skipping settled row {row_index}: [{}]",
+                        missing_fields.join(", "),
+                    );
+                    None
+                }
+            }
+        });
+
+    Ok(iter)
+}
+
+fn build_settled_example(
+    request_id: String,
+    device_id: String,
+    time: String,
+    input: ZetaPromptInput,
+    requested_output: String,
+    settled_editable_region: String,
+    requested_format: ZetaFormat,
+    zed_version: Option<String>,
+) -> Example {
+    let requested_editable_range = input
+        .excerpt_ranges
+        .as_ref()
+        .map(|ranges| excerpt_range_for_format(requested_format, ranges).0)
+        .unwrap_or_else(|| input.editable_range_in_excerpt.clone());
+
+    let base_cursor_excerpt = input.cursor_excerpt.to_string();
+
+    let requested_range_is_valid = requested_editable_range.start <= requested_editable_range.end
+        && requested_editable_range.end <= base_cursor_excerpt.len();
+    let mut example = build_example_from_snowflake(
+        request_id.clone(),
+        device_id,
+        time,
+        input,
+        vec!["settled".to_string()],
+        None,
+        zed_version,
+    );
+
+    if !requested_range_is_valid {
+        log::warn!(
+            "skipping malformed requested range for request {}: requested={:?} (base_len={})",
+            request_id,
+            requested_editable_range,
+            base_cursor_excerpt.len(),
+        );
+        return example;
+    }
+
+    let settled_replacement = settled_editable_region.as_str();
+    let rejected_patch = build_output_patch(
+        &example.spec.cursor_path,
+        &base_cursor_excerpt,
+        &requested_editable_range,
+        &requested_output,
+    );
+    let expected_patch = build_output_patch(
+        &example.spec.cursor_path,
+        &base_cursor_excerpt,
+        &requested_editable_range,
+        settled_replacement,
+    );
+
+    example.spec.expected_patches = vec![expected_patch];
+    example.spec.rejected_patch = Some(rejected_patch);
+    example
+}
+
 fn rejected_examples_from_response<'a>(
     response: &'a SnowflakeStatementResponse,
     column_indices: &'a std::collections::HashMap<String, usize>,