pull_examples.rs

  1use anyhow::{Context as _, Result};
  2use flate2::read::GzDecoder;
  3use gpui::BackgroundExecutor;
  4use http_client::{AsyncBody, HttpClient, Method, Request};
  5use indoc::indoc;
  6use serde::Deserialize;
  7use serde_json::{Value as JsonValue, json};
  8use std::io::Read;
  9use std::sync::Arc;
 10use std::time::Duration;
 11
 12use crate::{
 13    example::Example,
 14    progress::{InfoStyle, Progress, Step},
 15};
 16use edit_prediction::example_spec::ExampleSpec;
 17
 18const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
 19const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
 20const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
 21
 22const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
 23const POLL_INTERVAL: Duration = Duration::from_secs(2);
 24const MAX_POLL_ATTEMPTS: usize = 120;
 25
 26/// Parse an input token of the form `captured-after:{timestamp}`.
 27pub fn parse_captured_after_input(input: &str) -> Option<&str> {
 28    input.strip_prefix("captured-after:")
 29}
 30
 31pub async fn fetch_captured_examples_after(
 32    http_client: Arc<dyn HttpClient>,
 33    after_timestamps: &[String],
 34    max_rows_per_timestamp: usize,
 35    background_executor: BackgroundExecutor,
 36) -> Result<Vec<Example>> {
 37    if after_timestamps.is_empty() {
 38        return Ok(Vec::new());
 39    }
 40
 41    let progress = Progress::global();
 42
 43    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 44        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 45    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 46        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 47    )?;
 48    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 49
 50    let mut all_examples = Vec::new();
 51
 52    for after_date in after_timestamps.iter() {
 53        let step_progress_name = format!(">{after_date}");
 54        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 55        step_progress.set_substatus("querying");
 56
 57        let statement = indoc! {r#"
 58            SELECT
 59                event_properties:example AS example
 60            FROM events
 61            WHERE event_type = ?
 62                AND time > TRY_TO_TIMESTAMP_NTZ(?)
 63            ORDER BY time ASC
 64            LIMIT ?
 65        "#};
 66
 67        let request = json!({
 68            "statement": statement,
 69            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 70            "database": "EVENTS",
 71            "schema": "PUBLIC",
 72            "warehouse": "DBT",
 73            "role": role,
 74            "bindings": {
 75                "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
 76                "2": { "type": "TEXT", "value": after_date },
 77                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 78            }
 79        });
 80
 81        let response = run_sql_with_polling(
 82            http_client.clone(),
 83            &base_url,
 84            &token,
 85            &request,
 86            &step_progress,
 87            background_executor.clone(),
 88        )
 89        .await?;
 90
 91        let total_rows = response
 92            .result_set_meta_data
 93            .as_ref()
 94            .and_then(|m| m.num_rows)
 95            .unwrap_or(response.data.len() as i64);
 96
 97        let num_partitions = response
 98            .result_set_meta_data
 99            .as_ref()
100            .map(|m| m.partition_info.len())
101            .unwrap_or(1)
102            .max(1);
103
104        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
105        step_progress.set_substatus("parsing");
106
107        all_examples.extend(examples_from_response(&response)?);
108
109        if num_partitions > 1 {
110            let statement_handle = response
111                .statement_handle
112                .as_ref()
113                .context("response has multiple partitions but no statementHandle")?;
114
115            for partition in 1..num_partitions {
116                step_progress.set_substatus(format!(
117                    "fetching partition {}/{}",
118                    partition + 1,
119                    num_partitions
120                ));
121
122                let partition_response = fetch_partition(
123                    http_client.clone(),
124                    &base_url,
125                    &token,
126                    statement_handle,
127                    partition,
128                )
129                .await?;
130
131                all_examples.extend(examples_from_response(&partition_response)?);
132            }
133        }
134
135        step_progress.set_substatus("done");
136    }
137
138    Ok(all_examples)
139}
140
141#[derive(Debug, Clone, Deserialize)]
142#[serde(rename_all = "camelCase")]
143struct SnowflakeStatementResponse {
144    #[serde(default)]
145    data: Vec<Vec<JsonValue>>,
146    #[serde(default)]
147    result_set_meta_data: Option<SnowflakeResultSetMetaData>,
148    #[serde(default)]
149    code: Option<String>,
150    #[serde(default)]
151    message: Option<String>,
152    #[serde(default)]
153    statement_handle: Option<String>,
154}
155
156#[derive(Debug, Clone, Deserialize)]
157#[serde(rename_all = "camelCase")]
158struct SnowflakeResultSetMetaData {
159    #[serde(default, rename = "rowType")]
160    row_type: Vec<SnowflakeColumnMeta>,
161    #[serde(default)]
162    num_rows: Option<i64>,
163    #[serde(default)]
164    partition_info: Vec<SnowflakePartitionInfo>,
165}
166
167#[derive(Debug, Clone, Deserialize)]
168#[serde(rename_all = "camelCase")]
169struct SnowflakePartitionInfo {}
170
171#[derive(Debug, Clone, Deserialize)]
172struct SnowflakeColumnMeta {
173    #[serde(default)]
174    name: String,
175}
176
177fn examples_from_response(
178    response: &SnowflakeStatementResponse,
179) -> Result<impl Iterator<Item = Example> + '_> {
180    if let Some(code) = &response.code {
181        if code != SNOWFLAKE_SUCCESS_CODE {
182            anyhow::bail!(
183                "snowflake sql api returned error code={code} message={}",
184                response.message.as_deref().unwrap_or("<no message>")
185            );
186        }
187    }
188
189    let example_index = response
190        .result_set_meta_data
191        .as_ref()
192        .and_then(|m| {
193            m.row_type.iter().enumerate().find_map(|(index, col)| {
194                if col.name.eq_ignore_ascii_case("example") {
195                    Some(index)
196                } else {
197                    None
198                }
199            })
200        })
201        .unwrap_or(0);
202
203    let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
204        let Some(example_value) = data_row.get(example_index) else {
205            return None;
206        };
207        if example_value.is_null() {
208            return None;
209        }
210
211        let parse_result = match example_value {
212            JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
213            _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
214        };
215
216        match parse_result {
217            Ok(spec) => Some(Example {
218                spec,
219                prompt_inputs: None,
220                prompt: None,
221                predictions: Vec::new(),
222                score: Vec::new(),
223                state: None,
224            }),
225            Err(error) => {
226                let raw_json = serde_json::to_string_pretty(example_value)
227                    .unwrap_or_else(|_| "<failed to serialize json>".to_string());
228                log::error!(
229                    "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
230                );
231                None
232            }
233        }
234    });
235
236    Ok(iter)
237}
238
239async fn run_sql_with_polling(
240    http_client: Arc<dyn HttpClient>,
241    base_url: &str,
242    token: &str,
243    request: &serde_json::Value,
244    step_progress: &crate::progress::StepProgress,
245    background_executor: BackgroundExecutor,
246) -> Result<SnowflakeStatementResponse> {
247    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
248
249    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
250        let statement_handle = response
251            .statement_handle
252            .as_ref()
253            .context("async query response missing statementHandle")?
254            .clone();
255
256        for attempt in 1..=MAX_POLL_ATTEMPTS {
257            step_progress.set_substatus(format!("polling ({attempt})"));
258
259            background_executor.timer(POLL_INTERVAL).await;
260
261            response =
262                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
263
264            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
265                break;
266            }
267        }
268
269        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
270            anyhow::bail!(
271                "query still running after {} poll attempts ({} seconds)",
272                MAX_POLL_ATTEMPTS,
273                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
274            );
275        }
276    }
277
278    Ok(response)
279}
280
281async fn fetch_partition(
282    http_client: Arc<dyn HttpClient>,
283    base_url: &str,
284    token: &str,
285    statement_handle: &str,
286    partition: usize,
287) -> Result<SnowflakeStatementResponse> {
288    let url = format!(
289        "{}/api/v2/statements/{}?partition={}",
290        base_url.trim_end_matches('/'),
291        statement_handle,
292        partition
293    );
294
295    let http_request = Request::builder()
296        .method(Method::GET)
297        .uri(url.as_str())
298        .header("Authorization", format!("Bearer {token}"))
299        .header(
300            "X-Snowflake-Authorization-Token-Type",
301            "PROGRAMMATIC_ACCESS_TOKEN",
302        )
303        .header("Accept", "application/json")
304        .header("Accept-Encoding", "gzip")
305        .body(AsyncBody::empty())?;
306
307    let response = http_client
308        .send(http_request)
309        .await
310        .context("failed to send partition request to Snowflake SQL API")?;
311
312    let status = response.status();
313    let content_encoding = response
314        .headers()
315        .get("content-encoding")
316        .and_then(|v| v.to_str().ok())
317        .map(|s| s.to_lowercase());
318
319    let body_bytes = {
320        use futures::AsyncReadExt as _;
321
322        let mut body = response.into_body();
323        let mut bytes = Vec::new();
324        body.read_to_end(&mut bytes)
325            .await
326            .context("failed to read Snowflake SQL API partition response body")?;
327        bytes
328    };
329
330    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
331        let mut decoder = GzDecoder::new(&body_bytes[..]);
332        let mut decompressed = Vec::new();
333        decoder
334            .read_to_end(&mut decompressed)
335            .context("failed to decompress gzip response")?;
336        decompressed
337    } else {
338        body_bytes
339    };
340
341    if !status.is_success() && status.as_u16() != 202 {
342        let body_text = String::from_utf8_lossy(&body_bytes);
343        anyhow::bail!(
344            "snowflake sql api partition request http {}: {}",
345            status.as_u16(),
346            body_text
347        );
348    }
349
350    if body_bytes.is_empty() {
351        anyhow::bail!(
352            "snowflake sql api partition {} returned empty response body (http {})",
353            partition,
354            status.as_u16()
355        );
356    }
357
358    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
359        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
360        format!(
361            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
362            partition,
363            status.as_u16(),
364            body_preview
365        )
366    })
367}
368
369async fn run_sql(
370    http_client: Arc<dyn HttpClient>,
371    base_url: &str,
372    token: &str,
373    request: &serde_json::Value,
374) -> Result<SnowflakeStatementResponse> {
375    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
376
377    let request_body =
378        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
379
380    let http_request = Request::builder()
381        .method(Method::POST)
382        .uri(url.as_str())
383        .header("Authorization", format!("Bearer {token}"))
384        .header(
385            "X-Snowflake-Authorization-Token-Type",
386            "PROGRAMMATIC_ACCESS_TOKEN",
387        )
388        .header("Content-Type", "application/json")
389        .header("Accept", "application/json")
390        .body(AsyncBody::from(request_body.clone()))?;
391
392    let response = http_client
393        .send(http_request)
394        .await
395        .context("failed to send request to Snowflake SQL API")?;
396
397    let status = response.status();
398    let body_bytes = {
399        use futures::AsyncReadExt as _;
400
401        let mut body = response.into_body();
402        let mut bytes = Vec::new();
403        body.read_to_end(&mut bytes)
404            .await
405            .context("failed to read Snowflake SQL API response body")?;
406        bytes
407    };
408
409    if !status.is_success() && status.as_u16() != 202 {
410        let body_text = String::from_utf8_lossy(&body_bytes);
411        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
412    }
413
414    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
415        .context("failed to parse Snowflake SQL API response JSON")
416}