@@ -39,6 +39,11 @@ pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
input.strip_prefix("rejected-after:")
}
+/// Parse an input token of the form `requested-after:{timestamp}`.
+pub fn parse_requested_after_input(input: &str) -> Option<&str> {
+ input.strip_prefix("requested-after:")
+}
+
pub async fn fetch_captured_examples_after(
http_client: Arc<dyn HttpClient>,
after_timestamps: &[String],
@@ -556,6 +561,204 @@ pub async fn fetch_rejected_examples_after(
Ok(all_examples)
}
+pub async fn fetch_requested_examples_after(
+ http_client: Arc<dyn HttpClient>,
+ after_timestamps: &[String],
+ max_rows_per_timestamp: usize,
+ background_executor: BackgroundExecutor,
+) -> Result<Vec<Example>> {
+ if after_timestamps.is_empty() {
+ return Ok(Vec::new());
+ }
+
+ let progress = Progress::global();
+
+ let token = std::env::var("EP_SNOWFLAKE_API_KEY")
+ .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
+ let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
+ "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
+ )?;
+ let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
+
+ let mut all_examples = Vec::new();
+
+ for after_date in after_timestamps.iter() {
+ let step_progress_name = format!("requested>{after_date}");
+ let step_progress = progress.start(Step::PullExamples, &step_progress_name);
+ step_progress.set_substatus("querying");
+
+ let statement = indoc! {r#"
+ SELECT
+ req.event_properties:request_id::string AS request_id,
+ req.device_id::string AS device_id,
+ req.time::string AS time,
+ req.event_properties:input AS input
+ FROM events req
+ WHERE req.event_type = ?
+ AND req.event_properties:version = 'V3'
+ AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
+ ORDER BY req.time ASC
+ LIMIT ?
+ "#};
+
+ let request = json!({
+ "statement": statement,
+ "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+ "database": "EVENTS",
+ "schema": "PUBLIC",
+ "warehouse": "DBT",
+ "role": role,
+ "bindings": {
+ "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
+ "2": { "type": "TEXT", "value": after_date },
+ "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
+ }
+ });
+
+ let response = run_sql_with_polling(
+ http_client.clone(),
+ &base_url,
+ &token,
+ &request,
+ &step_progress,
+ background_executor.clone(),
+ )
+ .await?;
+
+ let total_rows = response
+ .result_set_meta_data
+ .as_ref()
+ .and_then(|m| m.num_rows)
+ .unwrap_or(response.data.len() as i64);
+
+ let num_partitions = response
+ .result_set_meta_data
+ .as_ref()
+ .map(|m| m.partition_info.len())
+ .unwrap_or(1)
+ .max(1);
+
+ step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
+ step_progress.set_substatus("parsing");
+
+ let column_indices = get_column_indices(
+ &response.result_set_meta_data,
+ &["request_id", "device_id", "time", "input"],
+ );
+
+ all_examples.extend(requested_examples_from_response(
+ &response,
+ &column_indices,
+ )?);
+
+ if num_partitions > 1 {
+ let statement_handle = response
+ .statement_handle
+ .as_ref()
+ .context("response has multiple partitions but no statementHandle")?;
+
+ for partition in 1..num_partitions {
+ step_progress.set_substatus(format!(
+ "fetching partition {}/{}",
+ partition + 1,
+ num_partitions
+ ));
+
+ let partition_response = fetch_partition(
+ http_client.clone(),
+ &base_url,
+ &token,
+ statement_handle,
+ partition,
+ )
+ .await?;
+
+ all_examples.extend(requested_examples_from_response(
+ &partition_response,
+ &column_indices,
+ )?);
+ }
+ }
+
+ step_progress.set_substatus("done");
+ }
+
+ Ok(all_examples)
+}
+
+fn requested_examples_from_response<'a>(
+ response: &'a SnowflakeStatementResponse,
+ column_indices: &'a std::collections::HashMap<String, usize>,
+) -> Result<impl Iterator<Item = Example> + 'a> {
+ if let Some(code) = &response.code {
+ if code != SNOWFLAKE_SUCCESS_CODE {
+ anyhow::bail!(
+ "snowflake sql api returned error code={code} message={}",
+ response.message.as_deref().unwrap_or("<no message>")
+ );
+ }
+ }
+
+ let iter = response
+ .data
+ .iter()
+ .enumerate()
+ .filter_map(move |(row_index, data_row)| {
+ let get_string = |name: &str| -> Option<String> {
+ let index = column_indices.get(name).copied()?;
+ match data_row.get(index)? {
+ JsonValue::String(s) => Some(s.clone()),
+ JsonValue::Null => None,
+ other => Some(other.to_string()),
+ }
+ };
+
+ let get_json = |name: &str| -> Option<JsonValue> {
+ let index = column_indices.get(name).copied()?;
+ let value = data_row.get(index)?;
+ if value.is_null() {
+ return None;
+ }
+ match value {
+ JsonValue::String(s) => serde_json::from_str(s).ok(),
+ other => Some(other.clone()),
+ }
+ };
+
+ let request_id_str = get_string("request_id");
+ let device_id = get_string("device_id");
+ let time = get_string("time");
+ let input_json = get_json("input");
+ let input: Option<ZetaPromptInput> =
+ input_json.clone().and_then(|v| serde_json::from_value(v).ok());
+
+ match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
+ (Some(request_id), Some(device_id), Some(time), Some(input)) => {
+ Some(build_example_from_snowflake(
+ request_id,
+ device_id,
+ time,
+ input,
+ vec!["requested".to_string()],
+ None,
+ ))
+ }
+ _ => {
+ log::warn!(
+ "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
+ request_id_str.is_some(),
+ device_id.is_some(),
+ time.is_some(),
+ input_json.is_some(),
+ );
+ None
+ }
+ }
+ });
+
+ Ok(iter)
+}
+
fn rejected_examples_from_response(
response: &SnowflakeStatementResponse,
) -> Result<impl Iterator<Item = Example> + '_> {
@@ -666,6 +869,37 @@ fn build_rejected_example(
output: String,
was_shown: bool,
reason: String,
+) -> Example {
+ let rejected_patch = build_output_patch(
+ &input.cursor_path,
+ input.cursor_excerpt.as_ref(),
+ &input.editable_range_in_excerpt,
+ &output,
+ );
+ let mut example = build_example_from_snowflake(
+ request_id,
+ device_id,
+ time,
+ input,
+ vec![format!("rejection:{}", reason.to_lowercase())],
+ Some(RejectionInfo { reason, was_shown }),
+ );
+ example.spec.rejected_patch = Some(rejected_patch);
+ example
+}
+
+struct RejectionInfo {
+ reason: String,
+ was_shown: bool,
+}
+
+fn build_example_from_snowflake(
+ request_id: String,
+ device_id: String,
+ time: String,
+ input: ZetaPromptInput,
+ tags: Vec<String>,
+ rejection: Option<RejectionInfo>,
) -> Example {
let events: Vec<CapturedEvent> = input
.events
@@ -715,25 +949,23 @@ fn build_rejected_example(
edit_history.push('\n');
}
- let rejected_patch = build_rejected_patch(
- &input.cursor_path,
- cursor_excerpt,
- &input.editable_range_in_excerpt,
- &output,
- );
+ let (rejection_reason, was_shown) = match &rejection {
+ Some(r) => (r.reason.clone(), r.was_shown),
+ None => (String::new(), false),
+ };
let spec = ExampleSpec {
name: request_id.clone(),
repository_url: String::new(),
revision: String::new(),
- tags: vec![format!("rejection:{}", reason.to_lowercase())],
+ tags,
reasoning: None,
uncommitted_diff: String::new(),
cursor_path: input.cursor_path.clone(),
cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
edit_history,
expected_patches: Vec::new(),
- rejected_patch: Some(rejected_patch),
+ rejected_patch: None,
captured_prompt_input: Some(CapturedPromptInput {
cursor_file_content: cursor_excerpt.to_string(),
cursor_offset,
@@ -746,7 +978,7 @@ fn build_rejected_example(
request_id,
device_id,
time,
- rejection_reason: reason,
+ rejection_reason,
was_shown,
}),
};
@@ -784,7 +1016,7 @@ fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
format!("{}[CURSOR_POSITION]{}", before, after)
}
-fn build_rejected_patch(
+fn build_output_patch(
cursor_path: &std::path::Path,
cursor_excerpt: &str,
editable_range: &std::ops::Range<usize>,