pull_examples.rs

   1use anyhow::{Context as _, Result};
   2use flate2::read::GzDecoder;
   3use gpui::BackgroundExecutor;
   4use http_client::{AsyncBody, HttpClient, Method, Request};
   5use indoc::indoc;
   6use serde::Deserialize;
   7use serde_json::{Value as JsonValue, json};
   8use std::io::Read;
   9use std::sync::Arc;
  10use std::time::Duration;
  11use telemetry_events::EditPredictionRating;
  12
  13use zeta_prompt::ZetaPromptInput;
  14
  15use crate::example::Example;
  16use crate::progress::{InfoStyle, Progress, Step};
  17use edit_prediction::example_spec::{
  18    CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
  19    TelemetrySource,
  20};
  21use std::fmt::Write as _;
  22
  23const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
  24const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
  25const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
  26const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
  27const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
  28const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
  29
  30const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
  31const POLL_INTERVAL: Duration = Duration::from_secs(2);
  32const MAX_POLL_ATTEMPTS: usize = 120;
  33
  34/// Parse an input token of the form `captured-after:{timestamp}`.
  35pub fn parse_captured_after_input(input: &str) -> Option<&str> {
  36    input.strip_prefix("captured-after:")
  37}
  38
  39/// Parse an input token of the form `rejected-after:{timestamp}`.
  40pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
  41    input.strip_prefix("rejected-after:")
  42}
  43
  44/// Parse an input token of the form `requested-after:{timestamp}`.
  45pub fn parse_requested_after_input(input: &str) -> Option<&str> {
  46    input.strip_prefix("requested-after:")
  47}
  48
  49/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
  50/// or `rated-negative-after:{timestamp}`.
  51/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
  52pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
  53    if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
  54        Some((timestamp, Some(EditPredictionRating::Positive)))
  55    } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
  56        Some((timestamp, Some(EditPredictionRating::Negative)))
  57    } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
  58        Some((timestamp, None))
  59    } else {
  60        None
  61    }
  62}
  63
  64pub async fn fetch_captured_examples_after(
  65    http_client: Arc<dyn HttpClient>,
  66    after_timestamps: &[String],
  67    max_rows_per_timestamp: usize,
  68    background_executor: BackgroundExecutor,
  69) -> Result<Vec<Example>> {
  70    if after_timestamps.is_empty() {
  71        return Ok(Vec::new());
  72    }
  73
  74    let progress = Progress::global();
  75
  76    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
  77        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
  78    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
  79        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
  80    )?;
  81    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
  82
  83    let mut all_examples = Vec::new();
  84
  85    for after_date in after_timestamps.iter() {
  86        let step_progress_name = format!(">{after_date}");
  87        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
  88        step_progress.set_substatus("querying");
  89
  90        let statement = indoc! {r#"
  91            SELECT
  92                event_properties:example AS example
  93            FROM events
  94            WHERE event_type = ?
  95                AND time > TRY_TO_TIMESTAMP_NTZ(?)
  96            ORDER BY time ASC
  97            LIMIT ?
  98        "#};
  99
 100        let request = json!({
 101            "statement": statement,
 102            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 103            "database": "EVENTS",
 104            "schema": "PUBLIC",
 105            "warehouse": "DBT",
 106            "role": role,
 107            "bindings": {
 108                "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
 109                "2": { "type": "TEXT", "value": after_date },
 110                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 111            }
 112        });
 113
 114        let response = run_sql_with_polling(
 115            http_client.clone(),
 116            &base_url,
 117            &token,
 118            &request,
 119            &step_progress,
 120            background_executor.clone(),
 121        )
 122        .await?;
 123
 124        let total_rows = response
 125            .result_set_meta_data
 126            .as_ref()
 127            .and_then(|m| m.num_rows)
 128            .unwrap_or(response.data.len() as i64);
 129
 130        let num_partitions = response
 131            .result_set_meta_data
 132            .as_ref()
 133            .map(|m| m.partition_info.len())
 134            .unwrap_or(1)
 135            .max(1);
 136
 137        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 138        step_progress.set_substatus("parsing");
 139
 140        let example_index = response
 141            .result_set_meta_data
 142            .as_ref()
 143            .and_then(|m| {
 144                m.row_type.iter().enumerate().find_map(|(index, col)| {
 145                    if col.name.eq_ignore_ascii_case("example") {
 146                        Some(index)
 147                    } else {
 148                        None
 149                    }
 150                })
 151            })
 152            .unwrap_or(0);
 153
 154        all_examples.extend(examples_from_response(&response, example_index)?);
 155
 156        if num_partitions > 1 {
 157            let statement_handle = response
 158                .statement_handle
 159                .as_ref()
 160                .context("response has multiple partitions but no statementHandle")?;
 161
 162            for partition in 1..num_partitions {
 163                step_progress.set_substatus(format!(
 164                    "fetching partition {}/{}",
 165                    partition + 1,
 166                    num_partitions
 167                ));
 168
 169                let partition_response = fetch_partition(
 170                    http_client.clone(),
 171                    &base_url,
 172                    &token,
 173                    statement_handle,
 174                    partition,
 175                )
 176                .await?;
 177
 178                all_examples.extend(examples_from_response(&partition_response, example_index)?);
 179            }
 180        }
 181
 182        step_progress.set_substatus("done");
 183    }
 184
 185    Ok(all_examples)
 186}
 187
 188#[derive(Debug, Clone, Deserialize)]
 189#[serde(rename_all = "camelCase")]
 190struct SnowflakeStatementResponse {
 191    #[serde(default)]
 192    data: Vec<Vec<JsonValue>>,
 193    #[serde(default)]
 194    result_set_meta_data: Option<SnowflakeResultSetMetaData>,
 195    #[serde(default)]
 196    code: Option<String>,
 197    #[serde(default)]
 198    message: Option<String>,
 199    #[serde(default)]
 200    statement_handle: Option<String>,
 201}
 202
 203#[derive(Debug, Clone, Deserialize)]
 204#[serde(rename_all = "camelCase")]
 205struct SnowflakeResultSetMetaData {
 206    #[serde(default, rename = "rowType")]
 207    row_type: Vec<SnowflakeColumnMeta>,
 208    #[serde(default)]
 209    num_rows: Option<i64>,
 210    #[serde(default)]
 211    partition_info: Vec<SnowflakePartitionInfo>,
 212}
 213
 214#[derive(Debug, Clone, Deserialize)]
 215#[serde(rename_all = "camelCase")]
 216struct SnowflakePartitionInfo {}
 217
 218#[derive(Debug, Clone, Deserialize)]
 219struct SnowflakeColumnMeta {
 220    #[serde(default)]
 221    name: String,
 222}
 223
 224fn examples_from_response(
 225    response: &SnowflakeStatementResponse,
 226    example_index: usize,
 227) -> Result<impl Iterator<Item = Example> + '_> {
 228    if let Some(code) = &response.code {
 229        if code != SNOWFLAKE_SUCCESS_CODE {
 230            anyhow::bail!(
 231                "snowflake sql api returned error code={code} message={}",
 232                response.message.as_deref().unwrap_or("<no message>")
 233            );
 234        }
 235    }
 236
 237    let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
 238        let Some(example_value) = data_row.get(example_index) else {
 239            return None;
 240        };
 241        if example_value.is_null() {
 242            return None;
 243        }
 244
 245        let parse_result = match example_value {
 246            JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
 247            _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
 248        };
 249
 250        match parse_result {
 251            Ok(spec) => Some(Example {
 252                spec,
 253                prompt_inputs: None,
 254                prompt: None,
 255                predictions: Vec::new(),
 256                score: Vec::new(),
 257                qa: Vec::new(),
 258                state: None,
 259            }),
 260            Err(error) => {
 261                let raw_json = serde_json::to_string_pretty(example_value)
 262                    .unwrap_or_else(|_| "<failed to serialize json>".to_string());
 263                log::error!(
 264                    "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
 265                );
 266                None
 267            }
 268        }
 269    });
 270
 271    Ok(iter)
 272}
 273
 274async fn run_sql_with_polling(
 275    http_client: Arc<dyn HttpClient>,
 276    base_url: &str,
 277    token: &str,
 278    request: &serde_json::Value,
 279    step_progress: &crate::progress::StepProgress,
 280    background_executor: BackgroundExecutor,
 281) -> Result<SnowflakeStatementResponse> {
 282    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
 283
 284    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 285        let statement_handle = response
 286            .statement_handle
 287            .as_ref()
 288            .context("async query response missing statementHandle")?
 289            .clone();
 290
 291        for attempt in 1..=MAX_POLL_ATTEMPTS {
 292            step_progress.set_substatus(format!("polling ({attempt})"));
 293
 294            background_executor.timer(POLL_INTERVAL).await;
 295
 296            response =
 297                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
 298
 299            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 300                break;
 301            }
 302        }
 303
 304        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 305            anyhow::bail!(
 306                "query still running after {} poll attempts ({} seconds)",
 307                MAX_POLL_ATTEMPTS,
 308                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
 309            );
 310        }
 311    }
 312
 313    Ok(response)
 314}
 315
 316async fn fetch_partition(
 317    http_client: Arc<dyn HttpClient>,
 318    base_url: &str,
 319    token: &str,
 320    statement_handle: &str,
 321    partition: usize,
 322) -> Result<SnowflakeStatementResponse> {
 323    let url = format!(
 324        "{}/api/v2/statements/{}?partition={}",
 325        base_url.trim_end_matches('/'),
 326        statement_handle,
 327        partition
 328    );
 329
 330    let http_request = Request::builder()
 331        .method(Method::GET)
 332        .uri(url.as_str())
 333        .header("Authorization", format!("Bearer {token}"))
 334        .header(
 335            "X-Snowflake-Authorization-Token-Type",
 336            "PROGRAMMATIC_ACCESS_TOKEN",
 337        )
 338        .header("Accept", "application/json")
 339        .header("Accept-Encoding", "gzip")
 340        .header("User-Agent", "edit_prediction_cli")
 341        .body(AsyncBody::empty())?;
 342
 343    let response = http_client
 344        .send(http_request)
 345        .await
 346        .context("failed to send partition request to Snowflake SQL API")?;
 347
 348    let status = response.status();
 349    let content_encoding = response
 350        .headers()
 351        .get("content-encoding")
 352        .and_then(|v| v.to_str().ok())
 353        .map(|s| s.to_lowercase());
 354
 355    let body_bytes = {
 356        use futures::AsyncReadExt as _;
 357
 358        let mut body = response.into_body();
 359        let mut bytes = Vec::new();
 360        body.read_to_end(&mut bytes)
 361            .await
 362            .context("failed to read Snowflake SQL API partition response body")?;
 363        bytes
 364    };
 365
 366    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
 367        let mut decoder = GzDecoder::new(&body_bytes[..]);
 368        let mut decompressed = Vec::new();
 369        decoder
 370            .read_to_end(&mut decompressed)
 371            .context("failed to decompress gzip response")?;
 372        decompressed
 373    } else {
 374        body_bytes
 375    };
 376
 377    if !status.is_success() && status.as_u16() != 202 {
 378        let body_text = String::from_utf8_lossy(&body_bytes);
 379        anyhow::bail!(
 380            "snowflake sql api partition request http {}: {}",
 381            status.as_u16(),
 382            body_text
 383        );
 384    }
 385
 386    if body_bytes.is_empty() {
 387        anyhow::bail!(
 388            "snowflake sql api partition {} returned empty response body (http {})",
 389            partition,
 390            status.as_u16()
 391        );
 392    }
 393
 394    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
 395        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
 396        format!(
 397            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
 398            partition,
 399            status.as_u16(),
 400            body_preview
 401        )
 402    })
 403}
 404
 405async fn run_sql(
 406    http_client: Arc<dyn HttpClient>,
 407    base_url: &str,
 408    token: &str,
 409    request: &serde_json::Value,
 410) -> Result<SnowflakeStatementResponse> {
 411    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
 412
 413    let request_body =
 414        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
 415
 416    let http_request = Request::builder()
 417        .method(Method::POST)
 418        .uri(url.as_str())
 419        .header("Authorization", format!("Bearer {token}"))
 420        .header(
 421            "X-Snowflake-Authorization-Token-Type",
 422            "PROGRAMMATIC_ACCESS_TOKEN",
 423        )
 424        .header("Content-Type", "application/json")
 425        .header("Accept", "application/json")
 426        .header("User-Agent", "edit_prediction_cli")
 427        .body(AsyncBody::from(request_body.clone()))?;
 428
 429    let response = http_client
 430        .send(http_request)
 431        .await
 432        .context("failed to send request to Snowflake SQL API")?;
 433
 434    let status = response.status();
 435    let body_bytes = {
 436        use futures::AsyncReadExt as _;
 437
 438        let mut body = response.into_body();
 439        let mut bytes = Vec::new();
 440        body.read_to_end(&mut bytes)
 441            .await
 442            .context("failed to read Snowflake SQL API response body")?;
 443        bytes
 444    };
 445
 446    if !status.is_success() && status.as_u16() != 202 {
 447        let body_text = String::from_utf8_lossy(&body_bytes);
 448        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
 449    }
 450
 451    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
 452        .context("failed to parse Snowflake SQL API response JSON")
 453}
 454
 455pub async fn fetch_rejected_examples_after(
 456    http_client: Arc<dyn HttpClient>,
 457    after_timestamps: &[String],
 458    max_rows_per_timestamp: usize,
 459    background_executor: BackgroundExecutor,
 460) -> Result<Vec<Example>> {
 461    if after_timestamps.is_empty() {
 462        return Ok(Vec::new());
 463    }
 464
 465    let progress = Progress::global();
 466
 467    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 468        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 469    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 470        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 471    )?;
 472    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 473
 474    let mut all_examples = Vec::new();
 475
 476    for after_date in after_timestamps.iter() {
 477        let step_progress_name = format!("rejected>{after_date}");
 478        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 479        step_progress.set_substatus("querying");
 480
 481        // Join rejected events with their corresponding request events to get the full context.
 482        // We filter for V3 sampling data which contains the structured input we need.
 483        // We also filter for predictions that were actually shown to the user (was_shown = true)
 484        // to focus on explicit user rejections rather than implicit cancellations.
 485        let statement = indoc! {r#"
 486            SELECT
 487                req.event_properties:request_id::string AS request_id,
 488                req.device_id::string AS device_id,
 489                req.time::string AS time,
 490                req.event_properties:input AS input,
 491                req.event_properties:prompt::string AS prompt,
 492                req.event_properties:output::string AS output,
 493                rej.event_properties:was_shown::boolean AS was_shown,
 494                rej.event_properties:reason::string AS reason
 495            FROM events req
 496            INNER JOIN events rej
 497                ON req.event_properties:request_id = rej.event_properties:request_id
 498            WHERE req.event_type = ?
 499                AND rej.event_type = ?
 500                AND req.event_properties:version = 'V3'
 501                AND rej.event_properties:was_shown = true
 502                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 503            ORDER BY req.time ASC
 504            LIMIT ?
 505        "#};
 506
 507        let request = json!({
 508            "statement": statement,
 509            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 510            "database": "EVENTS",
 511            "schema": "PUBLIC",
 512            "warehouse": "DBT",
 513            "role": role,
 514            "bindings": {
 515                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 516                "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
 517                "3": { "type": "TEXT", "value": after_date },
 518                "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 519            }
 520        });
 521
 522        let response = run_sql_with_polling(
 523            http_client.clone(),
 524            &base_url,
 525            &token,
 526            &request,
 527            &step_progress,
 528            background_executor.clone(),
 529        )
 530        .await?;
 531
 532        let total_rows = response
 533            .result_set_meta_data
 534            .as_ref()
 535            .and_then(|m| m.num_rows)
 536            .unwrap_or(response.data.len() as i64);
 537
 538        let num_partitions = response
 539            .result_set_meta_data
 540            .as_ref()
 541            .map(|m| m.partition_info.len())
 542            .unwrap_or(1)
 543            .max(1);
 544
 545        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 546        step_progress.set_substatus("parsing");
 547
 548        let column_indices = get_column_indices(
 549            &response.result_set_meta_data,
 550            &[
 551                "request_id",
 552                "device_id",
 553                "time",
 554                "input",
 555                "prompt",
 556                "output",
 557                "was_shown",
 558                "reason",
 559            ],
 560        );
 561
 562        all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
 563
 564        if num_partitions > 1 {
 565            let statement_handle = response
 566                .statement_handle
 567                .as_ref()
 568                .context("response has multiple partitions but no statementHandle")?;
 569
 570            for partition in 1..num_partitions {
 571                step_progress.set_substatus(format!(
 572                    "fetching partition {}/{}",
 573                    partition + 1,
 574                    num_partitions
 575                ));
 576
 577                let partition_response = fetch_partition(
 578                    http_client.clone(),
 579                    &base_url,
 580                    &token,
 581                    statement_handle,
 582                    partition,
 583                )
 584                .await?;
 585
 586                all_examples.extend(rejected_examples_from_response(
 587                    &partition_response,
 588                    &column_indices,
 589                )?);
 590            }
 591        }
 592
 593        step_progress.set_substatus("done");
 594    }
 595
 596    Ok(all_examples)
 597}
 598
 599pub async fn fetch_requested_examples_after(
 600    http_client: Arc<dyn HttpClient>,
 601    after_timestamps: &[String],
 602    max_rows_per_timestamp: usize,
 603    background_executor: BackgroundExecutor,
 604) -> Result<Vec<Example>> {
 605    if after_timestamps.is_empty() {
 606        return Ok(Vec::new());
 607    }
 608
 609    let progress = Progress::global();
 610
 611    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 612        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 613    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 614        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 615    )?;
 616    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 617
 618    let mut all_examples = Vec::new();
 619
 620    for after_date in after_timestamps.iter() {
 621        let step_progress_name = format!("requested>{after_date}");
 622        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 623        step_progress.set_substatus("querying");
 624
 625        let statement = indoc! {r#"
 626            SELECT
 627                req.event_properties:request_id::string AS request_id,
 628                req.device_id::string AS device_id,
 629                req.time::string AS time,
 630                req.event_properties:input AS input
 631            FROM events req
 632            WHERE req.event_type = ?
 633                AND req.event_properties:version = 'V3'
 634                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 635            ORDER BY req.time ASC
 636            LIMIT ?
 637        "#};
 638
 639        let request = json!({
 640            "statement": statement,
 641            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 642            "database": "EVENTS",
 643            "schema": "PUBLIC",
 644            "warehouse": "DBT",
 645            "role": role,
 646            "bindings": {
 647                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 648                "2": { "type": "TEXT", "value": after_date },
 649                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 650            }
 651        });
 652
 653        let response = run_sql_with_polling(
 654            http_client.clone(),
 655            &base_url,
 656            &token,
 657            &request,
 658            &step_progress,
 659            background_executor.clone(),
 660        )
 661        .await?;
 662
 663        let total_rows = response
 664            .result_set_meta_data
 665            .as_ref()
 666            .and_then(|m| m.num_rows)
 667            .unwrap_or(response.data.len() as i64);
 668
 669        let num_partitions = response
 670            .result_set_meta_data
 671            .as_ref()
 672            .map(|m| m.partition_info.len())
 673            .unwrap_or(1)
 674            .max(1);
 675
 676        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 677        step_progress.set_substatus("parsing");
 678
 679        let column_indices = get_column_indices(
 680            &response.result_set_meta_data,
 681            &["request_id", "device_id", "time", "input"],
 682        );
 683
 684        all_examples.extend(requested_examples_from_response(
 685            &response,
 686            &column_indices,
 687        )?);
 688
 689        if num_partitions > 1 {
 690            let statement_handle = response
 691                .statement_handle
 692                .as_ref()
 693                .context("response has multiple partitions but no statementHandle")?;
 694
 695            for partition in 1..num_partitions {
 696                step_progress.set_substatus(format!(
 697                    "fetching partition {}/{}",
 698                    partition + 1,
 699                    num_partitions
 700                ));
 701
 702                let partition_response = fetch_partition(
 703                    http_client.clone(),
 704                    &base_url,
 705                    &token,
 706                    statement_handle,
 707                    partition,
 708                )
 709                .await?;
 710
 711                all_examples.extend(requested_examples_from_response(
 712                    &partition_response,
 713                    &column_indices,
 714                )?);
 715            }
 716        }
 717
 718        step_progress.set_substatus("done");
 719    }
 720
 721    Ok(all_examples)
 722}
 723
 724pub async fn fetch_rated_examples_after(
 725    http_client: Arc<dyn HttpClient>,
 726    inputs: &[(String, Option<EditPredictionRating>)],
 727    max_rows_per_timestamp: usize,
 728    background_executor: BackgroundExecutor,
 729) -> Result<Vec<Example>> {
 730    if inputs.is_empty() {
 731        return Ok(Vec::new());
 732    }
 733
 734    let progress = Progress::global();
 735
 736    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 737        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 738    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 739        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 740    )?;
 741    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 742
 743    let mut all_examples = Vec::new();
 744
 745    for (after_date, rating_filter) in inputs.iter() {
 746        let filter_label = match rating_filter {
 747            None => "",
 748            Some(EditPredictionRating::Positive) => ":positive",
 749            Some(EditPredictionRating::Negative) => ":negative",
 750        };
 751        let step_progress_name = format!("rated{filter_label}>{after_date}");
 752        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 753        step_progress.set_substatus("querying");
 754
 755        let rating_value = rating_filter.as_ref().map(|r| match r {
 756            EditPredictionRating::Positive => "Positive",
 757            EditPredictionRating::Negative => "Negative",
 758        });
 759
 760        let statement = indoc! {r#"
 761            SELECT
 762                event_properties:inputs AS inputs,
 763                event_properties:output::string AS output,
 764                event_properties:rating::string AS rating,
 765                event_properties:feedback::string AS feedback,
 766                device_id::string AS device_id,
 767                time::string AS time
 768            FROM events
 769            WHERE event_type = ?
 770                AND (? IS NULL OR event_properties:rating::string = ?)
 771                AND time > TRY_TO_TIMESTAMP_NTZ(?)
 772                AND event_properties:inputs IS NOT NULL
 773                AND event_properties:inputs:cursor_excerpt IS NOT NULL
 774                AND event_properties:output IS NOT NULL
 775            ORDER BY time ASC
 776            LIMIT ?
 777        "#};
 778
 779        let bindings = json!({
 780            "1": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
 781            "2": { "type": "TEXT", "value": rating_value },
 782            "3": { "type": "TEXT", "value": rating_value },
 783            "4": { "type": "TEXT", "value": after_date },
 784            "5": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 785        });
 786
 787        let request = json!({
 788            "statement": statement,
 789            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 790            "database": "EVENTS",
 791            "schema": "PUBLIC",
 792            "warehouse": "DBT",
 793            "role": role,
 794            "bindings": bindings
 795        });
 796
 797        let response = run_sql_with_polling(
 798            http_client.clone(),
 799            &base_url,
 800            &token,
 801            &request,
 802            &step_progress,
 803            background_executor.clone(),
 804        )
 805        .await?;
 806
 807        let total_rows = response
 808            .result_set_meta_data
 809            .as_ref()
 810            .and_then(|m| m.num_rows)
 811            .unwrap_or(response.data.len() as i64);
 812
 813        let num_partitions = response
 814            .result_set_meta_data
 815            .as_ref()
 816            .map(|m| m.partition_info.len())
 817            .unwrap_or(1)
 818            .max(1);
 819
 820        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 821        step_progress.set_substatus("parsing");
 822
 823        let column_indices = get_column_indices(
 824            &response.result_set_meta_data,
 825            &[
 826                "inputs",
 827                "output",
 828                "rating",
 829                "feedback",
 830                "device_id",
 831                "time",
 832            ],
 833        );
 834
 835        all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
 836
 837        if num_partitions > 1 {
 838            let statement_handle = response
 839                .statement_handle
 840                .as_ref()
 841                .context("response has multiple partitions but no statementHandle")?;
 842
 843            for partition in 1..num_partitions {
 844                step_progress.set_substatus(format!(
 845                    "fetching partition {}/{}",
 846                    partition + 1,
 847                    num_partitions
 848                ));
 849
 850                let partition_response = fetch_partition(
 851                    http_client.clone(),
 852                    &base_url,
 853                    &token,
 854                    statement_handle,
 855                    partition,
 856                )
 857                .await?;
 858
 859                all_examples.extend(rated_examples_from_response(
 860                    &partition_response,
 861                    &column_indices,
 862                )?);
 863            }
 864        }
 865
 866        step_progress.set_substatus("done");
 867    }
 868
 869    Ok(all_examples)
 870}
 871
 872fn rated_examples_from_response<'a>(
 873    response: &'a SnowflakeStatementResponse,
 874    column_indices: &'a std::collections::HashMap<String, usize>,
 875) -> Result<impl Iterator<Item = Example> + 'a> {
 876    if let Some(code) = &response.code {
 877        if code != SNOWFLAKE_SUCCESS_CODE {
 878            anyhow::bail!(
 879                "snowflake sql api returned error code={code} message={}",
 880                response.message.as_deref().unwrap_or("<no message>")
 881            );
 882        }
 883    }
 884
 885    let iter = response
 886        .data
 887        .iter()
 888        .enumerate()
 889        .filter_map(move |(row_index, data_row)| {
 890            let get_string = |name: &str| -> Option<String> {
 891                let index = column_indices.get(name).copied()?;
 892                match data_row.get(index)? {
 893                    JsonValue::String(s) => Some(s.clone()),
 894                    JsonValue::Null => None,
 895                    other => Some(other.to_string()),
 896                }
 897            };
 898
 899            let get_json = |name: &str| -> Option<JsonValue> {
 900                let index = column_indices.get(name).copied()?;
 901                let value = data_row.get(index)?;
 902                if value.is_null() {
 903                    return None;
 904                }
 905                match value {
 906                    JsonValue::String(s) => serde_json::from_str(s).ok(),
 907                    other => Some(other.clone()),
 908                }
 909            };
 910
 911            let inputs_json = get_json("inputs");
 912            let inputs: Option<ZetaPromptInput> = match &inputs_json {
 913                Some(v) => match serde_json::from_value(v.clone()) {
 914                    Ok(parsed) => Some(parsed),
 915                    Err(e) => {
 916                        log::warn!(
 917                            "skipping row {row_index}: failed to parse inputs - {e}",
 918                        );
 919                        return None;
 920                    }
 921                },
 922                None => None,
 923            };
 924            let output = get_string("output");
 925            let rating = get_string("rating");
 926            let feedback = get_string("feedback").unwrap_or_default();
 927            let device_id = get_string("device_id");
 928            let time = get_string("time");
 929
 930            match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
 931                (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
 932                    Some(build_rated_example(
 933                        device_id,
 934                        time,
 935                        inputs,
 936                        output,
 937                        rating,
 938                        feedback,
 939                    ))
 940                }
 941                _ => {
 942                    log::warn!(
 943                        "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
 944                        inputs_json.is_some(),
 945                        output.is_some(),
 946                        rating.is_some(),
 947                        device_id.is_some(),
 948                        time.is_some(),
 949                    );
 950                    None
 951                }
 952            }
 953        });
 954
 955    Ok(iter)
 956}
 957
 958fn build_rated_example(
 959    device_id: String,
 960    time: String,
 961    input: ZetaPromptInput,
 962    output: String,
 963    rating: String,
 964    feedback: String,
 965) -> Example {
 966    let parsed_rating = if rating == "Positive" {
 967        EditPredictionRating::Positive
 968    } else {
 969        EditPredictionRating::Negative
 970    };
 971    let is_positive = parsed_rating == EditPredictionRating::Positive;
 972    let request_id = format!("rated-{}-{}", device_id, time);
 973
 974    let tags = if is_positive {
 975        vec!["rated:positive".to_string()]
 976    } else {
 977        vec!["rated:negative".to_string()]
 978    };
 979
 980    let mut example = build_example_from_snowflake(request_id, device_id, time, input, tags, None);
 981
 982    example.spec.rating = Some(parsed_rating);
 983
 984    if !feedback.is_empty() {
 985        example
 986            .spec
 987            .human_feedback
 988            .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
 989    }
 990
 991    if is_positive {
 992        example.spec.expected_patches = vec![output];
 993    } else {
 994        example.spec.rejected_patch = Some(output);
 995    }
 996
 997    example
 998}
 999
