1use anyhow::{Context as _, Result};
2use http_client::{AsyncBody, HttpClient, Method, Request};
3use indoc::indoc;
4use serde::Deserialize;
5use serde_json::{Value as JsonValue, json};
6use std::sync::Arc;
7
8use crate::{
9 example::Example,
10 progress::{InfoStyle, Progress, Step},
11};
12use edit_prediction::example_spec::ExampleSpec;
13
14const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
15const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
16
17const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
18
19/// Parse an input token of the form `captured-after:{timestamp}`.
20pub fn parse_captured_after_input(input: &str) -> Option<&str> {
21 input.strip_prefix("captured-after:")
22}
23
24pub async fn fetch_captured_examples_after(
25 http_client: Arc<dyn HttpClient>,
26 after_timestamps: &[String],
27 max_rows_per_timestamp: usize,
28) -> Result<Vec<Example>> {
29 if after_timestamps.is_empty() {
30 return Ok(Vec::new());
31 }
32
33 let progress = Progress::global();
34
35 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
36 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
37 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
38 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
39 )?;
40 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
41
42 let mut all_examples = Vec::new();
43
44 for after_date in after_timestamps.iter() {
45 let step_progress_name = format!(">{after_date}");
46 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
47 step_progress.set_substatus("querying");
48
49 let statement = indoc! {r#"
50 SELECT
51 event_properties:example AS example
52 FROM events
53 WHERE event_type = ?
54 AND time > TRY_TO_TIMESTAMP_NTZ(?)
55 ORDER BY time ASC
56 LIMIT ?
57 "#};
58
59 let request = json!({
60 "statement": statement,
61 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
62 "database": "EVENTS",
63 "schema": "PUBLIC",
64 "warehouse": "DBT",
65 "role": role,
66 "bindings": {
67 "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
68 "2": { "type": "TEXT", "value": after_date },
69 "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
70 }
71 });
72
73 let response = run_sql(http_client.clone(), &base_url, &token, &request).await?;
74
75 step_progress.set_info(format!("{} rows", response.data.len()), InfoStyle::Normal);
76 step_progress.set_substatus("parsing");
77
78 all_examples.extend(examples_from_response(&response)?);
79
80 step_progress.set_substatus("done");
81 }
82
83 Ok(all_examples)
84}
85
86#[derive(Debug, Clone, Deserialize)]
87struct SnowflakeStatementResponse {
88 #[serde(default)]
89 data: Vec<Vec<JsonValue>>,
90 #[serde(default)]
91 result_set_meta_data: Option<SnowflakeResultSetMetaData>,
92 #[serde(default)]
93 code: Option<String>,
94 #[serde(default)]
95 message: Option<String>,
96}
97
98#[derive(Debug, Clone, Deserialize)]
99struct SnowflakeResultSetMetaData {
100 #[serde(default, rename = "rowType")]
101 row_type: Vec<SnowflakeColumnMeta>,
102}
103
104#[derive(Debug, Clone, Deserialize)]
105struct SnowflakeColumnMeta {
106 #[serde(default)]
107 name: String,
108}
109
110fn examples_from_response(
111 response: &SnowflakeStatementResponse,
112) -> Result<impl Iterator<Item = Example>> {
113 if let Some(code) = &response.code {
114 if code != SNOWFLAKE_SUCCESS_CODE {
115 anyhow::bail!(
116 "snowflake sql api returned error code={code} message={}",
117 response.message.as_deref().unwrap_or("<no message>")
118 );
119 }
120 }
121
122 let example_index = response
123 .result_set_meta_data
124 .as_ref()
125 .and_then(|m| {
126 m.row_type.iter().enumerate().find_map(|(index, col)| {
127 if col.name.eq_ignore_ascii_case("example") {
128 Some(index)
129 } else {
130 None
131 }
132 })
133 })
134 .unwrap_or(0);
135
136 let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
137 let Some(example_value) = data_row.get(example_index) else {
138 return None;
139 };
140 if example_value.is_null() {
141 return None;
142 }
143
144 let parse_result = match example_value {
145 JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
146 _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
147 };
148
149 match parse_result {
150 Ok(spec) => Some(Example {
151 spec,
152 buffer: None,
153 context: None,
154 prompt: None,
155 predictions: Vec::new(),
156 score: Vec::new(),
157 state: None,
158 }),
159 Err(error) => {
160 let raw_json = serde_json::to_string_pretty(example_value)
161 .unwrap_or_else(|_| "<failed to serialize json>".to_string());
162 log::error!(
163 "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
164 );
165 None
166 }
167 }
168 });
169
170 Ok(iter)
171}
172
173async fn run_sql(
174 http_client: Arc<dyn HttpClient>,
175 base_url: &str,
176 token: &str,
177 request: &serde_json::Value,
178) -> Result<SnowflakeStatementResponse> {
179 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
180
181 let request_body =
182 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
183
184 let http_request = Request::builder()
185 .method(Method::POST)
186 .uri(url.as_str())
187 .header("Authorization", format!("Bearer {token}"))
188 .header(
189 "X-Snowflake-Authorization-Token-Type",
190 "PROGRAMMATIC_ACCESS_TOKEN",
191 )
192 .header("Content-Type", "application/json")
193 .header("Accept", "application/json")
194 .body(AsyncBody::from(request_body.clone()))?;
195
196 let response = http_client
197 .send(http_request)
198 .await
199 .context("failed to send request to Snowflake SQL API")?;
200
201 let status = response.status();
202 let body_bytes = {
203 use futures::AsyncReadExt as _;
204
205 let mut body = response.into_body();
206 let mut bytes = Vec::new();
207 body.read_to_end(&mut bytes)
208 .await
209 .context("failed to read Snowflake SQL API response body")?;
210 bytes
211 };
212
213 if !status.is_success() {
214 let body_text = String::from_utf8_lossy(&body_bytes);
215 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
216 }
217
218 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
219 .context("failed to parse Snowflake SQL API response JSON")
220}