pull_examples.rs

  1use anyhow::{Context as _, Result};
  2use http_client::{AsyncBody, HttpClient, Method, Request};
  3use indoc::indoc;
  4use serde::Deserialize;
  5use serde_json::{Value as JsonValue, json};
  6use std::sync::Arc;
  7
  8use crate::{
  9    example::Example,
 10    progress::{InfoStyle, Progress, Step},
 11};
 12use edit_prediction::example_spec::ExampleSpec;
 13
 14const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
 15const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
 16
 17const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
 18
 19/// Parse an input token of the form `captured-after:{timestamp}`.
 20pub fn parse_captured_after_input(input: &str) -> Option<&str> {
 21    input.strip_prefix("captured-after:")
 22}
 23
 24pub async fn fetch_captured_examples_after(
 25    http_client: Arc<dyn HttpClient>,
 26    after_timestamps: &[String],
 27    max_rows_per_timestamp: usize,
 28) -> Result<Vec<Example>> {
 29    if after_timestamps.is_empty() {
 30        return Ok(Vec::new());
 31    }
 32
 33    let progress = Progress::global();
 34
 35    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 36        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 37    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 38        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 39    )?;
 40    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 41
 42    let mut all_examples = Vec::new();
 43
 44    for after_date in after_timestamps.iter() {
 45        let step_progress_name = format!(">{after_date}");
 46        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 47        step_progress.set_substatus("querying");
 48
 49        let statement = indoc! {r#"
 50            SELECT
 51                event_properties:example AS example
 52            FROM events
 53            WHERE event_type = ?
 54                AND time > TRY_TO_TIMESTAMP_NTZ(?)
 55            ORDER BY time ASC
 56            LIMIT ?
 57        "#};
 58
 59        let request = json!({
 60            "statement": statement,
 61            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 62            "database": "EVENTS",
 63            "schema": "PUBLIC",
 64            "warehouse": "DBT",
 65            "role": role,
 66            "bindings": {
 67                "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
 68                "2": { "type": "TEXT", "value": after_date },
 69                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 70            }
 71        });
 72
 73        let response = run_sql(http_client.clone(), &base_url, &token, &request).await?;
 74
 75        step_progress.set_info(format!("{} rows", response.data.len()), InfoStyle::Normal);
 76        step_progress.set_substatus("parsing");
 77
 78        all_examples.extend(examples_from_response(&response)?);
 79
 80        step_progress.set_substatus("done");
 81    }
 82
 83    Ok(all_examples)
 84}
 85
 86#[derive(Debug, Clone, Deserialize)]
 87struct SnowflakeStatementResponse {
 88    #[serde(default)]
 89    data: Vec<Vec<JsonValue>>,
 90    #[serde(default)]
 91    result_set_meta_data: Option<SnowflakeResultSetMetaData>,
 92    #[serde(default)]
 93    code: Option<String>,
 94    #[serde(default)]
 95    message: Option<String>,
 96}
 97
 98#[derive(Debug, Clone, Deserialize)]
 99struct SnowflakeResultSetMetaData {
100    #[serde(default, rename = "rowType")]
101    row_type: Vec<SnowflakeColumnMeta>,
102}
103
104#[derive(Debug, Clone, Deserialize)]
105struct SnowflakeColumnMeta {
106    #[serde(default)]
107    name: String,
108}
109
110fn examples_from_response(
111    response: &SnowflakeStatementResponse,
112) -> Result<impl Iterator<Item = Example>> {
113    if let Some(code) = &response.code {
114        if code != SNOWFLAKE_SUCCESS_CODE {
115            anyhow::bail!(
116                "snowflake sql api returned error code={code} message={}",
117                response.message.as_deref().unwrap_or("<no message>")
118            );
119        }
120    }
121
122    let example_index = response
123        .result_set_meta_data
124        .as_ref()
125        .and_then(|m| {
126            m.row_type.iter().enumerate().find_map(|(index, col)| {
127                if col.name.eq_ignore_ascii_case("example") {
128                    Some(index)
129                } else {
130                    None
131                }
132            })
133        })
134        .unwrap_or(0);
135
136    let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
137        let Some(example_value) = data_row.get(example_index) else {
138            return None;
139        };
140        if example_value.is_null() {
141            return None;
142        }
143
144        let parse_result = match example_value {
145            JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
146            _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
147        };
148
149        match parse_result {
150            Ok(spec) => Some(Example {
151                spec,
152                buffer: None,
153                context: None,
154                prompt: None,
155                predictions: Vec::new(),
156                score: Vec::new(),
157                state: None,
158            }),
159            Err(error) => {
160                let raw_json = serde_json::to_string_pretty(example_value)
161                    .unwrap_or_else(|_| "<failed to serialize json>".to_string());
162                log::error!(
163                    "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
164                );
165                None
166            }
167        }
168    });
169
170    Ok(iter)
171}
172
173async fn run_sql(
174    http_client: Arc<dyn HttpClient>,
175    base_url: &str,
176    token: &str,
177    request: &serde_json::Value,
178) -> Result<SnowflakeStatementResponse> {
179    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
180
181    let request_body =
182        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
183
184    let http_request = Request::builder()
185        .method(Method::POST)
186        .uri(url.as_str())
187        .header("Authorization", format!("Bearer {token}"))
188        .header(
189            "X-Snowflake-Authorization-Token-Type",
190            "PROGRAMMATIC_ACCESS_TOKEN",
191        )
192        .header("Content-Type", "application/json")
193        .header("Accept", "application/json")
194        .body(AsyncBody::from(request_body.clone()))?;
195
196    let response = http_client
197        .send(http_request)
198        .await
199        .context("failed to send request to Snowflake SQL API")?;
200
201    let status = response.status();
202    let body_bytes = {
203        use futures::AsyncReadExt as _;
204
205        let mut body = response.into_body();
206        let mut bytes = Vec::new();
207        body.read_to_end(&mut bytes)
208            .await
209            .context("failed to read Snowflake SQL API response body")?;
210        bytes
211    };
212
213    if !status.is_success() {
214        let body_text = String::from_utf8_lossy(&body_bytes);
215        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
216    }
217
218    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
219        .context("failed to parse Snowflake SQL API response JSON")
220}