@@ -152,6 +152,103 @@ async fn run_sql_with_polling(
Ok(response)
}
+struct SnowflakeConfig {
+ token: String,
+ base_url: String,
+ role: Option<String>,
+}
+
+async fn fetch_examples_with_query(
+ http_client: Arc<dyn HttpClient>,
+ step_progress: &crate::progress::StepProgress,
+ background_executor: BackgroundExecutor,
+ statement: &str,
+ bindings: JsonValue,
+ timeout_seconds: u64,
+ required_columns: &[&str],
+ parse_response: for<'a> fn(
+ &'a SnowflakeStatementResponse,
+ &'a std::collections::HashMap<String, usize>,
+ ) -> Result<Box<dyn Iterator<Item = Example> + 'a>>,
+) -> Result<Vec<Example>> {
+ let snowflake = SnowflakeConfig {
+ token: std::env::var("EP_SNOWFLAKE_API_KEY")
+ .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?,
+ base_url: std::env::var("EP_SNOWFLAKE_BASE_URL").context(
+ "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
+ )?,
+ role: std::env::var("EP_SNOWFLAKE_ROLE").ok(),
+ };
+ let request = json!({
+ "statement": statement,
+ "timeout": timeout_seconds,
+ "database": "EVENTS",
+ "schema": "PUBLIC",
+ "warehouse": "DBT",
+ "role": snowflake.role.as_deref(),
+ "bindings": bindings
+ });
+
+ let response = run_sql_with_polling(
+ http_client.clone(),
+ &snowflake.base_url,
+ &snowflake.token,
+ &request,
+ step_progress,
+ background_executor,
+ )
+ .await?;
+
+ let total_rows = response
+ .result_set_meta_data
+ .as_ref()
+ .and_then(|meta| meta.num_rows)
+ .unwrap_or(response.data.len() as i64);
+ let partition_count = response
+ .result_set_meta_data
+ .as_ref()
+ .map(|meta| meta.partition_info.len())
+ .unwrap_or(1)
+ .max(1);
+
+ step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
+ step_progress.set_substatus("parsing");
+
+ let column_indices = get_column_indices(&response.result_set_meta_data, required_columns);
+
+ let mut parsed_examples = Vec::with_capacity(total_rows as usize);
+ parsed_examples.extend(parse_response(&response, &column_indices)?);
+
+ if partition_count > 1 {
+ let statement_handle = response
+ .statement_handle
+ .as_ref()
+ .context("response has multiple partitions but no statementHandle")?;
+
+ for partition in 1..partition_count {
+ step_progress.set_substatus(format!(
+ "fetching partition {}/{}",
+ partition + 1,
+ partition_count
+ ));
+
+ let partition_response = fetch_partition(
+ http_client.clone(),
+ &snowflake.base_url,
+ &snowflake.token,
+ statement_handle,
+ partition,
+ )
+ .await?;
+
+ parsed_examples.extend(parse_response(&partition_response, &column_indices)?);
+ }
+ }
+
+ step_progress.set_substatus("done");
+ Ok(parsed_examples)
+}
+
pub(crate) async fn fetch_partition(
http_client: Arc<dyn HttpClient>,
base_url: &str,
@@ -305,13 +402,6 @@ pub async fn fetch_rejected_examples_after(
let progress = Progress::global();
- let token = std::env::var("EP_SNOWFLAKE_API_KEY")
- .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
- let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
- "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
- )?;
- let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
-
let mut all_examples = Vec::new();
for after_date in after_timestamps.iter() {
@@ -319,10 +409,11 @@ pub async fn fetch_rejected_examples_after(
let step_progress = progress.start(Step::PullExamples, &step_progress_name);
step_progress.set_substatus("querying");
- // Join rejected events with their corresponding request events to get the full context.
- // We filter for V3 sampling data which contains the structured input we need.
- // We also filter for predictions that were actually shown to the user (was_shown = true)
- // to focus on explicit user rejections rather than implicit cancellations.
+ let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
+ let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
+ let min_minor_str_ref = min_minor_str.as_deref();
+ let min_patch_str_ref = min_patch_str.as_deref();
+
let statement = indoc! {r#"
SELECT
req.event_properties:request_id::string AS request_id,
@@ -355,58 +446,25 @@ pub async fn fetch_rejected_examples_after(
OFFSET ?
"#};
- let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
- let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
- let min_minor_str_ref = min_minor_str.as_deref();
- let min_patch_str_ref = min_patch_str.as_deref();
- let request = json!({
- "statement": statement,
- "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
- "database": "EVENTS",
- "schema": "PUBLIC",
- "warehouse": "DBT",
- "role": role,
- "bindings": {
- "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
- "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
- "3": { "type": "TEXT", "value": after_date },
- "4": { "type": "FIXED", "value": min_minor_str_ref },
- "5": { "type": "FIXED", "value": min_minor_str_ref },
- "6": { "type": "FIXED", "value": min_minor_str_ref },
- "7": { "type": "FIXED", "value": min_patch_str_ref },
- "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
- "9": { "type": "FIXED", "value": offset.to_string() }
- }
+ let bindings = json!({
+ "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
+ "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
+ "3": { "type": "TEXT", "value": after_date },
+ "4": { "type": "FIXED", "value": min_minor_str_ref },
+ "5": { "type": "FIXED", "value": min_minor_str_ref },
+ "6": { "type": "FIXED", "value": min_minor_str_ref },
+ "7": { "type": "FIXED", "value": min_patch_str_ref },
+ "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
+ "9": { "type": "FIXED", "value": offset.to_string() }
});
- let response = run_sql_with_polling(
+ let examples = fetch_examples_with_query(
http_client.clone(),
- &base_url,
- &token,
- &request,
&step_progress,
background_executor.clone(),
- )
- .await?;
-
- let total_rows = response
- .result_set_meta_data
- .as_ref()
- .and_then(|m| m.num_rows)
- .unwrap_or(response.data.len() as i64);
-
- let num_partitions = response
- .result_set_meta_data
- .as_ref()
- .map(|m| m.partition_info.len())
- .unwrap_or(1)
- .max(1);
-
- step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
- step_progress.set_substatus("parsing");
-
- let column_indices = get_column_indices(
- &response.result_set_meta_data,
+ statement,
+ bindings,
+ DEFAULT_STATEMENT_TIMEOUT_SECONDS,
&[
"request_id",
"device_id",
@@ -418,40 +476,11 @@ pub async fn fetch_rejected_examples_after(
"reason",
"zed_version",
],
- );
-
- all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
-
- if num_partitions > 1 {
- let statement_handle = response
- .statement_handle
- .as_ref()
- .context("response has multiple partitions but no statementHandle")?;
-
- for partition in 1..num_partitions {
- step_progress.set_substatus(format!(
- "fetching partition {}/{}",
- partition + 1,
- num_partitions
- ));
-
- let partition_response = fetch_partition(
- http_client.clone(),
- &base_url,
- &token,
- statement_handle,
- partition,
- )
- .await?;
-
- all_examples.extend(rejected_examples_from_response(
- &partition_response,
- &column_indices,
- )?);
- }
- }
+ rejected_examples_from_response,
+ )
+ .await?;
- step_progress.set_substatus("done");
+ all_examples.extend(examples);
}
Ok(all_examples)
@@ -471,13 +500,6 @@ pub async fn fetch_requested_examples_after(
let progress = Progress::global();
- let token = std::env::var("EP_SNOWFLAKE_API_KEY")
- .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
- let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
- "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
- )?;
- let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
-
let mut all_examples = Vec::new();
for after_date in after_timestamps.iter() {
@@ -485,6 +507,11 @@ pub async fn fetch_requested_examples_after(
let step_progress = progress.start(Step::PullExamples, &step_progress_name);
step_progress.set_substatus("querying");
+ let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
+ let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
+ let min_minor_str_ref = min_minor_str.as_deref();
+ let min_patch_str_ref = min_patch_str.as_deref();
+
let statement = indoc! {r#"
SELECT
req.event_properties:request_id::string AS request_id,
@@ -509,95 +536,30 @@ pub async fn fetch_requested_examples_after(
OFFSET ?
"#};
- let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
- let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
- let min_minor_str_ref = min_minor_str.as_deref();
- let min_patch_str_ref = min_patch_str.as_deref();
- let request = json!({
- "statement": statement,
- "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
- "database": "EVENTS",
- "schema": "PUBLIC",
- "warehouse": "DBT",
- "role": role,
- "bindings": {
- "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
- "2": { "type": "TEXT", "value": after_date },
- "3": { "type": "FIXED", "value": min_minor_str_ref },
- "4": { "type": "FIXED", "value": min_minor_str_ref },
- "5": { "type": "FIXED", "value": min_minor_str_ref },
- "6": { "type": "FIXED", "value": min_patch_str_ref },
- "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
- "8": { "type": "FIXED", "value": offset.to_string() }
- }
+ let bindings = json!({
+ "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
+ "2": { "type": "TEXT", "value": after_date },
+ "3": { "type": "FIXED", "value": min_minor_str_ref },
+ "4": { "type": "FIXED", "value": min_minor_str_ref },
+ "5": { "type": "FIXED", "value": min_minor_str_ref },
+ "6": { "type": "FIXED", "value": min_patch_str_ref },
+ "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
+ "8": { "type": "FIXED", "value": offset.to_string() }
});
- let response = run_sql_with_polling(
+ let examples = fetch_examples_with_query(
http_client.clone(),
- &base_url,
- &token,
- &request,
&step_progress,
background_executor.clone(),
+ statement,
+ bindings,
+ DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+ &["request_id", "device_id", "time", "input", "zed_version"],
+ requested_examples_from_response,
)
.await?;
- let total_rows = response
- .result_set_meta_data
- .as_ref()
- .and_then(|m| m.num_rows)
- .unwrap_or(response.data.len() as i64);
-
- let num_partitions = response
- .result_set_meta_data
- .as_ref()
- .map(|m| m.partition_info.len())
- .unwrap_or(1)
- .max(1);
-
- step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
- step_progress.set_substatus("parsing");
-
- let column_indices = get_column_indices(
- &response.result_set_meta_data,
- &["request_id", "device_id", "time", "input", "zed_version"],
- );
-
- all_examples.extend(requested_examples_from_response(
- &response,
- &column_indices,
- )?);
-
- if num_partitions > 1 {
- let statement_handle = response
- .statement_handle
- .as_ref()
- .context("response has multiple partitions but no statementHandle")?;
-
- for partition in 1..num_partitions {
- step_progress.set_substatus(format!(
- "fetching partition {}/{}",
- partition + 1,
- num_partitions
- ));
-
- let partition_response = fetch_partition(
- http_client.clone(),
- &base_url,
- &token,
- statement_handle,
- partition,
- )
- .await?;
-
- all_examples.extend(requested_examples_from_response(
- &partition_response,
- &column_indices,
- )?);
- }
- }
-
- step_progress.set_substatus("done");
+ all_examples.extend(examples);
}
Ok(all_examples)
@@ -617,13 +579,6 @@ pub async fn fetch_settled_examples_after(
let progress = Progress::global();
- let token = std::env::var("EP_SNOWFLAKE_API_KEY")
- .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
- let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
- "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
- )?;
- let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
-
let mut all_examples = Vec::new();
for after_date in after_timestamps.iter() {
@@ -631,6 +586,8 @@ pub async fn fetch_settled_examples_after(
let step_progress = progress.start(Step::PullExamples, &step_progress_name);
step_progress.set_substatus("querying");
+ let _ = min_capture_version;
+
let statement = indoc! {r#"
WITH requested AS (
SELECT
@@ -666,51 +623,21 @@ pub async fn fetch_settled_examples_after(
OFFSET ?
"#};
- let _ = min_capture_version;
- let request = json!({
- "statement": statement,
- "timeout": SETTLED_STATEMENT_TIMEOUT_SECONDS,
- "database": "EVENTS",
- "schema": "PUBLIC",
- "warehouse": "DBT",
- "role": role,
- "bindings": {
- "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
- "2": { "type": "TEXT", "value": after_date },
- "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
- "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
- "5": { "type": "FIXED", "value": offset.to_string() }
- }
+ let bindings = json!({
+ "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
+ "2": { "type": "TEXT", "value": after_date },
+ "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
+ "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
+ "5": { "type": "FIXED", "value": offset.to_string() }
});
- let response = run_sql_with_polling(
+ let examples = fetch_examples_with_query(
http_client.clone(),
- &base_url,
- &token,
- &request,
&step_progress,
background_executor.clone(),
- )
- .await?;
-
- let total_rows = response
- .result_set_meta_data
- .as_ref()
- .and_then(|m| m.num_rows)
- .unwrap_or(response.data.len() as i64);
-
- let num_partitions = response
- .result_set_meta_data
- .as_ref()
- .map(|m| m.partition_info.len())
- .unwrap_or(1)
- .max(1);
-
- step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
- step_progress.set_substatus("parsing");
-
- let column_indices = get_column_indices(
- &response.result_set_meta_data,
+ statement,
+ bindings,
+ SETTLED_STATEMENT_TIMEOUT_SECONDS,
&[
"request_id",
"device_id",
@@ -721,40 +648,11 @@ pub async fn fetch_settled_examples_after(
"requested_format",
"zed_version",
],
- );
-
- all_examples.extend(settled_examples_from_response(&response, &column_indices)?);
-
- if num_partitions > 1 {
- let statement_handle = response
- .statement_handle
- .as_ref()
- .context("response has multiple partitions but no statementHandle")?;
-
- for partition in 1..num_partitions {
- step_progress.set_substatus(format!(
- "fetching partition {}/{}",
- partition + 1,
- num_partitions
- ));
-
- let partition_response = fetch_partition(
- http_client.clone(),
- &base_url,
- &token,
- statement_handle,
- partition,
- )
- .await?;
-
- all_examples.extend(settled_examples_from_response(
- &partition_response,
- &column_indices,
- )?);
- }
- }
+ settled_examples_from_response,
+ )
+ .await?;
- step_progress.set_substatus("done");
+ all_examples.extend(examples);
}
Ok(all_examples)
@@ -774,13 +672,6 @@ pub async fn fetch_rated_examples_after(
let progress = Progress::global();
- let token = std::env::var("EP_SNOWFLAKE_API_KEY")
- .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
- let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
- "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
- )?;
- let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
-
let mut all_examples = Vec::new();
for (after_date, rating_filter) in inputs.iter() {
@@ -793,7 +684,7 @@ pub async fn fetch_rated_examples_after(
let step_progress = progress.start(Step::PullExamples, &step_progress_name);
step_progress.set_substatus("querying");
- let rating_value = rating_filter.as_ref().map(|r| match r {
+ let rating_value = rating_filter.as_ref().map(|rating| match rating {
EditPredictionRating::Positive => "Positive",
EditPredictionRating::Negative => "Negative",
});
@@ -841,44 +732,13 @@ pub async fn fetch_rated_examples_after(
"8": { "type": "FIXED", "value": offset.to_string() }
});
- let request = json!({
- "statement": statement,
- "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
- "database": "EVENTS",
- "schema": "PUBLIC",
- "warehouse": "DBT",
- "role": role,
- "bindings": bindings
- });
-
- let response = run_sql_with_polling(
+ let examples = fetch_examples_with_query(
http_client.clone(),
- &base_url,
- &token,
- &request,
&step_progress,
background_executor.clone(),
- )
- .await?;
-
- let total_rows = response
- .result_set_meta_data
- .as_ref()
- .and_then(|m| m.num_rows)
- .unwrap_or(response.data.len() as i64);
-
- let num_partitions = response
- .result_set_meta_data
- .as_ref()
- .map(|m| m.partition_info.len())
- .unwrap_or(1)
- .max(1);
-
- step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
- step_progress.set_substatus("parsing");
-
- let column_indices = get_column_indices(
- &response.result_set_meta_data,
+ statement,
+ bindings,
+ DEFAULT_STATEMENT_TIMEOUT_SECONDS,
&[
"request_id",
"inputs",
@@ -891,40 +751,11 @@ pub async fn fetch_rated_examples_after(
"environment",
"zed_version",
],
- );
-
- all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
-
- if num_partitions > 1 {
- let statement_handle = response
- .statement_handle
- .as_ref()
- .context("response has multiple partitions but no statementHandle")?;
-
- for partition in 1..num_partitions {
- step_progress.set_substatus(format!(
- "fetching partition {}/{}",
- partition + 1,
- num_partitions
- ));
-
- let partition_response = fetch_partition(
- http_client.clone(),
- &base_url,
- &token,
- statement_handle,
- partition,
- )
- .await?;
-
- all_examples.extend(rated_examples_from_response(
- &partition_response,
- &column_indices,
- )?);
- }
- }
+ rated_examples_from_response,
+ )
+ .await?;
- step_progress.set_substatus("done");
+ all_examples.extend(examples);
}
Ok(all_examples)
@@ -933,7 +764,7 @@ pub async fn fetch_rated_examples_after(
fn rated_examples_from_response<'a>(
response: &'a SnowflakeStatementResponse,
column_indices: &'a std::collections::HashMap<String, usize>,
-) -> Result<impl Iterator<Item = Example> + 'a> {
+) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
if let Some(code) = &response.code {
if code != SNOWFLAKE_SUCCESS_CODE {
anyhow::bail!(
@@ -1021,7 +852,7 @@ fn rated_examples_from_response<'a>(
}
});
- Ok(iter)
+ Ok(Box::new(iter))
}
fn build_rated_example(
@@ -1081,7 +912,7 @@ fn build_rated_example(
fn requested_examples_from_response<'a>(
response: &'a SnowflakeStatementResponse,
column_indices: &'a std::collections::HashMap<String, usize>,
-) -> Result<impl Iterator<Item = Example> + 'a> {
+) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
if let Some(code) = &response.code {
if code != SNOWFLAKE_SUCCESS_CODE {
anyhow::bail!(
@@ -1150,13 +981,13 @@ fn requested_examples_from_response<'a>(
}
});
- Ok(iter)
+ Ok(Box::new(iter))
}
fn settled_examples_from_response<'a>(
response: &'a SnowflakeStatementResponse,
column_indices: &'a std::collections::HashMap<String, usize>,
-) -> Result<impl Iterator<Item = Example> + 'a> {
+) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
if let Some(code) = &response.code {
if code != SNOWFLAKE_SUCCESS_CODE {
anyhow::bail!(
@@ -1271,7 +1102,7 @@ fn settled_examples_from_response<'a>(
}
});
- Ok(iter)
+ Ok(Box::new(iter))
}
fn build_settled_example(
@@ -1336,7 +1167,7 @@ fn build_settled_example(
fn rejected_examples_from_response<'a>(
response: &'a SnowflakeStatementResponse,
column_indices: &'a std::collections::HashMap<String, usize>,
-) -> Result<impl Iterator<Item = Example> + 'a> {
+) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
if let Some(code) = &response.code {
if code != SNOWFLAKE_SUCCESS_CODE {
anyhow::bail!(
@@ -1421,7 +1252,7 @@ fn rejected_examples_from_response<'a>(
}
});
- Ok(iter)
+ Ok(Box::new(iter))
}
fn build_rejected_example(