diff --git a/crates/edit_prediction_cli/src/pull_examples.rs b/crates/edit_prediction_cli/src/pull_examples.rs index e34fc62c031cbfa411d9d5a701a3e327d0be8166..b53a3d5546e1a5697550ed24715f049c36c98178 100644 --- a/crates/edit_prediction_cli/src/pull_examples.rs +++ b/crates/edit_prediction_cli/src/pull_examples.rs @@ -152,6 +152,103 @@ async fn run_sql_with_polling( Ok(response) } +struct SnowflakeConfig { + token: String, + base_url: String, + role: Option, +} + +async fn fetch_examples_with_query( + http_client: Arc, + step_progress: &crate::progress::StepProgress, + background_executor: BackgroundExecutor, + statement: &str, + bindings: JsonValue, + timeout_seconds: u64, + required_columns: &[&str], + parse_response: for<'a> fn( + &'a SnowflakeStatementResponse, + &'a std::collections::HashMap, + ) -> Result + 'a>>, +) -> Result> { + let snowflake = SnowflakeConfig { + token: std::env::var("EP_SNOWFLAKE_API_KEY") + .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?, + base_url: std::env::var("EP_SNOWFLAKE_BASE_URL").context( + "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://.snowflakecomputing.com)", + )?, + role: std::env::var("EP_SNOWFLAKE_ROLE").ok(), + }; + let request = json!({ + "statement": statement, + "timeout": timeout_seconds, + "database": "EVENTS", + "schema": "PUBLIC", + "warehouse": "DBT", + "role": snowflake.role.as_deref(), + "bindings": bindings + }); + + let response = run_sql_with_polling( + http_client.clone(), + &snowflake.base_url, + &snowflake.token, + &request, + step_progress, + background_executor, + ) + .await?; + + let total_rows = response + .result_set_meta_data + .as_ref() + .and_then(|meta| meta.num_rows) + .unwrap_or(response.data.len() as i64); + let partition_count = response + .result_set_meta_data + .as_ref() + .map(|meta| meta.partition_info.len()) + .unwrap_or(1) + .max(1); + + step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal); + step_progress.set_substatus("parsing"); + + let column_indices = get_column_indices(&response.result_set_meta_data, required_columns); + + let mut parsed_examples = Vec::with_capacity(total_rows as usize); + parsed_examples.extend(parse_response(&response, &column_indices)?); + + if partition_count > 1 { + let statement_handle = response + .statement_handle + .as_ref() + .context("response has multiple partitions but no statementHandle")?; + + for partition in 1..partition_count { + step_progress.set_substatus(format!( + "fetching partition {}/{}", + partition + 1, + partition_count + )); + + let partition_response = fetch_partition( + http_client.clone(), + &snowflake.base_url, + &snowflake.token, + statement_handle, + partition, + ) + .await?; + + parsed_examples.extend(parse_response(&partition_response, &column_indices)?); + } + } + + step_progress.set_substatus("done"); + Ok(parsed_examples) +} + pub(crate) async fn fetch_partition( http_client: Arc, base_url: &str, @@ -305,13 +402,6 @@ pub async fn fetch_rejected_examples_after( let progress = Progress::global(); - let token = std::env::var("EP_SNOWFLAKE_API_KEY") - .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?; - let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context( - "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://.snowflakecomputing.com)", - )?; - let role = std::env::var("EP_SNOWFLAKE_ROLE").ok(); - let mut all_examples = Vec::new(); for after_date in after_timestamps.iter() { @@ -319,10 +409,11 @@ pub async fn fetch_rejected_examples_after( let step_progress = progress.start(Step::PullExamples, &step_progress_name); step_progress.set_substatus("querying"); - // Join rejected events with their corresponding request events to get the full context. - // We filter for V3 sampling data which contains the structured input we need. - // We also filter for predictions that were actually shown to the user (was_shown = true) - // to focus on explicit user rejections rather than implicit cancellations. + let min_minor_str = min_capture_version.map(|version| version.minor.to_string()); + let min_patch_str = min_capture_version.map(|version| version.patch.to_string()); + let min_minor_str_ref = min_minor_str.as_deref(); + let min_patch_str_ref = min_patch_str.as_deref(); + let statement = indoc! {r#" SELECT req.event_properties:request_id::string AS request_id, @@ -355,58 +446,25 @@ pub async fn fetch_rejected_examples_after( OFFSET ? "#}; - let min_minor_str = min_capture_version.map(|v| v.minor.to_string()); - let min_patch_str = min_capture_version.map(|v| v.patch.to_string()); - let min_minor_str_ref = min_minor_str.as_deref(); - let min_patch_str_ref = min_patch_str.as_deref(); - let request = json!({ - "statement": statement, - "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS, - "database": "EVENTS", - "schema": "PUBLIC", - "warehouse": "DBT", - "role": role, - "bindings": { - "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, - "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT }, - "3": { "type": "TEXT", "value": after_date }, - "4": { "type": "FIXED", "value": min_minor_str_ref }, - "5": { "type": "FIXED", "value": min_minor_str_ref }, - "6": { "type": "FIXED", "value": min_minor_str_ref }, - "7": { "type": "FIXED", "value": min_patch_str_ref }, - "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }, - "9": { "type": "FIXED", "value": offset.to_string() } - } + let bindings = json!({ + "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, + "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT }, + "3": { "type": "TEXT", "value": after_date }, + "4": { "type": "FIXED", "value": min_minor_str_ref }, + "5": { "type": "FIXED", "value": min_minor_str_ref }, + "6": { "type": "FIXED", "value": min_minor_str_ref }, + "7": { "type": "FIXED", "value": min_patch_str_ref }, + "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }, + "9": { "type": "FIXED", "value": offset.to_string() } }); - let response = run_sql_with_polling( + let examples = fetch_examples_with_query( http_client.clone(), - &base_url, - &token, - &request, &step_progress, background_executor.clone(), - ) - .await?; - - let total_rows = response - .result_set_meta_data - .as_ref() - .and_then(|m| m.num_rows) - .unwrap_or(response.data.len() as i64); - - let num_partitions = response - .result_set_meta_data - .as_ref() - .map(|m| m.partition_info.len()) - .unwrap_or(1) - .max(1); - - step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal); - step_progress.set_substatus("parsing"); - - let column_indices = get_column_indices( - &response.result_set_meta_data, + statement, + bindings, + DEFAULT_STATEMENT_TIMEOUT_SECONDS, &[ "request_id", "device_id", @@ -418,40 +476,11 @@ pub async fn fetch_rejected_examples_after( "reason", "zed_version", ], - ); - - all_examples.extend(rejected_examples_from_response(&response, &column_indices)?); - - if num_partitions > 1 { - let statement_handle = response - .statement_handle - .as_ref() - .context("response has multiple partitions but no statementHandle")?; - - for partition in 1..num_partitions { - step_progress.set_substatus(format!( - "fetching partition {}/{}", - partition + 1, - num_partitions - )); - - let partition_response = fetch_partition( - http_client.clone(), - &base_url, - &token, - statement_handle, - partition, - ) - .await?; - - all_examples.extend(rejected_examples_from_response( - &partition_response, - &column_indices, - )?); - } - } + rejected_examples_from_response, + ) + .await?; - step_progress.set_substatus("done"); + all_examples.extend(examples); } Ok(all_examples) @@ -471,13 +500,6 @@ pub async fn fetch_requested_examples_after( let progress = Progress::global(); - let token = std::env::var("EP_SNOWFLAKE_API_KEY") - .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?; - let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context( - "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://.snowflakecomputing.com)", - )?; - let role = std::env::var("EP_SNOWFLAKE_ROLE").ok(); - let mut all_examples = Vec::new(); for after_date in after_timestamps.iter() { @@ -485,6 +507,11 @@ pub async fn fetch_requested_examples_after( let step_progress = progress.start(Step::PullExamples, &step_progress_name); step_progress.set_substatus("querying"); + let min_minor_str = min_capture_version.map(|version| version.minor.to_string()); + let min_patch_str = min_capture_version.map(|version| version.patch.to_string()); + let min_minor_str_ref = min_minor_str.as_deref(); + let min_patch_str_ref = min_patch_str.as_deref(); + let statement = indoc! {r#" SELECT req.event_properties:request_id::string AS request_id, @@ -509,95 +536,30 @@ pub async fn fetch_requested_examples_after( OFFSET ? "#}; - let min_minor_str = min_capture_version.map(|v| v.minor.to_string()); - let min_patch_str = min_capture_version.map(|v| v.patch.to_string()); - let min_minor_str_ref = min_minor_str.as_deref(); - let min_patch_str_ref = min_patch_str.as_deref(); - let request = json!({ - "statement": statement, - "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS, - "database": "EVENTS", - "schema": "PUBLIC", - "warehouse": "DBT", - "role": role, - "bindings": { - "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, - "2": { "type": "TEXT", "value": after_date }, - "3": { "type": "FIXED", "value": min_minor_str_ref }, - "4": { "type": "FIXED", "value": min_minor_str_ref }, - "5": { "type": "FIXED", "value": min_minor_str_ref }, - "6": { "type": "FIXED", "value": min_patch_str_ref }, - "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }, - "8": { "type": "FIXED", "value": offset.to_string() } - } + let bindings = json!({ + "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, + "2": { "type": "TEXT", "value": after_date }, + "3": { "type": "FIXED", "value": min_minor_str_ref }, + "4": { "type": "FIXED", "value": min_minor_str_ref }, + "5": { "type": "FIXED", "value": min_minor_str_ref }, + "6": { "type": "FIXED", "value": min_patch_str_ref }, + "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }, + "8": { "type": "FIXED", "value": offset.to_string() } }); - let response = run_sql_with_polling( + let examples = fetch_examples_with_query( http_client.clone(), - &base_url, - &token, - &request, &step_progress, background_executor.clone(), + statement, + bindings, + DEFAULT_STATEMENT_TIMEOUT_SECONDS, + &["request_id", "device_id", "time", "input", "zed_version"], + requested_examples_from_response, ) .await?; - let total_rows = response - .result_set_meta_data - .as_ref() - .and_then(|m| m.num_rows) - .unwrap_or(response.data.len() as i64); - - let num_partitions = response - .result_set_meta_data - .as_ref() - .map(|m| m.partition_info.len()) - .unwrap_or(1) - .max(1); - - step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal); - step_progress.set_substatus("parsing"); - - let column_indices = get_column_indices( - &response.result_set_meta_data, - &["request_id", "device_id", "time", "input", "zed_version"], - ); - - all_examples.extend(requested_examples_from_response( - &response, - &column_indices, - )?); - - if num_partitions > 1 { - let statement_handle = response - .statement_handle - .as_ref() - .context("response has multiple partitions but no statementHandle")?; - - for partition in 1..num_partitions { - step_progress.set_substatus(format!( - "fetching partition {}/{}", - partition + 1, - num_partitions - )); - - let partition_response = fetch_partition( - http_client.clone(), - &base_url, - &token, - statement_handle, - partition, - ) - .await?; - - all_examples.extend(requested_examples_from_response( - &partition_response, - &column_indices, - )?); - } - } - - step_progress.set_substatus("done"); + all_examples.extend(examples); } Ok(all_examples) @@ -617,13 +579,6 @@ pub async fn fetch_settled_examples_after( let progress = Progress::global(); - let token = std::env::var("EP_SNOWFLAKE_API_KEY") - .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?; - let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context( - "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://.snowflakecomputing.com)", - )?; - let role = std::env::var("EP_SNOWFLAKE_ROLE").ok(); - let mut all_examples = Vec::new(); for after_date in after_timestamps.iter() { @@ -631,6 +586,8 @@ pub async fn fetch_settled_examples_after( let step_progress = progress.start(Step::PullExamples, &step_progress_name); step_progress.set_substatus("querying"); + let _ = min_capture_version; + let statement = indoc! {r#" WITH requested AS ( SELECT @@ -666,51 +623,21 @@ pub async fn fetch_settled_examples_after( OFFSET ? "#}; - let _ = min_capture_version; - let request = json!({ - "statement": statement, - "timeout": SETTLED_STATEMENT_TIMEOUT_SECONDS, - "database": "EVENTS", - "schema": "PUBLIC", - "warehouse": "DBT", - "role": role, - "bindings": { - "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, - "2": { "type": "TEXT", "value": after_date }, - "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT }, - "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }, - "5": { "type": "FIXED", "value": offset.to_string() } - } + let bindings = json!({ + "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT }, + "2": { "type": "TEXT", "value": after_date }, + "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT }, + "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }, + "5": { "type": "FIXED", "value": offset.to_string() } }); - let response = run_sql_with_polling( + let examples = fetch_examples_with_query( http_client.clone(), - &base_url, - &token, - &request, &step_progress, background_executor.clone(), - ) - .await?; - - let total_rows = response - .result_set_meta_data - .as_ref() - .and_then(|m| m.num_rows) - .unwrap_or(response.data.len() as i64); - - let num_partitions = response - .result_set_meta_data - .as_ref() - .map(|m| m.partition_info.len()) - .unwrap_or(1) - .max(1); - - step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal); - step_progress.set_substatus("parsing"); - - let column_indices = get_column_indices( - &response.result_set_meta_data, + statement, + bindings, + SETTLED_STATEMENT_TIMEOUT_SECONDS, &[ "request_id", "device_id", @@ -721,40 +648,11 @@ pub async fn fetch_settled_examples_after( "requested_format", "zed_version", ], - ); - - all_examples.extend(settled_examples_from_response(&response, &column_indices)?); - - if num_partitions > 1 { - let statement_handle = response - .statement_handle - .as_ref() - .context("response has multiple partitions but no statementHandle")?; - - for partition in 1..num_partitions { - step_progress.set_substatus(format!( - "fetching partition {}/{}", - partition + 1, - num_partitions - )); - - let partition_response = fetch_partition( - http_client.clone(), - &base_url, - &token, - statement_handle, - partition, - ) - .await?; - - all_examples.extend(settled_examples_from_response( - &partition_response, - &column_indices, - )?); - } - } + settled_examples_from_response, + ) + .await?; - step_progress.set_substatus("done"); + all_examples.extend(examples); } Ok(all_examples) @@ -774,13 +672,6 @@ pub async fn fetch_rated_examples_after( let progress = Progress::global(); - let token = std::env::var("EP_SNOWFLAKE_API_KEY") - .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?; - let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context( - "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://.snowflakecomputing.com)", - )?; - let role = std::env::var("EP_SNOWFLAKE_ROLE").ok(); - let mut all_examples = Vec::new(); for (after_date, rating_filter) in inputs.iter() { @@ -793,7 +684,7 @@ pub async fn fetch_rated_examples_after( let step_progress = progress.start(Step::PullExamples, &step_progress_name); step_progress.set_substatus("querying"); - let rating_value = rating_filter.as_ref().map(|r| match r { + let rating_value = rating_filter.as_ref().map(|rating| match rating { EditPredictionRating::Positive => "Positive", EditPredictionRating::Negative => "Negative", }); @@ -841,44 +732,13 @@ pub async fn fetch_rated_examples_after( "8": { "type": "FIXED", "value": offset.to_string() } }); - let request = json!({ - "statement": statement, - "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS, - "database": "EVENTS", - "schema": "PUBLIC", - "warehouse": "DBT", - "role": role, - "bindings": bindings - }); - - let response = run_sql_with_polling( + let examples = fetch_examples_with_query( http_client.clone(), - &base_url, - &token, - &request, &step_progress, background_executor.clone(), - ) - .await?; - - let total_rows = response - .result_set_meta_data - .as_ref() - .and_then(|m| m.num_rows) - .unwrap_or(response.data.len() as i64); - - let num_partitions = response - .result_set_meta_data - .as_ref() - .map(|m| m.partition_info.len()) - .unwrap_or(1) - .max(1); - - step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal); - step_progress.set_substatus("parsing"); - - let column_indices = get_column_indices( - &response.result_set_meta_data, + statement, + bindings, + DEFAULT_STATEMENT_TIMEOUT_SECONDS, &[ "request_id", "inputs", @@ -891,40 +751,11 @@ pub async fn fetch_rated_examples_after( "environment", "zed_version", ], - ); - - all_examples.extend(rated_examples_from_response(&response, &column_indices)?); - - if num_partitions > 1 { - let statement_handle = response - .statement_handle - .as_ref() - .context("response has multiple partitions but no statementHandle")?; - - for partition in 1..num_partitions { - step_progress.set_substatus(format!( - "fetching partition {}/{}", - partition + 1, - num_partitions - )); - - let partition_response = fetch_partition( - http_client.clone(), - &base_url, - &token, - statement_handle, - partition, - ) - .await?; - - all_examples.extend(rated_examples_from_response( - &partition_response, - &column_indices, - )?); - } - } + rated_examples_from_response, + ) + .await?; - step_progress.set_substatus("done"); + all_examples.extend(examples); } Ok(all_examples) @@ -933,7 +764,7 @@ pub async fn fetch_rated_examples_after( fn rated_examples_from_response<'a>( response: &'a SnowflakeStatementResponse, column_indices: &'a std::collections::HashMap, -) -> Result + 'a> { +) -> Result + 'a>> { if let Some(code) = &response.code { if code != SNOWFLAKE_SUCCESS_CODE { anyhow::bail!( @@ -1021,7 +852,7 @@ fn rated_examples_from_response<'a>( } }); - Ok(iter) + Ok(Box::new(iter)) } fn build_rated_example( @@ -1081,7 +912,7 @@ fn build_rated_example( fn requested_examples_from_response<'a>( response: &'a SnowflakeStatementResponse, column_indices: &'a std::collections::HashMap, -) -> Result + 'a> { +) -> Result + 'a>> { if let Some(code) = &response.code { if code != SNOWFLAKE_SUCCESS_CODE { anyhow::bail!( @@ -1150,13 +981,13 @@ fn requested_examples_from_response<'a>( } }); - Ok(iter) + Ok(Box::new(iter)) } fn settled_examples_from_response<'a>( response: &'a SnowflakeStatementResponse, column_indices: &'a std::collections::HashMap, -) -> Result + 'a> { +) -> Result + 'a>> { if let Some(code) = &response.code { if code != SNOWFLAKE_SUCCESS_CODE { anyhow::bail!( @@ -1271,7 +1102,7 @@ fn settled_examples_from_response<'a>( } }); - Ok(iter) + Ok(Box::new(iter)) } fn build_settled_example( @@ -1336,7 +1167,7 @@ fn build_settled_example( fn rejected_examples_from_response<'a>( response: &'a SnowflakeStatementResponse, column_indices: &'a std::collections::HashMap, -) -> Result + 'a> { +) -> Result + 'a>> { if let Some(code) = &response.code { if code != SNOWFLAKE_SUCCESS_CODE { anyhow::bail!( @@ -1421,7 +1252,7 @@ fn rejected_examples_from_response<'a>( } }); - Ok(iter) + Ok(Box::new(iter)) } fn build_rejected_example(