1000fn requested_examples_from_response<'a>(
1001    response: &'a SnowflakeStatementResponse,
1002    column_indices: &'a std::collections::HashMap<String, usize>,
1003) -> Result<impl Iterator<Item = Example> + 'a> {
1004    if let Some(code) = &response.code {
1005        if code != SNOWFLAKE_SUCCESS_CODE {
1006            anyhow::bail!(
1007                "snowflake sql api returned error code={code} message={}",
1008                response.message.as_deref().unwrap_or("<no message>")
1009            );
1010        }
1011    }
1012
1013    let iter = response
1014        .data
1015        .iter()
1016        .enumerate()
1017        .filter_map(move |(row_index, data_row)| {
1018            let get_string = |name: &str| -> Option<String> {
1019                let index = column_indices.get(name).copied()?;
1020                match data_row.get(index)? {
1021                    JsonValue::String(s) => Some(s.clone()),
1022                    JsonValue::Null => None,
1023                    other => Some(other.to_string()),
1024                }
1025            };
1026
1027            let get_json = |name: &str| -> Option<JsonValue> {
1028                let index = column_indices.get(name).copied()?;
1029                let value = data_row.get(index)?;
1030                if value.is_null() {
1031                    return None;
1032                }
1033                match value {
1034                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1035                    other => Some(other.clone()),
1036                }
1037            };
1038
1039            let request_id_str = get_string("request_id");
1040            let device_id = get_string("device_id");
1041            let time = get_string("time");
1042            let input_json = get_json("input");
1043            let input: Option<ZetaPromptInput> =
1044                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1045
1046            match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1047                (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1048                    Some(build_example_from_snowflake(
1049                        request_id,
1050                        device_id,
1051                        time,
1052                        input,
1053                        vec!["requested".to_string()],
1054                        None,
1055                    ))
1056                }
1057                _ => {
1058                    log::warn!(
1059                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1060                        request_id_str.is_some(),
1061                        device_id.is_some(),
1062                        time.is_some(),
1063                        input_json.is_some(),
1064                    );
1065                    None
1066                }
1067            }
1068        });
1069
1070    Ok(iter)
1071}
1072
1073fn rejected_examples_from_response<'a>(
1074    response: &'a SnowflakeStatementResponse,
1075    column_indices: &'a std::collections::HashMap<String, usize>,
1076) -> Result<impl Iterator<Item = Example> + 'a> {
1077    if let Some(code) = &response.code {
1078        if code != SNOWFLAKE_SUCCESS_CODE {
1079            anyhow::bail!(
1080                "snowflake sql api returned error code={code} message={}",
1081                response.message.as_deref().unwrap_or("<no message>")
1082            );
1083        }
1084    }
1085
1086    let iter = response
1087        .data
1088        .iter()
1089        .enumerate()
1090        .filter_map(move |(row_index, data_row)| {
1091            let get_string = |name: &str| -> Option<String> {
1092                let index = column_indices.get(name).copied()?;
1093                match data_row.get(index)? {
1094                    JsonValue::String(s) => Some(s.clone()),
1095                    JsonValue::Null => None,
1096                    other => Some(other.to_string()),
1097                }
1098            };
1099
1100            let get_json = |name: &str| -> Option<JsonValue> {
1101                let index = column_indices.get(name).copied()?;
1102                let value = data_row.get(index)?;
1103                if value.is_null() {
1104                    return None;
1105                }
1106                match value {
1107                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1108                    other => Some(other.clone()),
1109                }
1110            };
1111
1112            let get_bool = |name: &str| -> Option<bool> {
1113                let index = column_indices.get(name).copied()?;
1114                match data_row.get(index)? {
1115                    JsonValue::Bool(b) => Some(*b),
1116                    JsonValue::String(s) => s.parse().ok(),
1117                    _ => None,
1118                }
1119            };
1120
1121            let request_id_str = get_string("request_id");
1122            let device_id = get_string("device_id");
1123            let time = get_string("time");
1124            let input_json = get_json("input");
1125            let input: Option<ZetaPromptInput> =
1126                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1127            let output = get_string("output");
1128            let was_shown = get_bool("was_shown");
1129            let reason = get_string("reason");
1130
1131            match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1132                (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1133                    Some(build_rejected_example(
1134                        request_id,
1135                        device_id,
1136                        time,
1137                        input,
1138                        output,
1139                        was_shown,
1140                        reason,
1141                    ))
1142                }
1143                _ => {
1144                    log::warn!(
1145                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1146                        request_id_str.is_some(),
1147                        device_id.is_some(),
1148                        time.is_some(),
1149                        input_json.is_some(),
1150                        output.is_some(),
1151                        was_shown.is_some(),
1152                        reason.is_some()
1153                    );
1154                    None
1155                }
1156            }
1157        });
1158
1159    Ok(iter)
1160}
1161
1162fn build_rejected_example(
1163    request_id: String,
1164    device_id: String,
1165    time: String,
1166    input: ZetaPromptInput,
1167    output: String,
1168    was_shown: bool,
1169    reason: String,
1170) -> Example {
1171    let rejected_patch = build_output_patch(
1172        &input.cursor_path,
1173        input.cursor_excerpt.as_ref(),
1174        &input.editable_range_in_excerpt,
1175        &output,
1176    );
1177    let mut example = build_example_from_snowflake(
1178        request_id,
1179        device_id,
1180        time,
1181        input,
1182        vec![format!("rejection:{}", reason.to_lowercase())],
1183        Some(RejectionInfo { reason, was_shown }),
1184    );
1185    example.spec.rejected_patch = Some(rejected_patch);
1186    example
1187}
1188
1189struct RejectionInfo {
1190    reason: String,
1191    was_shown: bool,
1192}
1193
1194fn build_example_from_snowflake(
1195    request_id: String,
1196    device_id: String,
1197    time: String,
1198    input: ZetaPromptInput,
1199    tags: Vec<String>,
1200    rejection: Option<RejectionInfo>,
1201) -> Example {
1202    let events: Vec<CapturedEvent> = input
1203        .events
1204        .iter()
1205        .map(|event| match event.as_ref() {
1206            zeta_prompt::Event::BufferChange {
1207                path,
1208                old_path,
1209                diff,
1210                predicted,
1211                in_open_source_repo,
1212            } => CapturedEvent {
1213                path: path.clone(),
1214                old_path: old_path.clone(),
1215                diff: diff.clone(),
1216                predicted: *predicted,
1217                in_open_source_repo: *in_open_source_repo,
1218            },
1219        })
1220        .collect();
1221
1222    let related_files: Vec<CapturedRelatedFile> = input
1223        .related_files
1224        .iter()
1225        .map(|rf| CapturedRelatedFile {
1226            path: rf.path.clone(),
1227            max_row: rf.max_row,
1228            excerpts: rf
1229                .excerpts
1230                .iter()
1231                .map(|e| CapturedRelatedExcerpt {
1232                    row_range: e.row_range.clone(),
1233                    text: e.text.to_string(),
1234                })
1235                .collect(),
1236        })
1237        .collect();
1238
1239    let cursor_excerpt = input.cursor_excerpt.as_ref();
1240    let cursor_offset = input.cursor_offset_in_excerpt;
1241
1242    let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
1243
1244    let mut edit_history = String::new();
1245    for event in &input.events {
1246        zeta_prompt::write_event(&mut edit_history, event);
1247        edit_history.push('\n');
1248    }
1249
1250    let (rejection_reason, was_shown) = match &rejection {
1251        Some(r) => (r.reason.clone(), r.was_shown),
1252        None => (String::new(), false),
1253    };
1254
1255    let spec = ExampleSpec {
1256        name: request_id.clone(),
1257        repository_url: String::new(),
1258        revision: String::new(),
1259        tags,
1260        reasoning: None,
1261        uncommitted_diff: String::new(),
1262        cursor_path: input.cursor_path.clone(),
1263        cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1264        edit_history,
1265        expected_patches: Vec::new(),
1266        rejected_patch: None,
1267        captured_prompt_input: Some(CapturedPromptInput {
1268            cursor_file_content: cursor_excerpt.to_string(),
1269            cursor_offset,
1270            cursor_row,
1271            cursor_column,
1272            excerpt_start_row: None,
1273            events,
1274            related_files,
1275        }),
1276        telemetry: Some(TelemetrySource {
1277            request_id,
1278            device_id,
1279            time,
1280            rejection_reason,
1281            was_shown,
1282        }),
1283        human_feedback: Vec::new(),
1284        rating: None,
1285    };
1286
1287    Example {
1288        spec,
1289        prompt_inputs: None,
1290        prompt: None,
1291        predictions: Vec::new(),
1292        score: Vec::new(),
1293        qa: Vec::new(),
1294        state: None,
1295    }
1296}
1297
1298fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
1299    let mut row = 0u32;
1300    let mut last_newline_offset = 0;
1301    for (i, c) in text.char_indices() {
1302        if i >= offset {
1303            break;
1304        }
1305        if c == '\n' {
1306            row += 1;
1307            last_newline_offset = i + 1;
1308        }
1309    }
1310    let column = (offset - last_newline_offset) as u32;
1311    (row, column)
1312}
1313
1314fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1315    let before = &excerpt[..cursor_offset.min(excerpt.len())];
1316    let after = &excerpt[cursor_offset.min(excerpt.len())..];
1317    format!("{}[CURSOR_POSITION]{}", before, after)
1318}
1319
1320fn build_output_patch(
1321    cursor_path: &std::path::Path,
1322    cursor_excerpt: &str,
1323    editable_range: &std::ops::Range<usize>,
1324    model_output: &str,
1325) -> String {
1326    let old_text = &cursor_excerpt[editable_range.clone()];
1327
1328    let editable_start_row = cursor_excerpt[..editable_range.start]
1329        .chars()
1330        .filter(|&c| c == '\n')
1331        .count() as u32;
1332
1333    let diff_body = language::unified_diff_with_offsets(
1334        old_text,
1335        model_output,
1336        editable_start_row,
1337        editable_start_row,
1338    );
1339
1340    let mut patch = String::new();
1341    writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1342    writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1343    patch.push_str(&diff_body);
1344    patch
1345}
1346
1347fn get_column_indices(
1348    meta: &Option<SnowflakeResultSetMetaData>,
1349    names: &[&str],
1350) -> std::collections::HashMap<String, usize> {
1351    let mut indices = std::collections::HashMap::new();
1352    if let Some(meta) = meta {
1353        for (index, col) in meta.row_type.iter().enumerate() {
1354            for &name in names {
1355                if col.name.eq_ignore_ascii_case(name) {
1356                    indices.insert(name.to_string(), index);
1357                }
1358            }
1359        }
1360    }
1361    indices
1362}