@@ -5,24 +5,25 @@ use http_client::{AsyncBody, HttpClient, Method, Request};
use indoc::indoc;
use serde::Deserialize;
use serde_json::{Value as JsonValue, json};
+use std::fmt::Write as _;
use std::io::Read;
use std::sync::Arc;
use std::time::Duration;
use telemetry_events::EditPredictionRating;
-use zeta_prompt::ZetaPromptInput;
+use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format};
use crate::example::Example;
use crate::progress::{InfoStyle, Progress, Step};
const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
-use std::fmt::Write as _;
pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
+const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
/// Minimum Zed version for filtering captured examples.
/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
@@ -34,6 +35,7 @@ pub struct MinCaptureVersion {
}
const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
+const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
@@ -52,6 +54,11 @@ pub fn parse_requested_after_input(input: &str) -> Option<&str> {
input.strip_prefix("requested-after:")
}
+/// Parse an input token of the form `settled-after:{timestamp}`.
+pub fn parse_settled_after_input(input: &str) -> Option<&str> {
+ input.strip_prefix("settled-after:")
+}
+
/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
/// or `rated-negative-after:{timestamp}`.
/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
@@ -596,6 +603,163 @@ pub async fn fetch_requested_examples_after(
Ok(all_examples)
}
+pub async fn fetch_settled_examples_after(
+ http_client: Arc<dyn HttpClient>,
+ after_timestamps: &[String],
+ max_rows_per_timestamp: usize,
+ offset: usize,
+ background_executor: BackgroundExecutor,
+ min_capture_version: Option<MinCaptureVersion>,
+) -> Result<Vec<Example>> {
+ if after_timestamps.is_empty() {
+ return Ok(Vec::new());
+ }
+
+ let progress = Progress::global();
+
+ let token = std::env::var("EP_SNOWFLAKE_API_KEY")
+ .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
+ let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
+ "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
+ )?;
+ let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
+
+ let mut all_examples = Vec::new();
+
+ for after_date in after_timestamps.iter() {
+ let step_progress_name = format!("settled>{after_date}");
+ let step_progress = progress.start(Step::PullExamples, &step_progress_name);
+ step_progress.set_substatus("querying");
+
+ let statement = indoc! {r#"
+ WITH requested AS (
+ SELECT
+ req.event_properties:request_id::string AS request_id,
+ req.device_id::string AS device_id,
+ req.time AS req_time,
+ req.time::string AS time,
+ req.event_properties:input AS input,
+ req.event_properties:format::string AS requested_format,
+ req.event_properties:output::string AS requested_output,
+ req.event_properties:zed_version::string AS zed_version
+ FROM events req
+ WHERE req.event_type = ?
+ AND req.event_properties:version = 'V3'
+ AND req.event_properties:input:can_collect_data = true
+ AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
+ )
+ SELECT
+ req.request_id AS request_id,
+ req.device_id AS device_id,
+ req.time AS time,
+ req.input AS input,
+ req.requested_output AS requested_output,
+ settled.event_properties:settled_editable_region::string AS settled_editable_region,
+ req.requested_format AS requested_format,
+ req.zed_version AS zed_version
+ FROM requested req
+ INNER JOIN events settled
+ ON req.request_id = settled.event_properties:request_id::string
+ WHERE settled.event_type = ?
+ ORDER BY req.req_time ASC
+ LIMIT ?
+ OFFSET ?
+ "#};
+
+ let _ = min_capture_version;
+ let request = json!({
+ "statement": statement,
+ "timeout": SETTLED_STATEMENT_TIMEOUT_SECONDS,
+ "database": "EVENTS",
+ "schema": "PUBLIC",
+ "warehouse": "DBT",
+ "role": role,
+ "bindings": {
+ "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
+ "2": { "type": "TEXT", "value": after_date },
+ "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
+ "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
+ "5": { "type": "FIXED", "value": offset.to_string() }
+ }
+ });
+
+ let response = run_sql_with_polling(
+ http_client.clone(),
+ &base_url,
+ &token,
+ &request,
+ &step_progress,
+ background_executor.clone(),
+ )
+ .await?;
+
+ let total_rows = response
+ .result_set_meta_data
+ .as_ref()
+ .and_then(|m| m.num_rows)
+ .unwrap_or(response.data.len() as i64);
+
+ let num_partitions = response
+ .result_set_meta_data
+ .as_ref()
+ .map(|m| m.partition_info.len())
+ .unwrap_or(1)
+ .max(1);
+
+ step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
+ step_progress.set_substatus("parsing");
+
+ let column_indices = get_column_indices(
+ &response.result_set_meta_data,
+ &[
+ "request_id",
+ "device_id",
+ "time",
+ "input",
+ "requested_output",
+ "settled_editable_region",
+ "requested_format",
+ "zed_version",
+ ],
+ );
+
+ all_examples.extend(settled_examples_from_response(&response, &column_indices)?);
+
+ if num_partitions > 1 {
+ let statement_handle = response
+ .statement_handle
+ .as_ref()
+ .context("response has multiple partitions but no statementHandle")?;
+
+ for partition in 1..num_partitions {
+ step_progress.set_substatus(format!(
+ "fetching partition {}/{}",
+ partition + 1,
+ num_partitions
+ ));
+
+ let partition_response = fetch_partition(
+ http_client.clone(),
+ &base_url,
+ &token,
+ statement_handle,
+ partition,
+ )
+ .await?;
+
+ all_examples.extend(settled_examples_from_response(
+ &partition_response,
+ &column_indices,
+ )?);
+ }
+ }
+
+ step_progress.set_substatus("done");
+ }
+
+ Ok(all_examples)
+}
+
pub async fn fetch_rated_examples_after(
http_client: Arc<dyn HttpClient>,
inputs: &[(String, Option<EditPredictionRating>)],
@@ -989,6 +1153,186 @@ fn requested_examples_from_response<'a>(
Ok(iter)
}
+fn settled_examples_from_response<'a>(
+ response: &'a SnowflakeStatementResponse,
+ column_indices: &'a std::collections::HashMap<String, usize>,
+) -> Result<impl Iterator<Item = Example> + 'a> {
+ if let Some(code) = &response.code {
+ if code != SNOWFLAKE_SUCCESS_CODE {
+ anyhow::bail!(
+ "snowflake sql api returned error code={code} message={}",
+ response.message.as_deref().unwrap_or("<no message>")
+ );
+ }
+ }
+
+ let iter = response
+ .data
+ .iter()
+ .enumerate()
+ .filter_map(move |(row_index, data_row)| {
+ let get_value = |name: &str| -> Option<JsonValue> {
+ let index = column_indices.get(name).copied()?;
+ let value = data_row.get(index)?;
+ if value.is_null() {
+ None
+ } else {
+ Some(value.clone())
+ }
+ };
+
+ let get_string = |name: &str| -> Option<String> {
+ match get_value(name)? {
+ JsonValue::String(s) => Some(s),
+ other => Some(other.to_string()),
+ }
+ };
+
+ let parse_json_value = |_: &str, raw: Option<&JsonValue>| -> Option<JsonValue> {
+ let value = raw?;
+ match value {
+ JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
+ other => Some(other.clone()),
+ }
+ };
+
+ let request_id_str = get_string("request_id");
+ let device_id = get_string("device_id");
+ let time = get_string("time");
+ let input_raw = get_value("input");
+ let input_json = parse_json_value("input", input_raw.as_ref());
+ let input: Option<ZetaPromptInput> = input_json
+ .as_ref()
+ .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
+ let requested_output = get_string("requested_output");
+ let settled_editable_region = get_string("settled_editable_region");
+ let requested_format =
+ get_string("requested_format").and_then(|s| ZetaFormat::parse(&s).ok());
+ let zed_version = get_string("zed_version");
+
+ match (
+ request_id_str.clone(),
+ device_id.clone(),
+ time.clone(),
+ input.clone(),
+ requested_output.clone(),
+ settled_editable_region.clone(),
+ requested_format,
+ ) {
+ (
+ Some(request_id),
+ Some(device_id),
+ Some(time),
+ Some(input),
+ Some(requested_output),
+ Some(settled_editable_region),
+ Some(requested_format),
+ ) => Some(build_settled_example(
+ request_id,
+ device_id,
+ time,
+ input,
+ requested_output,
+ settled_editable_region,
+ requested_format,
+ zed_version,
+ )),
+ _ => {
+ let mut missing_fields = Vec::new();
+
+ if request_id_str.is_none() {
+ missing_fields.push("request_id");
+ }
+ if device_id.is_none() {
+ missing_fields.push("device_id");
+ }
+ if time.is_none() {
+ missing_fields.push("time");
+ }
+ if input_raw.is_none() || input_json.is_none() || input.is_none() {
+ missing_fields.push("input");
+ }
+ if requested_output.is_none() {
+ missing_fields.push("requested_output");
+ }
+ if settled_editable_region.is_none() {
+ missing_fields.push("settled_editable_region");
+ }
+ if requested_format.is_none() {
+ missing_fields.push("requested_format");
+ }
+
+ log::warn!(
+ "skipping settled row {row_index}: [{}]",
+ missing_fields.join(", "),
+ );
+ None
+ }
+ }
+ });
+
+ Ok(iter)
+}
+
+fn build_settled_example(
+ request_id: String,
+ device_id: String,
+ time: String,
+ input: ZetaPromptInput,
+ requested_output: String,
+ settled_editable_region: String,
+ requested_format: ZetaFormat,
+ zed_version: Option<String>,
+) -> Example {
+ let requested_editable_range = input
+ .excerpt_ranges
+ .as_ref()
+ .map(|ranges| excerpt_range_for_format(requested_format, ranges).0)
+ .unwrap_or_else(|| input.editable_range_in_excerpt.clone());
+
+ let base_cursor_excerpt = input.cursor_excerpt.to_string();
+
+ let requested_range_is_valid = requested_editable_range.start <= requested_editable_range.end
+ && requested_editable_range.end <= base_cursor_excerpt.len();
+ let mut example = build_example_from_snowflake(
+ request_id.clone(),
+ device_id,
+ time,
+ input,
+ vec!["settled".to_string()],
+ None,
+ zed_version,
+ );
+
+ if !requested_range_is_valid {
+ log::warn!(
+ "skipping malformed requested range for request {}: requested={:?} (base_len={})",
+ request_id,
+ requested_editable_range,
+ base_cursor_excerpt.len(),
+ );
+ return example;
+ }
+
+ let settled_replacement = settled_editable_region.as_str();
+ let rejected_patch = build_output_patch(
+ &example.spec.cursor_path,
+ &base_cursor_excerpt,
+ &requested_editable_range,
+ &requested_output,
+ );
+ let expected_patch = build_output_patch(
+ &example.spec.cursor_path,
+ &base_cursor_excerpt,
+ &requested_editable_range,
+ settled_replacement,
+ );
+
+ example.spec.expected_patches = vec![expected_patch];
+ example.spec.rejected_patch = Some(rejected_patch);
+ example
+}
+
fn rejected_examples_from_response<'a>(
response: &'a SnowflakeStatementResponse,
column_indices: &'a std::collections::HashMap<String, usize>,