Add support for pulling all requested prediction inputs

Max Brunsfeld and Ben Kunkle created

Co-authored-by: Ben Kunkle <ben@zed.dev>

Change summary

crates/edit_prediction_cli/src/main.rs          |  22 +
crates/edit_prediction_cli/src/pull_examples.rs | 252 ++++++++++++++++++
2 files changed, 262 insertions(+), 12 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/main.rs 🔗

@@ -435,6 +435,7 @@ async fn load_examples(
 ) -> anyhow::Result<Vec<Example>> {
     let mut captured_after_timestamps = Vec::new();
     let mut rejected_after_timestamps = Vec::new();
+    let mut requested_after_timestamps = Vec::new();
     let mut file_inputs = Vec::new();
 
     for input in &args.inputs {
@@ -445,6 +446,10 @@ async fn load_examples(
             pull_examples::parse_rejected_after_input(input_string.as_ref())
         {
             rejected_after_timestamps.push(timestamp.to_string());
+        } else if let Some(timestamp) =
+            pull_examples::parse_requested_after_input(input_string.as_ref())
+        {
+            requested_after_timestamps.push(timestamp.to_string());
         } else {
             file_inputs.push(input.clone());
         }
@@ -481,14 +486,27 @@ async fn load_examples(
             rejected_after_timestamps.sort();
 
             let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
-                http_client,
+                http_client.clone(),
                 &rejected_after_timestamps,
                 max_rows_per_timestamp,
-                background_executor,
+                background_executor.clone(),
             )
             .await?;
             examples.append(&mut rejected_examples);
         }
+
+        if !requested_after_timestamps.is_empty() {
+            requested_after_timestamps.sort();
+
+            let mut requested_examples = pull_examples::fetch_requested_examples_after(
+                http_client,
+                &requested_after_timestamps,
+                max_rows_per_timestamp,
+                background_executor,
+            )
+            .await?;
+            examples.append(&mut requested_examples);
+        }
     }
 
     crate::example::sort_examples_by_repo_and_rev(&mut examples);

crates/edit_prediction_cli/src/pull_examples.rs 🔗

@@ -39,6 +39,11 @@ pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
     input.strip_prefix("rejected-after:")
 }
 
+/// Parse an input token of the form `requested-after:{timestamp}`.
+pub fn parse_requested_after_input(input: &str) -> Option<&str> {
+    input.strip_prefix("requested-after:")
+}
+
 pub async fn fetch_captured_examples_after(
     http_client: Arc<dyn HttpClient>,
     after_timestamps: &[String],
@@ -556,6 +561,204 @@ pub async fn fetch_rejected_examples_after(
     Ok(all_examples)
 }
 
+pub async fn fetch_requested_examples_after(
+    http_client: Arc<dyn HttpClient>,
+    after_timestamps: &[String],
+    max_rows_per_timestamp: usize,
+    background_executor: BackgroundExecutor,
+) -> Result<Vec<Example>> {
+    if after_timestamps.is_empty() {
+        return Ok(Vec::new());
+    }
+
+    let progress = Progress::global();
+
+    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
+        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
+    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
+        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
+    )?;
+    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
+
+    let mut all_examples = Vec::new();
+
+    for after_date in after_timestamps.iter() {
+        let step_progress_name = format!("requested>{after_date}");
+        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
+        step_progress.set_substatus("querying");
+
+        let statement = indoc! {r#"
+            SELECT
+                req.event_properties:request_id::string AS request_id,
+                req.device_id::string AS device_id,
+                req.time::string AS time,
+                req.event_properties:input AS input
+            FROM events req
+            WHERE req.event_type = ?
+                AND req.event_properties:version = 'V3'
+                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
+            ORDER BY req.time ASC
+            LIMIT ?
+        "#};
+
+        let request = json!({
+            "statement": statement,
+            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+            "database": "EVENTS",
+            "schema": "PUBLIC",
+            "warehouse": "DBT",
+            "role": role,
+            "bindings": {
+                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
+                "2": { "type": "TEXT", "value": after_date },
+                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
+            }
+        });
+
+        let response = run_sql_with_polling(
+            http_client.clone(),
+            &base_url,
+            &token,
+            &request,
+            &step_progress,
+            background_executor.clone(),
+        )
+        .await?;
+
+        let total_rows = response
+            .result_set_meta_data
+            .as_ref()
+            .and_then(|m| m.num_rows)
+            .unwrap_or(response.data.len() as i64);
+
+        let num_partitions = response
+            .result_set_meta_data
+            .as_ref()
+            .map(|m| m.partition_info.len())
+            .unwrap_or(1)
+            .max(1);
+
+        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
+        step_progress.set_substatus("parsing");
+
+        let column_indices = get_column_indices(
+            &response.result_set_meta_data,
+            &["request_id", "device_id", "time", "input"],
+        );
+
+        all_examples.extend(requested_examples_from_response(
+            &response,
+            &column_indices,
+        )?);
+
+        if num_partitions > 1 {
+            let statement_handle = response
+                .statement_handle
+                .as_ref()
+                .context("response has multiple partitions but no statementHandle")?;
+
+            for partition in 1..num_partitions {
+                step_progress.set_substatus(format!(
+                    "fetching partition {}/{}",
+                    partition + 1,
+                    num_partitions
+                ));
+
+                let partition_response = fetch_partition(
+                    http_client.clone(),
+                    &base_url,
+                    &token,
+                    statement_handle,
+                    partition,
+                )
+                .await?;
+
+                all_examples.extend(requested_examples_from_response(
+                    &partition_response,
+                    &column_indices,
+                )?);
+            }
+        }
+
+        step_progress.set_substatus("done");
+    }
+
+    Ok(all_examples)
+}
+
+fn requested_examples_from_response<'a>(
+    response: &'a SnowflakeStatementResponse,
+    column_indices: &'a std::collections::HashMap<String, usize>,
+) -> Result<impl Iterator<Item = Example> + 'a> {
+    if let Some(code) = &response.code {
+        if code != SNOWFLAKE_SUCCESS_CODE {
+            anyhow::bail!(
+                "snowflake sql api returned error code={code} message={}",
+                response.message.as_deref().unwrap_or("<no message>")
+            );
+        }
+    }
+
+    let iter = response
+        .data
+        .iter()
+        .enumerate()
+        .filter_map(move |(row_index, data_row)| {
+            let get_string = |name: &str| -> Option<String> {
+                let index = column_indices.get(name).copied()?;
+                match data_row.get(index)? {
+                    JsonValue::String(s) => Some(s.clone()),
+                    JsonValue::Null => None,
+                    other => Some(other.to_string()),
+                }
+            };
+
+            let get_json = |name: &str| -> Option<JsonValue> {
+                let index = column_indices.get(name).copied()?;
+                let value = data_row.get(index)?;
+                if value.is_null() {
+                    return None;
+                }
+                match value {
+                    JsonValue::String(s) => serde_json::from_str(s).ok(),
+                    other => Some(other.clone()),
+                }
+            };
+
+            let request_id_str = get_string("request_id");
+            let device_id = get_string("device_id");
+            let time = get_string("time");
+            let input_json = get_json("input");
+            let input: Option<ZetaPromptInput> =
+                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
+
+            match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
+                (Some(request_id), Some(device_id), Some(time), Some(input)) => {
+                    Some(build_example_from_snowflake(
+                        request_id,
+                        device_id,
+                        time,
+                        input,
+                        vec!["requested".to_string()],
+                        None,
+                    ))
+                }
+                _ => {
+                    log::warn!(
+                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
+                        request_id_str.is_some(),
+                        device_id.is_some(),
+                        time.is_some(),
+                        input_json.is_some(),
+                    );
+                    None
+                }
+            }
+        });
+
+    Ok(iter)
+}
+
 fn rejected_examples_from_response(
     response: &SnowflakeStatementResponse,
 ) -> Result<impl Iterator<Item = Example> + '_> {
