diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 4d506c334c600e4946f4a5711677b79c3e6281a6..6790491c69ae2888ca78bbd05db76ccacb92f974 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -435,6 +435,7 @@ async fn load_examples( ) -> anyhow::Result> { let mut captured_after_timestamps = Vec::new(); let mut rejected_after_timestamps = Vec::new(); + let mut requested_after_timestamps = Vec::new(); let mut file_inputs = Vec::new(); for input in &args.inputs { @@ -445,6 +446,10 @@ async fn load_examples( pull_examples::parse_rejected_after_input(input_string.as_ref()) { rejected_after_timestamps.push(timestamp.to_string()); + } else if let Some(timestamp) = + pull_examples::parse_requested_after_input(input_string.as_ref()) + { + requested_after_timestamps.push(timestamp.to_string()); } else { file_inputs.push(input.clone()); } @@ -481,14 +486,27 @@ async fn load_examples( rejected_after_timestamps.sort(); let mut rejected_examples = pull_examples::fetch_rejected_examples_after( - http_client, + http_client.clone(), &rejected_after_timestamps, max_rows_per_timestamp, - background_executor, + background_executor.clone(), ) .await?; examples.append(&mut rejected_examples); } + + if !requested_after_timestamps.is_empty() { + requested_after_timestamps.sort(); + + let mut requested_examples = pull_examples::fetch_requested_examples_after( + http_client, + &requested_after_timestamps, + max_rows_per_timestamp, + background_executor, + ) + .await?; + examples.append(&mut requested_examples); + } } crate::example::sort_examples_by_repo_and_rev(&mut examples); diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index fadc724f067d5f1cc907202894cf798f6d78bab3..ce6886fd87fd2464940076d475ce3bc6f0061d9b 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -39,6 +39,11 @@ pub fn parse_rejected_after_input(input: &str) -> Option<&str> { input.strip_prefix("rejected-after:") } +/// Parse an input token of the form `requested-after:{timestamp}`. +pub fn parse_requested_after_input(input: &str) -> Option<&str> { + input.strip_prefix("requested-after:") +} + pub async fn fetch_captured_examples_after( http_client: Arc, after_timestamps: &[String], @@ -556,6 +561,204 @@ pub async fn fetch_rejected_examples_after( Ok(all_examples) } +pub async fn fetch_requested_examples_after( + http_client: Arc, + after_timestamps: &[String], + max_rows_per_timestamp: usize, + background_executor: BackgroundExecutor, +) -> Result> { + if after_timestamps.is_empty() { + return Ok(Vec::new()); + } + + let progress = Progress::global(); + + let token = std::env::var("EP_SNOWFLAKE_API_KEY") + .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?; + let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context( + "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://.snowflakecomputing.com)", + )?; + let role = std::env::var("EP_SNOWFLAKE_ROLE").ok(); + + let mut all_examples = Vec::new(); + + for after_date in after_timestamps.iter() { + let step_progress_name = format!("requested>{after_date}"); + let step_progress = progress.start(Step::PullExamples, &step_progress_name); + step_progress.set_substatus("querying"); + + let statement = indoc! {r#" + SELECT + req.event_properties:request_id::string AS request_id, + req.device_id::string AS device_id, + req.time::string AS time, + req.event_properties:input AS input + FROM events req + WHERE req.event_type = ? + AND req.event_properties:version = 'V3' + AND req.time > TRY_TO_TIMESTAMP_NTZ(?) + ORDER BY req.time ASC + LIMIT ? + "#}; + + let request = json!({ + "statement": statement, + "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS, + "database": "EVENTS", + "schema": "PUBLIC", + "warehouse": "DBT", + "role": role, + "bindings": { + "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, + "2": { "type": "TEXT", "value": after_date }, + "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() } + } + }); + + let response = run_sql_with_polling( + http_client.clone(), + &base_url, + &token, + &request, + &step_progress, + background_executor.clone(), + ) + .await?; + + let total_rows = response + .result_set_meta_data + .as_ref() + .and_then(|m| m.num_rows) + .unwrap_or(response.data.len() as i64); + + let num_partitions = response + .result_set_meta_data + .as_ref() + .map(|m| m.partition_info.len()) + .unwrap_or(1) + .max(1); + + step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal); + step_progress.set_substatus("parsing"); + + let column_indices = get_column_indices( + &response.result_set_meta_data, + &["request_id", "device_id", "time", "input"], + ); + + all_examples.extend(requested_examples_from_response( + &response, + &column_indices, + )?); + + if num_partitions > 1 { + let statement_handle = response + .statement_handle + .as_ref() + .context("response has multiple partitions but no statementHandle")?; + + for partition in 1..num_partitions { + step_progress.set_substatus(format!( + "fetching partition {}/{}", + partition + 1, + num_partitions + )); + + let partition_response = fetch_partition( + http_client.clone(), + &base_url, + &token, + statement_handle, + partition, + ) + .await?; + + all_examples.extend(requested_examples_from_response( + &partition_response, + &column_indices, + )?); + } + } + + step_progress.set_substatus("done"); + } + + Ok(all_examples) +} + +fn requested_examples_from_response<'a>( + response: &'a SnowflakeStatementResponse, + column_indices: &'a std::collections::HashMap, +) -> Result + 'a> { + if let Some(code) = &response.code { + if code != SNOWFLAKE_SUCCESS_CODE { + anyhow::bail!( + "snowflake sql api returned error code={code} message={}", + response.message.as_deref().unwrap_or("") + ); + } + } + + let iter = response + .data + .iter() + .enumerate() + .filter_map(move |(row_index, data_row)| { + let get_string = |name: &str| -> Option { + let index = column_indices.get(name).copied()?; + match data_row.get(index)? { + JsonValue::String(s) => Some(s.clone()), + JsonValue::Null => None, + other => Some(other.to_string()), + } + }; + + let get_json = |name: &str| -> Option { + let index = column_indices.get(name).copied()?; + let value = data_row.get(index)?; + if value.is_null() { + return None; + } + match value { + JsonValue::String(s) => serde_json::from_str(s).ok(), + other => Some(other.clone()), + } + }; + + let request_id_str = get_string("request_id"); + let device_id = get_string("device_id"); + let time = get_string("time"); + let input_json = get_json("input"); + let input: Option = + input_json.clone().and_then(|v| serde_json::from_value(v).ok()); + + match (request_id_str.clone(), device_id.clone(), time.clone(), input) { + (Some(request_id), Some(device_id), Some(time), Some(input)) => { + Some(build_example_from_snowflake( + request_id, + device_id, + time, + input, + vec!["requested".to_string()], + None, + )) + } + _ => { + log::warn!( + "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}", + request_id_str.is_some(), + device_id.is_some(), + time.is_some(), + input_json.is_some(), + ); + None + } + } + }); + + Ok(iter) +} + fn rejected_examples_from_response( response: &SnowflakeStatementResponse, ) -> Result + '_> { @@ -666,6 +869,37 @@ fn build_rejected_example( output: String, was_shown: bool, reason: String, +) -> Example { + let rejected_patch = build_output_patch( + &input.cursor_path, + input.cursor_excerpt.as_ref(), + &input.editable_range_in_excerpt, + &output, + ); + let mut example = build_example_from_snowflake( + request_id, + device_id, + time, + input, + vec![format!("rejection:{}", reason.to_lowercase())], + Some(RejectionInfo { reason, was_shown }), + ); + example.spec.rejected_patch = Some(rejected_patch); + example +} + +struct RejectionInfo { + reason: String, + was_shown: bool, +} + +fn build_example_from_snowflake( + request_id: String, + device_id: String, + time: String, + input: ZetaPromptInput, + tags: Vec, + rejection: Option, ) -> Example { let events: Vec = input .events @@ -715,25 +949,23 @@ fn build_rejected_example( edit_history.push('\n'); } - let rejected_patch = build_rejected_patch( - &input.cursor_path, - cursor_excerpt, - &input.editable_range_in_excerpt, - &output, - ); + let (rejection_reason, was_shown) = match &rejection { + Some(r) => (r.reason.clone(), r.was_shown), + None => (String::new(), false), + }; let spec = ExampleSpec { name: request_id.clone(), repository_url: String::new(), revision: String::new(), - tags: vec![format!("rejection:{}", reason.to_lowercase())], + tags, reasoning: None, uncommitted_diff: String::new(), cursor_path: input.cursor_path.clone(), cursor_position: build_cursor_position(cursor_excerpt, cursor_offset), edit_history, expected_patches: Vec::new(), - rejected_patch: Some(rejected_patch), + rejected_patch: None, captured_prompt_input: Some(CapturedPromptInput { cursor_file_content: cursor_excerpt.to_string(), cursor_offset, @@ -746,7 +978,7 @@ fn build_rejected_example( request_id, device_id, time, - rejection_reason: reason, + rejection_reason, was_shown, }), }; @@ -784,7 +1016,7 @@ fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String { format!("{}[CURSOR_POSITION]{}", before, after) } -fn build_rejected_patch( +fn build_output_patch( cursor_path: &std::path::Path, cursor_excerpt: &str, editable_range: &std::ops::Range,