1use anyhow::{Context as _, Result};
2use flate2::read::GzDecoder;
3use gpui::BackgroundExecutor;
4use http_client::{AsyncBody, HttpClient, Method, Request};
5use indoc::indoc;
6use serde::Deserialize;
7use serde_json::{Value as JsonValue, json};
8use std::io::Read;
9use std::sync::Arc;
10use std::time::Duration;
11
12use crate::{
13 example::Example,
14 progress::{InfoStyle, Progress, Step},
15};
16use edit_prediction::example_spec::ExampleSpec;
17
18const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
19const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
20const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
21
22const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
23const POLL_INTERVAL: Duration = Duration::from_secs(2);
24const MAX_POLL_ATTEMPTS: usize = 120;
25
26/// Parse an input token of the form `captured-after:{timestamp}`.
27pub fn parse_captured_after_input(input: &str) -> Option<&str> {
28 input.strip_prefix("captured-after:")
29}
30
31pub async fn fetch_captured_examples_after(
32 http_client: Arc<dyn HttpClient>,
33 after_timestamps: &[String],
34 max_rows_per_timestamp: usize,
35 background_executor: BackgroundExecutor,
36) -> Result<Vec<Example>> {
37 if after_timestamps.is_empty() {
38 return Ok(Vec::new());
39 }
40
41 let progress = Progress::global();
42
43 let token = std::env::var("EP_SNOWFLAKE_API_KEY")
44 .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
45 let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
46 "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
47 )?;
48 let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
49
50 let mut all_examples = Vec::new();
51
52 for after_date in after_timestamps.iter() {
53 let step_progress_name = format!(">{after_date}");
54 let step_progress = progress.start(Step::PullExamples, &step_progress_name);
55 step_progress.set_substatus("querying");
56
57 let statement = indoc! {r#"
58 SELECT
59 event_properties:example AS example
60 FROM events
61 WHERE event_type = ?
62 AND time > TRY_TO_TIMESTAMP_NTZ(?)
63 ORDER BY time ASC
64 LIMIT ?
65 "#};
66
67 let request = json!({
68 "statement": statement,
69 "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
70 "database": "EVENTS",
71 "schema": "PUBLIC",
72 "warehouse": "DBT",
73 "role": role,
74 "bindings": {
75 "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
76 "2": { "type": "TEXT", "value": after_date },
77 "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
78 }
79 });
80
81 let response = run_sql_with_polling(
82 http_client.clone(),
83 &base_url,
84 &token,
85 &request,
86 &step_progress,
87 background_executor.clone(),
88 )
89 .await?;
90
91 let total_rows = response
92 .result_set_meta_data
93 .as_ref()
94 .and_then(|m| m.num_rows)
95 .unwrap_or(response.data.len() as i64);
96
97 let num_partitions = response
98 .result_set_meta_data
99 .as_ref()
100 .map(|m| m.partition_info.len())
101 .unwrap_or(1)
102 .max(1);
103
104 step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
105 step_progress.set_substatus("parsing");
106
107 all_examples.extend(examples_from_response(&response)?);
108
109 if num_partitions > 1 {
110 let statement_handle = response
111 .statement_handle
112 .as_ref()
113 .context("response has multiple partitions but no statementHandle")?;
114
115 for partition in 1..num_partitions {
116 step_progress.set_substatus(format!(
117 "fetching partition {}/{}",
118 partition + 1,
119 num_partitions
120 ));
121
122 let partition_response = fetch_partition(
123 http_client.clone(),
124 &base_url,
125 &token,
126 statement_handle,
127 partition,
128 )
129 .await?;
130
131 all_examples.extend(examples_from_response(&partition_response)?);
132 }
133 }
134
135 step_progress.set_substatus("done");
136 }
137
138 Ok(all_examples)
139}
140
141#[derive(Debug, Clone, Deserialize)]
142#[serde(rename_all = "camelCase")]
143struct SnowflakeStatementResponse {
144 #[serde(default)]
145 data: Vec<Vec<JsonValue>>,
146 #[serde(default)]
147 result_set_meta_data: Option<SnowflakeResultSetMetaData>,
148 #[serde(default)]
149 code: Option<String>,
150 #[serde(default)]
151 message: Option<String>,
152 #[serde(default)]
153 statement_handle: Option<String>,
154}
155
156#[derive(Debug, Clone, Deserialize)]
157#[serde(rename_all = "camelCase")]
158struct SnowflakeResultSetMetaData {
159 #[serde(default, rename = "rowType")]
160 row_type: Vec<SnowflakeColumnMeta>,
161 #[serde(default)]
162 num_rows: Option<i64>,
163 #[serde(default)]
164 partition_info: Vec<SnowflakePartitionInfo>,
165}
166
167#[derive(Debug, Clone, Deserialize)]
168#[serde(rename_all = "camelCase")]
169struct SnowflakePartitionInfo {}
170
171#[derive(Debug, Clone, Deserialize)]
172struct SnowflakeColumnMeta {
173 #[serde(default)]
174 name: String,
175}
176
177fn examples_from_response(
178 response: &SnowflakeStatementResponse,
179) -> Result<impl Iterator<Item = Example> + '_> {
180 if let Some(code) = &response.code {
181 if code != SNOWFLAKE_SUCCESS_CODE {
182 anyhow::bail!(
183 "snowflake sql api returned error code={code} message={}",
184 response.message.as_deref().unwrap_or("<no message>")
185 );
186 }
187 }
188
189 let example_index = response
190 .result_set_meta_data
191 .as_ref()
192 .and_then(|m| {
193 m.row_type.iter().enumerate().find_map(|(index, col)| {
194 if col.name.eq_ignore_ascii_case("example") {
195 Some(index)
196 } else {
197 None
198 }
199 })
200 })
201 .unwrap_or(0);
202
203 let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
204 let Some(example_value) = data_row.get(example_index) else {
205 return None;
206 };
207 if example_value.is_null() {
208 return None;
209 }
210
211 let parse_result = match example_value {
212 JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
213 _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
214 };
215
216 match parse_result {
217 Ok(spec) => Some(Example {
218 spec,
219 prompt_inputs: None,
220 prompt: None,
221 predictions: Vec::new(),
222 score: Vec::new(),
223 state: None,
224 }),
225 Err(error) => {
226 let raw_json = serde_json::to_string_pretty(example_value)
227 .unwrap_or_else(|_| "<failed to serialize json>".to_string());
228 log::error!(
229 "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
230 );
231 None
232 }
233 }
234 });
235
236 Ok(iter)
237}
238
239async fn run_sql_with_polling(
240 http_client: Arc<dyn HttpClient>,
241 base_url: &str,
242 token: &str,
243 request: &serde_json::Value,
244 step_progress: &crate::progress::StepProgress,
245 background_executor: BackgroundExecutor,
246) -> Result<SnowflakeStatementResponse> {
247 let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
248
249 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
250 let statement_handle = response
251 .statement_handle
252 .as_ref()
253 .context("async query response missing statementHandle")?
254 .clone();
255
256 for attempt in 1..=MAX_POLL_ATTEMPTS {
257 step_progress.set_substatus(format!("polling ({attempt})"));
258
259 background_executor.timer(POLL_INTERVAL).await;
260
261 response =
262 fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
263
264 if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
265 break;
266 }
267 }
268
269 if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
270 anyhow::bail!(
271 "query still running after {} poll attempts ({} seconds)",
272 MAX_POLL_ATTEMPTS,
273 MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
274 );
275 }
276 }
277
278 Ok(response)
279}
280
281async fn fetch_partition(
282 http_client: Arc<dyn HttpClient>,
283 base_url: &str,
284 token: &str,
285 statement_handle: &str,
286 partition: usize,
287) -> Result<SnowflakeStatementResponse> {
288 let url = format!(
289 "{}/api/v2/statements/{}?partition={}",
290 base_url.trim_end_matches('/'),
291 statement_handle,
292 partition
293 );
294
295 let http_request = Request::builder()
296 .method(Method::GET)
297 .uri(url.as_str())
298 .header("Authorization", format!("Bearer {token}"))
299 .header(
300 "X-Snowflake-Authorization-Token-Type",
301 "PROGRAMMATIC_ACCESS_TOKEN",
302 )
303 .header("Accept", "application/json")
304 .header("Accept-Encoding", "gzip")
305 .body(AsyncBody::empty())?;
306
307 let response = http_client
308 .send(http_request)
309 .await
310 .context("failed to send partition request to Snowflake SQL API")?;
311
312 let status = response.status();
313 let content_encoding = response
314 .headers()
315 .get("content-encoding")
316 .and_then(|v| v.to_str().ok())
317 .map(|s| s.to_lowercase());
318
319 let body_bytes = {
320 use futures::AsyncReadExt as _;
321
322 let mut body = response.into_body();
323 let mut bytes = Vec::new();
324 body.read_to_end(&mut bytes)
325 .await
326 .context("failed to read Snowflake SQL API partition response body")?;
327 bytes
328 };
329
330 let body_bytes = if content_encoding.as_deref() == Some("gzip") {
331 let mut decoder = GzDecoder::new(&body_bytes[..]);
332 let mut decompressed = Vec::new();
333 decoder
334 .read_to_end(&mut decompressed)
335 .context("failed to decompress gzip response")?;
336 decompressed
337 } else {
338 body_bytes
339 };
340
341 if !status.is_success() && status.as_u16() != 202 {
342 let body_text = String::from_utf8_lossy(&body_bytes);
343 anyhow::bail!(
344 "snowflake sql api partition request http {}: {}",
345 status.as_u16(),
346 body_text
347 );
348 }
349
350 if body_bytes.is_empty() {
351 anyhow::bail!(
352 "snowflake sql api partition {} returned empty response body (http {})",
353 partition,
354 status.as_u16()
355 );
356 }
357
358 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
359 let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
360 format!(
361 "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
362 partition,
363 status.as_u16(),
364 body_preview
365 )
366 })
367}
368
369async fn run_sql(
370 http_client: Arc<dyn HttpClient>,
371 base_url: &str,
372 token: &str,
373 request: &serde_json::Value,
374) -> Result<SnowflakeStatementResponse> {
375 let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
376
377 let request_body =
378 serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
379
380 let http_request = Request::builder()
381 .method(Method::POST)
382 .uri(url.as_str())
383 .header("Authorization", format!("Bearer {token}"))
384 .header(
385 "X-Snowflake-Authorization-Token-Type",
386 "PROGRAMMATIC_ACCESS_TOKEN",
387 )
388 .header("Content-Type", "application/json")
389 .header("Accept", "application/json")
390 .body(AsyncBody::from(request_body.clone()))?;
391
392 let response = http_client
393 .send(http_request)
394 .await
395 .context("failed to send request to Snowflake SQL API")?;
396
397 let status = response.status();
398 let body_bytes = {
399 use futures::AsyncReadExt as _;
400
401 let mut body = response.into_body();
402 let mut bytes = Vec::new();
403 body.read_to_end(&mut bytes)
404 .await
405 .context("failed to read Snowflake SQL API response body")?;
406 bytes
407 };
408
409 if !status.is_success() && status.as_u16() != 202 {
410 let body_text = String::from_utf8_lossy(&body_bytes);
411 anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
412 }
413
414 serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
415 .context("failed to parse Snowflake SQL API response JSON")
416}