@@ -666,6 +869,37 @@ fn build_rejected_example(
     output: String,
     was_shown: bool,
     reason: String,
+) -> Example {
+    let rejected_patch = build_output_patch(
+        &input.cursor_path,
+        input.cursor_excerpt.as_ref(),
+        &input.editable_range_in_excerpt,
+        &output,
+    );
+    let mut example = build_example_from_snowflake(
+        request_id,
+        device_id,
+        time,
+        input,
+        vec![format!("rejection:{}", reason.to_lowercase())],
+        Some(RejectionInfo { reason, was_shown }),
+    );
+    example.spec.rejected_patch = Some(rejected_patch);
+    example
+}
+
+struct RejectionInfo {
+    reason: String,
+    was_shown: bool,
+}
+
+fn build_example_from_snowflake(
+    request_id: String,
+    device_id: String,
+    time: String,
+    input: ZetaPromptInput,
+    tags: Vec<String>,
+    rejection: Option<RejectionInfo>,
 ) -> Example {
     let events: Vec<CapturedEvent> = input
         .events
@@ -715,25 +949,23 @@ fn build_rejected_example(
         edit_history.push('\n');
     }
 
-    let rejected_patch = build_rejected_patch(
-        &input.cursor_path,
-        cursor_excerpt,
-        &input.editable_range_in_excerpt,
-        &output,
-    );
+    let (rejection_reason, was_shown) = match &rejection {
+        Some(r) => (r.reason.clone(), r.was_shown),
+        None => (String::new(), false),
+    };
 
     let spec = ExampleSpec {
         name: request_id.clone(),
         repository_url: String::new(),
         revision: String::new(),
-        tags: vec![format!("rejection:{}", reason.to_lowercase())],
+        tags,
         reasoning: None,
         uncommitted_diff: String::new(),
         cursor_path: input.cursor_path.clone(),
         cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
         edit_history,
         expected_patches: Vec::new(),
-        rejected_patch: Some(rejected_patch),
+        rejected_patch: None,
         captured_prompt_input: Some(CapturedPromptInput {
             cursor_file_content: cursor_excerpt.to_string(),
             cursor_offset,
@@ -746,7 +978,7 @@ fn build_rejected_example(
             request_id,
             device_id,
             time,
-            rejection_reason: reason,
+            rejection_reason,
             was_shown,
         }),
     };
@@ -784,7 +1016,7 @@ fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
     format!("{}[CURSOR_POSITION]{}", before, after)
 }
 
-fn build_rejected_patch(
+fn build_output_patch(
     cursor_path: &std::path::Path,
     cursor_excerpt: &str,
     editable_range: &std::ops::Range<usize>,