@@ -5,6 +5,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request};
use indoc::indoc;
use serde::Deserialize;
use serde_json::{Value as JsonValue, json};
+use std::collections::HashMap;
use std::fmt::Write as _;
use std::io::Read;
use std::sync::Arc;
@@ -13,17 +14,14 @@ use telemetry_events::EditPredictionRating;
use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format};
-use crate::example::Example;
+use crate::PredictionProvider;
+use crate::example::{Example, ExamplePrompt};
use crate::progress::{InfoStyle, Progress, Step};
-const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
-const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
-const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
-const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
-const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
+const SNOWFLAKE_TIMEOUT_CODE: &str = "000630";
/// Minimum Zed version for filtering captured examples.
/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
@@ -34,10 +32,13 @@ pub struct MinCaptureVersion {
pub patch: u32,
}
-const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
-const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
-pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
+const PARTITION_FETCH_MAX_RETRIES: usize = 3;
+const PARTITION_FETCH_RETRY_DELAYS: [Duration; PARTITION_FETCH_MAX_RETRIES] = [
+ Duration::from_millis(500),
+ Duration::from_secs(1),
+ Duration::from_secs(2),
+];
/// Parse an input token of the form `captured-after:{timestamp}`.
pub fn parse_captured_after_input(input: &str) -> Option<&str> {
@@ -127,26 +128,25 @@ async fn run_sql_with_polling(
.context("async query response missing statementHandle")?
.clone();
- for attempt in 1..=MAX_POLL_ATTEMPTS {
+ for attempt in 0.. {
step_progress.set_substatus(format!("polling ({attempt})"));
background_executor.timer(POLL_INTERVAL).await;
- response =
- fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
+ response = fetch_partition_with_retries(
+ http_client.clone(),
+ base_url,
+ token,
+ &statement_handle,
+ 0,
+ background_executor.clone(),
+ )
+ .await?;
if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
break;
}
}
-
- if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
- anyhow::bail!(
- "query still running after {} poll attempts ({} seconds)",
- MAX_POLL_ATTEMPTS,
- MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
- );
- }
}
Ok(response)
@@ -158,19 +158,29 @@ struct SnowflakeConfig {
role: Option<String>,
}
-async fn fetch_examples_with_query(
+#[derive(Clone)]
+struct QueryRetryState {
+ resume_after: String,
+ remaining_limit: Option<usize>,
+ offset: usize,
+}
+
+async fn fetch_examples_with_query<MakeBindings>(
http_client: Arc<dyn HttpClient>,
step_progress: &crate::progress::StepProgress,
background_executor: BackgroundExecutor,
statement: &str,
- bindings: JsonValue,
- timeout_seconds: u64,
+ initial_retry_state: QueryRetryState,
+ make_bindings: MakeBindings,
required_columns: &[&str],
parse_response: for<'a> fn(
&'a SnowflakeStatementResponse,
- &'a std::collections::HashMap<String, usize>,
+ &'a HashMap<String, usize>,
) -> Result<Box<dyn Iterator<Item = Example> + 'a>>,
-) -> Result<Vec<Example>> {
+) -> Result<Vec<Example>>
+where
+ MakeBindings: Fn(&QueryRetryState) -> JsonValue,
+{
let snowflake = SnowflakeConfig {
token: std::env::var("EP_SNOWFLAKE_API_KEY")
.context("missing required environment variable EP_SNOWFLAKE_API_KEY")?,
@@ -179,74 +189,153 @@ async fn fetch_examples_with_query(
)?,
role: std::env::var("EP_SNOWFLAKE_ROLE").ok(),
};
- let request = json!({
- "statement": statement,
- "timeout": timeout_seconds,
- "database": "EVENTS",
- "schema": "PUBLIC",
- "warehouse": "DBT",
- "role": snowflake.role.as_deref(),
- "bindings": bindings
- });
- let response = run_sql_with_polling(
- http_client.clone(),
- &snowflake.base_url,
- &snowflake.token,
- &request,
- step_progress,
- background_executor,
- )
- .await?;
-
- let total_rows = response
- .result_set_meta_data
- .as_ref()
- .and_then(|meta| meta.num_rows)
- .unwrap_or(response.data.len() as i64);
- let partition_count = response
- .result_set_meta_data
- .as_ref()
- .map(|meta| meta.partition_info.len())
- .unwrap_or(1)
- .max(1);
-
- step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
- step_progress.set_substatus("parsing");
-
- let column_indices = get_column_indices(&response.result_set_meta_data, required_columns);
-
- let mut parsed_examples = Vec::with_capacity(total_rows as usize);
- parsed_examples.extend(parse_response(&response, &column_indices)?);
-
- if partition_count > 1 {
- let statement_handle = response
- .statement_handle
+ let mut requested_columns = required_columns.to_vec();
+ if !requested_columns.contains(&"continuation_time") {
+ requested_columns.push("continuation_time");
+ }
+
+ let mut parsed_examples = Vec::new();
+ let mut retry_state = initial_retry_state;
+ let mut retry_count = 0usize;
+
+ loop {
+ let bindings = make_bindings(&retry_state);
+ let request = json!({
+ "statement": statement,
+ "database": "EVENTS",
+ "schema": "PUBLIC",
+ "warehouse": "DBT",
+ "role": snowflake.role.as_deref(),
+ "bindings": bindings
+ });
+
+ let response = match run_sql_with_polling(
+ http_client.clone(),
+ &snowflake.base_url,
+ &snowflake.token,
+ &request,
+ step_progress,
+ background_executor.clone(),
+ )
+ .await
+ {
+ Ok(response) => response,
+ Err(error) => {
+ if is_snowflake_timeout_error(&error) && !parsed_examples.is_empty() {
+ retry_count += 1;
+ step_progress.set_substatus(format!(
+ "retrying from {} ({retry_count})",
+ retry_state.resume_after
+ ));
+ continue;
+ }
+
+ return Err(error);
+ }
+ };
+
+ let total_rows = response
+ .result_set_meta_data
+ .as_ref()
+ .and_then(|meta| meta.num_rows)
+ .unwrap_or(response.data.len() as i64);
+ let partition_count = response
+ .result_set_meta_data
.as_ref()
- .context("response has multiple partitions but no statementHandle")?;
+ .map(|meta| meta.partition_info.len())
+ .unwrap_or(1)
+ .max(1);
- for partition in 1..partition_count {
- step_progress.set_substatus(format!(
- "fetching partition {}/{}",
- partition + 1,
- partition_count
- ));
+ step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
+ step_progress.set_substatus("parsing");
- let partition_response = fetch_partition(
- http_client.clone(),
- &snowflake.base_url,
- &snowflake.token,
- statement_handle,
- partition,
- )
- .await?;
+ let column_indices = get_column_indices(&response.result_set_meta_data, &requested_columns);
+ let mut rows_fetched_this_attempt = 0usize;
+ let mut timed_out_fetching_partition = false;
+
+ parsed_examples.extend(parse_response(&response, &column_indices)?);
+ rows_fetched_this_attempt += response.data.len();
+ let mut last_continuation_time_this_attempt =
+ last_continuation_timestamp_from_response(&response, &column_indices);
- parsed_examples.extend(parse_response(&partition_response, &column_indices)?);
+ if partition_count > 1 {
+ let statement_handle = response
+ .statement_handle
+ .as_ref()
+ .context("response has multiple partitions but no statementHandle")?;
+
+ for partition in 1..partition_count {
+ step_progress.set_substatus(format!(
+ "fetching partition {}/{}",
+ partition + 1,
+ partition_count
+ ));
+
+ let partition_response = match fetch_partition_with_retries(
+ http_client.clone(),
+ &snowflake.base_url,
+ &snowflake.token,
+ statement_handle,
+ partition,
+ background_executor.clone(),
+ )
+ .await
+ {
+ Ok(response) => response,
+ Err(error) => {
+ if is_snowflake_timeout_error(&error) && rows_fetched_this_attempt > 0 {
+ timed_out_fetching_partition = true;
+ break;
+ }
+
+ return Err(error);
+ }
+ };
+
+ parsed_examples.extend(parse_response(&partition_response, &column_indices)?);
+ rows_fetched_this_attempt += partition_response.data.len();
+
+ if let Some(partition_continuation_time) =
+ last_continuation_timestamp_from_response(&partition_response, &column_indices)
+ {
+ last_continuation_time_this_attempt = Some(partition_continuation_time);
+ }
+ }
}
- }
- step_progress.set_substatus("done");
- Ok(parsed_examples)
+ if rows_fetched_this_attempt == 0 {
+ step_progress.set_substatus("done");
+ return Ok(parsed_examples);
+ }
+
+ if let Some(remaining_limit_value) = &mut retry_state.remaining_limit {
+ *remaining_limit_value =
+ remaining_limit_value.saturating_sub(rows_fetched_this_attempt);
+ if *remaining_limit_value == 0 {
+ step_progress.set_substatus("done");
+ return Ok(parsed_examples);
+ }
+ }
+
+ if !timed_out_fetching_partition {
+ step_progress.set_substatus("done");
+ return Ok(parsed_examples);
+ }
+
+ let Some(last_continuation_time_this_attempt) = last_continuation_time_this_attempt else {
+ step_progress.set_substatus("done");
+ return Ok(parsed_examples);
+ };
+
+ retry_state.resume_after = last_continuation_time_this_attempt;
+ retry_state.offset = 0;
+ retry_count += 1;
+ step_progress.set_substatus(format!(
+ "retrying from {} ({retry_count})",
+ retry_state.resume_after
+ ));
+ }
}
pub(crate) async fn fetch_partition(
@@ -338,6 +427,57 @@ pub(crate) async fn fetch_partition(
})
}
+async fn fetch_partition_with_retries(
+ http_client: Arc<dyn HttpClient>,
+ base_url: &str,
+ token: &str,
+ statement_handle: &str,
+ partition: usize,
+ background_executor: BackgroundExecutor,
+) -> Result<SnowflakeStatementResponse> {
+ let mut last_error = None;
+
+ for retry_attempt in 0..=PARTITION_FETCH_MAX_RETRIES {
+ match fetch_partition(
+ http_client.clone(),
+ base_url,
+ token,
+ statement_handle,
+ partition,
+ )
+ .await
+ {
+ Ok(response) => return Ok(response),
+ Err(error) => {
+ if retry_attempt == PARTITION_FETCH_MAX_RETRIES
+ || !is_transient_partition_fetch_error(&error)
+ {
+ return Err(error);
+ }
+
+ last_error = Some(error);
+ background_executor
+ .timer(PARTITION_FETCH_RETRY_DELAYS[retry_attempt])
+ .await;
+ }
+ }
+ }
+
+ match last_error {
+ Some(error) => Err(error),
+ None => anyhow::bail!("partition fetch retry loop exited without a result"),
+ }
+}
+
+fn is_transient_partition_fetch_error(error: &anyhow::Error) -> bool {
+ error.chain().any(|cause| {
+ let message = cause.to_string();
+ message.contains("failed to read Snowflake SQL API partition response body")
+ || message.contains("unexpected EOF")
+ || message.contains("peer closed connection without sending TLS close_notify")
+ })
+}
+
pub(crate) async fn run_sql(
http_client: Arc<dyn HttpClient>,
base_url: &str,
@@ -379,19 +519,32 @@ pub(crate) async fn run_sql(
bytes
};
- if !status.is_success() && status.as_u16() != 202 {
+ let snowflake_response = serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
+ .context("failed to parse Snowflake SQL API response JSON")?;
+
+ if !status.is_success() && status.as_u16() != 202 && !is_timeout_response(&snowflake_response) {
let body_text = String::from_utf8_lossy(&body_bytes);
anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
}
- serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
- .context("failed to parse Snowflake SQL API response JSON")
+ if is_timeout_response(&snowflake_response) {
+ anyhow::bail!(
+ "snowflake sql api timed out code={} message={}",
+ snowflake_response.code.as_deref().unwrap_or("<no code>"),
+ snowflake_response
+ .message
+ .as_deref()
+ .unwrap_or("<no message>")
+ );
+ }
+
+ Ok(snowflake_response)
}
pub async fn fetch_rejected_examples_after(
http_client: Arc<dyn HttpClient>,
after_timestamps: &[String],
- max_rows_per_timestamp: usize,
+ max_rows_per_timestamp: Option<usize>,
offset: usize,
background_executor: BackgroundExecutor,
min_capture_version: Option<MinCaptureVersion>,
@@ -416,55 +569,53 @@ pub async fn fetch_rejected_examples_after(
let statement = indoc! {r#"
SELECT
- req.event_properties:request_id::string AS request_id,
- req.device_id::string AS device_id,
- req.time::string AS time,
- req.event_properties:input AS input,
- req.event_properties:prompt::string AS prompt,
- req.event_properties:output::string AS output,
- rej.event_properties:was_shown::boolean AS was_shown,
- rej.event_properties:reason::string AS reason,
- req.event_properties:zed_version::string AS zed_version
- FROM events req
- INNER JOIN events rej
- ON req.event_properties:request_id = rej.event_properties:request_id
- WHERE req.event_type = ?
- AND rej.event_type = ?
- AND req.event_properties:version = 'V3'
- AND rej.event_properties:was_shown = true
- AND req.event_properties:input:can_collect_data = true
- AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
+ ep_request_id AS request_id,
+ device_id AS device_id,
+ requested_at::string AS continuation_time,
+ requested_at::string AS time,
+ input_payload AS input,
+ prompt AS prompt,
+ requested_output AS output,
+ is_ep_shown_before_rejected AS was_shown,
+ ep_rejected_reason AS reason,
+ zed_version AS zed_version
+ FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
+ WHERE ep_outcome LIKE 'Rejected%'
+ AND is_ep_shown_before_rejected = true
+ AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
AND (? IS NULL OR (
- TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
+ TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ?
OR (
- TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
- AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
+ TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ?
+ AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ?
)
))
- ORDER BY req.time ASC
+ ORDER BY requested_at ASC
LIMIT ?
OFFSET ?
"#};
- let bindings = json!({
- "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
- "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
- "3": { "type": "TEXT", "value": after_date },
- "4": { "type": "FIXED", "value": min_minor_str_ref },
- "5": { "type": "FIXED", "value": min_minor_str_ref },
- "6": { "type": "FIXED", "value": min_minor_str_ref },
- "7": { "type": "FIXED", "value": min_patch_str_ref },
- "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
- "9": { "type": "FIXED", "value": offset.to_string() }
- });
-
let examples = fetch_examples_with_query(
http_client.clone(),
&step_progress,
background_executor.clone(),
statement,
- bindings,
- DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+ QueryRetryState {
+ resume_after: after_date.clone(),
+ remaining_limit: max_rows_per_timestamp,
+ offset,
+ },
+ |retry_state| {
+ json!({
+ "1": { "type": "TEXT", "value": retry_state.resume_after },
+ "2": { "type": "FIXED", "value": min_minor_str_ref },
+ "3": { "type": "FIXED", "value": min_minor_str_ref },
+ "4": { "type": "FIXED", "value": min_minor_str_ref },
+ "5": { "type": "FIXED", "value": min_patch_str_ref },
+ "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
+ "7": { "type": "FIXED", "value": retry_state.offset.to_string() }
+ })
+ },
&[
"request_id",
"device_id",
@@ -486,10 +637,14 @@ pub async fn fetch_rejected_examples_after(
Ok(all_examples)
}
+fn format_limit(limit: Option<usize>) -> String {
+ return limit.map(|l| l.to_string()).unwrap_or("NULL".to_string());
+}
+
pub async fn fetch_requested_examples_after(
http_client: Arc<dyn HttpClient>,
after_timestamps: &[String],
- max_rows_per_timestamp: usize,
+ max_rows_per_timestamp: Option<usize>,
offset: usize,
background_executor: BackgroundExecutor,
min_capture_version: Option<MinCaptureVersion>,
@@ -514,46 +669,47 @@ pub async fn fetch_requested_examples_after(
let statement = indoc! {r#"
SELECT
- req.event_properties:request_id::string AS request_id,
- req.device_id::string AS device_id,
- req.time::string AS time,
- req.event_properties:input AS input,
- req.event_properties:zed_version::string AS zed_version
- FROM events req
- WHERE req.event_type = ?
- AND req.event_properties:version = 'V3'
- AND req.event_properties:input:can_collect_data = true
- AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
+ ep_request_id AS request_id,
+ device_id AS device_id,
+ requested_at::string AS continuation_time,
+ requested_at::string AS time,
+ input_payload AS input,
+ zed_version AS zed_version
+ FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
+ WHERE requested_at > TRY_TO_TIMESTAMP_NTZ(?)
AND (? IS NULL OR (
- TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
+ TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ?
OR (
- TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
- AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
+ TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ?
+ AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ?
)
))
- ORDER BY req.time ASC
+ ORDER BY requested_at ASC
LIMIT ?
OFFSET ?
"#};
- let bindings = json!({
- "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
- "2": { "type": "TEXT", "value": after_date },
- "3": { "type": "FIXED", "value": min_minor_str_ref },
- "4": { "type": "FIXED", "value": min_minor_str_ref },
- "5": { "type": "FIXED", "value": min_minor_str_ref },
- "6": { "type": "FIXED", "value": min_patch_str_ref },
- "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
- "8": { "type": "FIXED", "value": offset.to_string() }
- });
-
let examples = fetch_examples_with_query(
http_client.clone(),
&step_progress,
background_executor.clone(),
statement,
- bindings,
- DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+ QueryRetryState {
+ resume_after: after_date.clone(),
+ remaining_limit: max_rows_per_timestamp,
+ offset,
+ },
+ |retry_state| {
+ json!({
+ "1": { "type": "TEXT", "value": retry_state.resume_after },
+ "2": { "type": "FIXED", "value": min_minor_str_ref },
+ "3": { "type": "FIXED", "value": min_minor_str_ref },
+ "4": { "type": "FIXED", "value": min_minor_str_ref },
+ "5": { "type": "FIXED", "value": min_patch_str_ref },
+ "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
+ "7": { "type": "FIXED", "value": retry_state.offset.to_string() }
+ })
+ },
&["request_id", "device_id", "time", "input", "zed_version"],
requested_examples_from_response,
)
@@ -568,7 +724,7 @@ pub async fn fetch_requested_examples_after(
pub async fn fetch_captured_examples_after(
http_client: Arc<dyn HttpClient>,
after_timestamps: &[String],
- max_rows_per_timestamp: usize,
+ max_rows_per_timestamp: Option<usize>,
offset: usize,
background_executor: BackgroundExecutor,
min_capture_version: Option<MinCaptureVersion>,
@@ -593,54 +749,51 @@ pub async fn fetch_captured_examples_after(
let statement = indoc! {r#"
SELECT
- settled.event_properties:request_id::string AS request_id,
- settled.device_id::string AS device_id,
- settled.time::string AS time,
- req.event_properties:input AS input,
- settled.event_properties:settled_editable_region::string AS settled_editable_region,
- settled.event_properties:example AS example,
- req.event_properties:zed_version::string AS zed_version
- FROM events settled
- INNER JOIN events req
- ON settled.event_properties:request_id::string = req.event_properties:request_id::string
- WHERE settled.event_type = ?
- AND req.event_type = ?
- AND req.event_properties:version = 'V3'
- AND req.event_properties:input:can_collect_data = true
- AND settled.event_properties:example IS NOT NULL
- AND TYPEOF(settled.event_properties:example) != 'NULL_VALUE'
- AND settled.time > TRY_TO_TIMESTAMP_NTZ(?)
+ ep_request_id AS request_id,
+ device_id AS device_id,
+ requested_at::string AS continuation_time,
+ requested_at::string AS time,
+ input_payload AS input,
+ settled_editable_region AS settled_editable_region,
+ example_payload AS example,
+ zed_version AS zed_version
+ FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
+ WHERE settled_editable_region IS NOT NULL
+ AND example_payload IS NOT NULL
+ AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
AND (? IS NULL OR (
- TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
+ TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) > ?
OR (
- TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
- AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
+ TRY_CAST(SPLIT_PART(zed_version, '.', 2) AS INTEGER) = ?
+ AND TRY_CAST(SPLIT_PART(SPLIT_PART(zed_version, '.', 3), '+', 1) AS INTEGER) >= ?
)
))
- ORDER BY settled.time ASC
+ ORDER BY requested_at ASC
LIMIT ?
OFFSET ?
"#};
- let bindings = json!({
- "1": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
- "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
- "3": { "type": "TEXT", "value": after_date },
- "4": { "type": "FIXED", "value": min_minor_str_ref },
- "5": { "type": "FIXED", "value": min_minor_str_ref },
- "6": { "type": "FIXED", "value": min_minor_str_ref },
- "7": { "type": "FIXED", "value": min_patch_str_ref },
- "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
- "9": { "type": "FIXED", "value": offset.to_string() }
- });
-
let examples = fetch_examples_with_query(
http_client.clone(),
&step_progress,
background_executor.clone(),
statement,
- bindings,
- DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+ QueryRetryState {
+ resume_after: after_date.clone(),
+ remaining_limit: max_rows_per_timestamp,
+ offset,
+ },
+ |retry_state| {
+ json!({
+ "1": { "type": "TEXT", "value": retry_state.resume_after },
+ "2": { "type": "FIXED", "value": min_minor_str_ref },
+ "3": { "type": "FIXED", "value": min_minor_str_ref },
+ "4": { "type": "FIXED", "value": min_minor_str_ref },
+ "5": { "type": "FIXED", "value": min_patch_str_ref },
+ "6": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
+ "7": { "type": "FIXED", "value": retry_state.offset.to_string() }
+ })
+ },
&[
"request_id",
"device_id",
@@ -663,7 +816,7 @@ pub async fn fetch_captured_examples_after(
pub async fn fetch_settled_examples_after(
http_client: Arc<dyn HttpClient>,
after_timestamps: &[String],
- max_rows_per_timestamp: usize,
+ max_rows_per_timestamp: Option<usize>,
offset: usize,
background_executor: BackgroundExecutor,
min_capture_version: Option<MinCaptureVersion>,
@@ -684,55 +837,41 @@ pub async fn fetch_settled_examples_after(
let _ = min_capture_version;
let statement = indoc! {r#"
- WITH requested AS (
- SELECT
- req.event_properties:request_id::string AS request_id,
- req.device_id::string AS device_id,
- req.time AS req_time,
- req.time::string AS time,
- req.event_properties:input AS input,
- req.event_properties:format::string AS requested_format,
- req.event_properties:output::string AS requested_output,
- req.event_properties:zed_version::string AS zed_version
- FROM events req
- WHERE req.event_type = ?
- AND req.event_properties:version = 'V3'
- AND req.event_properties:input:can_collect_data = true
- AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
- )
SELECT
- req.request_id AS request_id,
- req.device_id AS device_id,
- req.time AS time,
- req.input AS input,
- req.requested_output AS requested_output,
- settled.event_properties:settled_editable_region::string AS settled_editable_region,
- req.requested_format AS requested_format,
- req.zed_version AS zed_version
- FROM requested req
- INNER JOIN events settled
- ON req.request_id = settled.event_properties:request_id::string
- WHERE settled.event_type = ?
- ORDER BY req.req_time ASC
+ ep_request_id AS request_id,
+ device_id AS device_id,
+ requested_at::string AS continuation_time,
+ requested_at::string AS time,
+ input_payload AS input,
+ requested_output AS requested_output,
+ settled_editable_region AS settled_editable_region,
+ requested_format AS requested_format,
+ zed_version AS zed_version
+ FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
+ WHERE settled_editable_region IS NOT NULL
+ AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
+ ORDER BY requested_at ASC
LIMIT ?
OFFSET ?
"#};
- let bindings = json!({
- "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
- "2": { "type": "TEXT", "value": after_date },
- "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
- "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
- "5": { "type": "FIXED", "value": offset.to_string() }
- });
-
let examples = fetch_examples_with_query(
http_client.clone(),
&step_progress,
background_executor.clone(),
statement,
- bindings,
- SETTLED_STATEMENT_TIMEOUT_SECONDS,
+ QueryRetryState {
+ resume_after: after_date.clone(),
+ remaining_limit: max_rows_per_timestamp,
+ offset,
+ },
+ |retry_state| {
+ json!({
+ "1": { "type": "TEXT", "value": retry_state.resume_after },
+ "2": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
+ "3": { "type": "FIXED", "value": retry_state.offset.to_string() }
+ })
+ },
&[
"request_id",
"device_id",
@@ -756,7 +895,7 @@ pub async fn fetch_settled_examples_after(
pub async fn fetch_rated_examples_after(
http_client: Arc<dyn HttpClient>,
inputs: &[(String, Option<EditPredictionRating>)],
- max_rows_per_timestamp: usize,
+ max_rows_per_timestamp: Option<usize>,
offset: usize,
background_executor: BackgroundExecutor,
_min_capture_version: Option<MinCaptureVersion>,
@@ -786,54 +925,48 @@ pub async fn fetch_rated_examples_after(
let statement = indoc! {r#"
SELECT
- rated.event_properties:request_id::string AS request_id,
- rated.event_properties:inputs AS inputs,
- rated.event_properties:output::string AS output,
- rated.event_properties:rating::string AS rating,
- rated.event_properties:feedback::string AS feedback,
- rated.device_id::string AS device_id,
- rated.time::string AS time,
- deploy.event_properties:experiment_name::string AS experiment_name,
- deploy.event_properties:environment::string AS environment,
- rated.event_properties:zed_version::string AS zed_version
- FROM events rated
- LEFT JOIN events req
- ON rated.event_properties:request_id::string = req.event_properties:request_id::string
- AND req.event_type = ?
- LEFT JOIN events deploy
- ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
- AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
- AND deploy.event_type = ?
- WHERE rated.event_type = ?
- AND (? IS NULL OR rated.event_properties:rating::string = ?)
- AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
- AND rated.event_properties:inputs IS NOT NULL
- AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
- AND rated.event_properties:output IS NOT NULL
- AND rated.event_properties:inputs:can_collect_data = true
- ORDER BY rated.time ASC
+ ep_request_id AS request_id,
+ rated_inputs AS inputs,
+ rated_output AS output,
+ rating AS rating,
+ feedback AS feedback,
+ device_id AS device_id,
+ requested_at::string AS continuation_time,
+ requested_at::string AS time,
+ NULL AS experiment_name,
+ NULL AS environment,
+ zed_version AS zed_version
+ FROM ZED_DBT.DBT_PROD.fct_edit_prediction_examples
+ WHERE rating IS NOT NULL
+ AND (? IS NULL OR rating = ?)
+ AND requested_at > TRY_TO_TIMESTAMP_NTZ(?)
+ AND rated_inputs IS NOT NULL
+ AND rated_inputs:cursor_excerpt IS NOT NULL
+ AND rated_output IS NOT NULL
+ ORDER BY requested_at ASC
LIMIT ?
OFFSET ?
"#};
- let bindings = json!({
- "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
- "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
- "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
- "4": { "type": "TEXT", "value": rating_value },
- "5": { "type": "TEXT", "value": rating_value },
- "6": { "type": "TEXT", "value": after_date },
- "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
- "8": { "type": "FIXED", "value": offset.to_string() }
- });
-
let examples = fetch_examples_with_query(
http_client.clone(),
&step_progress,
background_executor.clone(),
statement,
- bindings,
- DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+ QueryRetryState {
+ resume_after: after_date.clone(),
+ remaining_limit: max_rows_per_timestamp,
+ offset,
+ },
+ |retry_state| {
+ json!({
+ "1": { "type": "TEXT", "value": rating_value },
+ "2": { "type": "TEXT", "value": rating_value },
+ "3": { "type": "TEXT", "value": retry_state.resume_after },
+ "4": { "type": "FIXED", "value": format_limit(retry_state.remaining_limit) },
+ "5": { "type": "FIXED", "value": retry_state.offset.to_string() }
+ })
+ },
&[
"request_id",
"inputs",
@@ -1473,6 +1606,7 @@ fn rejected_examples_from_response<'a>(
let input_json = get_json("input");
let input: Option<ZetaPromptInput> =
input_json.clone().and_then(|v| serde_json::from_value(v).ok());
+ let prompt = get_string("prompt");
let output = get_string("output");
let was_shown = get_bool("was_shown");
let reason = get_string("reason");
@@ -1485,6 +1619,7 @@ fn rejected_examples_from_response<'a>(
device_id,
time,
input,
+ prompt,
output,
was_shown,
reason,
@@ -1515,6 +1650,7 @@ fn build_rejected_example(
device_id: String,
time: String,
input: ZetaPromptInput,
+ prompt: Option<String>,
output: String,
was_shown: bool,
reason: String,
@@ -1536,6 +1672,13 @@ fn build_rejected_example(
zed_version,
);
example.spec.rejected_patch = Some(rejected_patch);
+ example.prompt = prompt.map(|prompt| ExamplePrompt {
+ input: prompt,
+ expected_output: String::new(),
+ rejected_output: Some(output),
+ prefill: None,
+ provider: PredictionProvider::default(),
+ });
example
}
@@ -1635,11 +1778,42 @@ fn build_output_patch(
patch
}
+fn is_timeout_response(response: &SnowflakeStatementResponse) -> bool {
+ response.code.as_deref() == Some(SNOWFLAKE_TIMEOUT_CODE)
+ && response
+ .message
+ .as_deref()
+ .map(|message| message.to_ascii_lowercase().contains("timeout"))
+ .unwrap_or(false)
+}
+
+fn is_snowflake_timeout_error(error: &anyhow::Error) -> bool {
+ error
+ .chain()
+ .any(|cause| cause.to_string().contains(SNOWFLAKE_TIMEOUT_CODE))
+}
+
+fn last_continuation_timestamp_from_response(
+ response: &SnowflakeStatementResponse,
+ column_indices: &HashMap<String, usize>,
+) -> Option<String> {
+ let continuation_time_index = column_indices.get("continuation_time").copied()?;
+ response
+ .data
+ .iter()
+ .rev()
+ .find_map(|row| match row.get(continuation_time_index)? {
+ JsonValue::String(value) => Some(value.clone()),
+ JsonValue::Null => None,
+ other => Some(other.to_string()),
+ })
+}
+
pub(crate) fn get_column_indices(
meta: &Option<SnowflakeResultSetMetaData>,
names: &[&str],
-) -> std::collections::HashMap<String, usize> {
- let mut indices = std::collections::HashMap::new();
+) -> HashMap<String, usize> {
+ let mut indices = HashMap::new();
if let Some(meta) = meta {
for (index, col) in meta.row_type.iter().enumerate() {
for &name in names {