pull_examples.rs

   1use anyhow::{Context as _, Result};
   2use flate2::read::GzDecoder;
   3use gpui::BackgroundExecutor;
   4use http_client::{AsyncBody, HttpClient, Method, Request};
   5use indoc::indoc;
   6use serde::Deserialize;
   7use serde_json::{Value as JsonValue, json};
   8use std::fmt::Write as _;
   9use std::io::Read;
  10use std::sync::Arc;
  11use std::time::Duration;
  12use telemetry_events::EditPredictionRating;
  13
  14use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format};
  15
  16use crate::example::Example;
  17use crate::progress::{InfoStyle, Progress, Step};
  18const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
  19use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
  20
  21pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
  22pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
  23const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
  24const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
  25const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
  26const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
  27
  28/// Minimum Zed version for filtering captured examples.
  29/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
  30/// where `zed_version >= 0.224.1`.
  31#[derive(Clone, Copy, Debug)]
  32pub struct MinCaptureVersion {
  33    pub minor: u32,
  34    pub patch: u32,
  35}
  36
  37const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
  38const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
  39pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
  40pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
  41
  42/// Parse an input token of the form `captured-after:{timestamp}`.
  43pub fn parse_captured_after_input(input: &str) -> Option<&str> {
  44    input.strip_prefix("captured-after:")
  45}
  46
  47/// Parse an input token of the form `rejected-after:{timestamp}`.
  48pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
  49    input.strip_prefix("rejected-after:")
  50}
  51
  52/// Parse an input token of the form `requested-after:{timestamp}`.
  53pub fn parse_requested_after_input(input: &str) -> Option<&str> {
  54    input.strip_prefix("requested-after:")
  55}
  56
  57/// Parse an input token of the form `settled-after:{timestamp}`.
  58pub fn parse_settled_after_input(input: &str) -> Option<&str> {
  59    input.strip_prefix("settled-after:")
  60}
  61
  62/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
  63/// or `rated-negative-after:{timestamp}`.
  64/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
  65pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
  66    if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
  67        Some((timestamp, Some(EditPredictionRating::Positive)))
  68    } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
  69        Some((timestamp, Some(EditPredictionRating::Negative)))
  70    } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
  71        Some((timestamp, None))
  72    } else {
  73        None
  74    }
  75}
  76
  77#[derive(Debug, Clone, Deserialize)]
  78#[serde(rename_all = "camelCase")]
  79pub(crate) struct SnowflakeStatementResponse {
  80    #[serde(default)]
  81    pub(crate) data: Vec<Vec<JsonValue>>,
  82    #[serde(default)]
  83    pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
  84    #[serde(default)]
  85    pub(crate) code: Option<String>,
  86    #[serde(default)]
  87    pub(crate) message: Option<String>,
  88    #[serde(default)]
  89    pub(crate) statement_handle: Option<String>,
  90}
  91
  92#[derive(Debug, Clone, Deserialize)]
  93#[serde(rename_all = "camelCase")]
  94pub(crate) struct SnowflakeResultSetMetaData {
  95    #[serde(default, rename = "rowType")]
  96    row_type: Vec<SnowflakeColumnMeta>,
  97    #[serde(default)]
  98    num_rows: Option<i64>,
  99    #[serde(default)]
 100    partition_info: Vec<SnowflakePartitionInfo>,
 101}
 102
 103#[derive(Debug, Clone, Deserialize)]
 104#[serde(rename_all = "camelCase")]
 105struct SnowflakePartitionInfo {}
 106
 107#[derive(Debug, Clone, Deserialize)]
 108struct SnowflakeColumnMeta {
 109    #[serde(default)]
 110    name: String,
 111}
 112
 113async fn run_sql_with_polling(
 114    http_client: Arc<dyn HttpClient>,
 115    base_url: &str,
 116    token: &str,
 117    request: &serde_json::Value,
 118    step_progress: &crate::progress::StepProgress,
 119    background_executor: BackgroundExecutor,
 120) -> Result<SnowflakeStatementResponse> {
 121    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
 122
 123    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 124        let statement_handle = response
 125            .statement_handle
 126            .as_ref()
 127            .context("async query response missing statementHandle")?
 128            .clone();
 129
 130        for attempt in 1..=MAX_POLL_ATTEMPTS {
 131            step_progress.set_substatus(format!("polling ({attempt})"));
 132
 133            background_executor.timer(POLL_INTERVAL).await;
 134
 135            response =
 136                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
 137
 138            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 139                break;
 140            }
 141        }
 142
 143        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 144            anyhow::bail!(
 145                "query still running after {} poll attempts ({} seconds)",
 146                MAX_POLL_ATTEMPTS,
 147                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
 148            );
 149        }
 150    }
 151
 152    Ok(response)
 153}
 154
 155pub(crate) async fn fetch_partition(
 156    http_client: Arc<dyn HttpClient>,
 157    base_url: &str,
 158    token: &str,
 159    statement_handle: &str,
 160    partition: usize,
 161) -> Result<SnowflakeStatementResponse> {
 162    let url = format!(
 163        "{}/api/v2/statements/{}?partition={}",
 164        base_url.trim_end_matches('/'),
 165        statement_handle,
 166        partition
 167    );
 168
 169    let http_request = Request::builder()
 170        .method(Method::GET)
 171        .uri(url.as_str())
 172        .header("Authorization", format!("Bearer {token}"))
 173        .header(
 174            "X-Snowflake-Authorization-Token-Type",
 175            "PROGRAMMATIC_ACCESS_TOKEN",
 176        )
 177        .header("Accept", "application/json")
 178        .header("Accept-Encoding", "gzip")
 179        .header("User-Agent", "edit_prediction_cli")
 180        .body(AsyncBody::empty())?;
 181
 182    let response = http_client
 183        .send(http_request)
 184        .await
 185        .context("failed to send partition request to Snowflake SQL API")?;
 186
 187    let status = response.status();
 188    let content_encoding = response
 189        .headers()
 190        .get("content-encoding")
 191        .and_then(|v| v.to_str().ok())
 192        .map(|s| s.to_lowercase());
 193
 194    let body_bytes = {
 195        use futures::AsyncReadExt as _;
 196
 197        let mut body = response.into_body();
 198        let mut bytes = Vec::new();
 199        body.read_to_end(&mut bytes)
 200            .await
 201            .context("failed to read Snowflake SQL API partition response body")?;
 202        bytes
 203    };
 204
 205    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
 206        let mut decoder = GzDecoder::new(&body_bytes[..]);
 207        let mut decompressed = Vec::new();
 208        decoder
 209            .read_to_end(&mut decompressed)
 210            .context("failed to decompress gzip response")?;
 211        decompressed
 212    } else {
 213        body_bytes
 214    };
 215
 216    if !status.is_success() && status.as_u16() != 202 {
 217        let body_text = String::from_utf8_lossy(&body_bytes);
 218        anyhow::bail!(
 219            "snowflake sql api partition request http {}: {}",
 220            status.as_u16(),
 221            body_text
 222        );
 223    }
 224
 225    if body_bytes.is_empty() {
 226        anyhow::bail!(
 227            "snowflake sql api partition {} returned empty response body (http {})",
 228            partition,
 229            status.as_u16()
 230        );
 231    }
 232
 233    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
 234        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
 235        format!(
 236            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
 237            partition,
 238            status.as_u16(),
 239            body_preview
 240        )
 241    })
 242}
 243
 244pub(crate) async fn run_sql(
 245    http_client: Arc<dyn HttpClient>,
 246    base_url: &str,
 247    token: &str,
 248    request: &serde_json::Value,
 249) -> Result<SnowflakeStatementResponse> {
 250    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
 251
 252    let request_body =
 253        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
 254
 255    let http_request = Request::builder()
 256        .method(Method::POST)
 257        .uri(url.as_str())
 258        .header("Authorization", format!("Bearer {token}"))
 259        .header(
 260            "X-Snowflake-Authorization-Token-Type",
 261            "PROGRAMMATIC_ACCESS_TOKEN",
 262        )
 263        .header("Content-Type", "application/json")
 264        .header("Accept", "application/json")
 265        .header("User-Agent", "edit_prediction_cli")
 266        .body(AsyncBody::from(request_body.clone()))?;
 267
 268    let response = http_client
 269        .send(http_request)
 270        .await
 271        .context("failed to send request to Snowflake SQL API")?;
 272
 273    let status = response.status();
 274    let body_bytes = {
 275        use futures::AsyncReadExt as _;
 276
 277        let mut body = response.into_body();
 278        let mut bytes = Vec::new();
 279        body.read_to_end(&mut bytes)
 280            .await
 281            .context("failed to read Snowflake SQL API response body")?;
 282        bytes
 283    };
 284
 285    if !status.is_success() && status.as_u16() != 202 {
 286        let body_text = String::from_utf8_lossy(&body_bytes);
 287        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
 288    }
 289
 290    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
 291        .context("failed to parse Snowflake SQL API response JSON")
 292}
 293
 294pub async fn fetch_rejected_examples_after(
 295    http_client: Arc<dyn HttpClient>,
 296    after_timestamps: &[String],
 297    max_rows_per_timestamp: usize,
 298    offset: usize,
 299    background_executor: BackgroundExecutor,
 300    min_capture_version: Option<MinCaptureVersion>,
 301) -> Result<Vec<Example>> {
 302    if after_timestamps.is_empty() {
 303        return Ok(Vec::new());
 304    }
 305
 306    let progress = Progress::global();
 307
 308    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 309        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 310    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 311        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 312    )?;
 313    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 314
 315    let mut all_examples = Vec::new();
 316
 317    for after_date in after_timestamps.iter() {
 318        let step_progress_name = format!("rejected>{after_date}");
 319        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 320        step_progress.set_substatus("querying");
 321
 322        // Join rejected events with their corresponding request events to get the full context.
 323        // We filter for V3 sampling data which contains the structured input we need.
 324        // We also filter for predictions that were actually shown to the user (was_shown = true)
 325        // to focus on explicit user rejections rather than implicit cancellations.
 326        let statement = indoc! {r#"
 327            SELECT
 328                req.event_properties:request_id::string AS request_id,
 329                req.device_id::string AS device_id,
 330                req.time::string AS time,
 331                req.event_properties:input AS input,
 332                req.event_properties:prompt::string AS prompt,
 333                req.event_properties:output::string AS output,
 334                rej.event_properties:was_shown::boolean AS was_shown,
 335                rej.event_properties:reason::string AS reason,
 336                req.event_properties:zed_version::string AS zed_version
 337            FROM events req
 338            INNER JOIN events rej
 339                ON req.event_properties:request_id = rej.event_properties:request_id
 340            WHERE req.event_type = ?
 341                AND rej.event_type = ?
 342                AND req.event_properties:version = 'V3'
 343                AND rej.event_properties:was_shown = true
 344                AND req.event_properties:input:can_collect_data = true
 345                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 346                AND (? IS NULL OR (
 347                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 348                    OR (
 349                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 350                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 351                    )
 352                ))
 353            ORDER BY req.time ASC
 354            LIMIT ?
 355            OFFSET ?
 356        "#};
 357
 358        let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
 359        let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
 360        let min_minor_str_ref = min_minor_str.as_deref();
 361        let min_patch_str_ref = min_patch_str.as_deref();
 362        let request = json!({
 363            "statement": statement,
 364            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 365            "database": "EVENTS",
 366            "schema": "PUBLIC",
 367            "warehouse": "DBT",
 368            "role": role,
 369            "bindings": {
 370                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 371                "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
 372                "3": { "type": "TEXT", "value": after_date },
 373                "4": { "type": "FIXED", "value": min_minor_str_ref },
 374                "5": { "type": "FIXED", "value": min_minor_str_ref },
 375                "6": { "type": "FIXED", "value": min_minor_str_ref },
 376                "7": { "type": "FIXED", "value": min_patch_str_ref },
 377                "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 378                "9": { "type": "FIXED", "value": offset.to_string() }
 379            }
 380        });
 381
 382        let response = run_sql_with_polling(
 383            http_client.clone(),
 384            &base_url,
 385            &token,
 386            &request,
 387            &step_progress,
 388            background_executor.clone(),
 389        )
 390        .await?;
 391
 392        let total_rows = response
 393            .result_set_meta_data
 394            .as_ref()
 395            .and_then(|m| m.num_rows)
 396            .unwrap_or(response.data.len() as i64);
 397
 398        let num_partitions = response
 399            .result_set_meta_data
 400            .as_ref()
 401            .map(|m| m.partition_info.len())
 402            .unwrap_or(1)
 403            .max(1);
 404
 405        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 406        step_progress.set_substatus("parsing");
 407
 408        let column_indices = get_column_indices(
 409            &response.result_set_meta_data,
 410            &[
 411                "request_id",
 412                "device_id",
 413                "time",
 414                "input",
 415                "prompt",
 416                "output",
 417                "was_shown",
 418                "reason",
 419                "zed_version",
 420            ],
 421        );
 422
 423        all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
 424
 425        if num_partitions > 1 {
 426            let statement_handle = response
 427                .statement_handle
 428                .as_ref()
 429                .context("response has multiple partitions but no statementHandle")?;
 430
 431            for partition in 1..num_partitions {
 432                step_progress.set_substatus(format!(
 433                    "fetching partition {}/{}",
 434                    partition + 1,
 435                    num_partitions
 436                ));
 437
 438                let partition_response = fetch_partition(
 439                    http_client.clone(),
 440                    &base_url,
 441                    &token,
 442                    statement_handle,
 443                    partition,
 444                )
 445                .await?;
 446
 447                all_examples.extend(rejected_examples_from_response(
 448                    &partition_response,
 449                    &column_indices,
 450                )?);
 451            }
 452        }
 453
 454        step_progress.set_substatus("done");
 455    }
 456
 457    Ok(all_examples)
 458}
 459
 460pub async fn fetch_requested_examples_after(
 461    http_client: Arc<dyn HttpClient>,
 462    after_timestamps: &[String],
 463    max_rows_per_timestamp: usize,
 464    offset: usize,
 465    background_executor: BackgroundExecutor,
 466    min_capture_version: Option<MinCaptureVersion>,
 467) -> Result<Vec<Example>> {
 468    if after_timestamps.is_empty() {
 469        return Ok(Vec::new());
 470    }
 471
 472    let progress = Progress::global();
 473
 474    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 475        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 476    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 477        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 478    )?;
 479    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 480
 481    let mut all_examples = Vec::new();
 482
 483    for after_date in after_timestamps.iter() {
 484        let step_progress_name = format!("requested>{after_date}");
 485        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 486        step_progress.set_substatus("querying");
 487
 488        let statement = indoc! {r#"
 489            SELECT
 490                req.event_properties:request_id::string AS request_id,
 491                req.device_id::string AS device_id,
 492                req.time::string AS time,
 493                req.event_properties:input AS input,
 494                req.event_properties:zed_version::string AS zed_version
 495            FROM events req
 496            WHERE req.event_type = ?
 497                AND req.event_properties:version = 'V3'
 498                AND req.event_properties:input:can_collect_data = true
 499                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 500                AND (? IS NULL OR (
 501                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 502                    OR (
 503                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 504                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 505                    )
 506                ))
 507            ORDER BY req.time ASC
 508            LIMIT ?
 509            OFFSET ?
 510        "#};
 511
 512        let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
 513        let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
 514        let min_minor_str_ref = min_minor_str.as_deref();
 515        let min_patch_str_ref = min_patch_str.as_deref();
 516        let request = json!({
 517            "statement": statement,
 518            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 519            "database": "EVENTS",
 520            "schema": "PUBLIC",
 521            "warehouse": "DBT",
 522            "role": role,
 523            "bindings": {
 524                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 525                "2": { "type": "TEXT", "value": after_date },
 526                "3": { "type": "FIXED", "value": min_minor_str_ref },
 527                "4": { "type": "FIXED", "value": min_minor_str_ref },
 528                "5": { "type": "FIXED", "value": min_minor_str_ref },
 529                "6": { "type": "FIXED", "value": min_patch_str_ref },
 530                "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 531                "8": { "type": "FIXED", "value": offset.to_string() }
 532            }
 533        });
 534
 535        let response = run_sql_with_polling(
 536            http_client.clone(),
 537            &base_url,
 538            &token,
 539            &request,
 540            &step_progress,
 541            background_executor.clone(),
 542        )
 543        .await?;
 544
 545        let total_rows = response
 546            .result_set_meta_data
 547            .as_ref()
 548            .and_then(|m| m.num_rows)
 549            .unwrap_or(response.data.len() as i64);
 550
 551        let num_partitions = response
 552            .result_set_meta_data
 553            .as_ref()
 554            .map(|m| m.partition_info.len())
 555            .unwrap_or(1)
 556            .max(1);
 557
 558        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 559        step_progress.set_substatus("parsing");
 560
 561        let column_indices = get_column_indices(
 562            &response.result_set_meta_data,
 563            &["request_id", "device_id", "time", "input", "zed_version"],
 564        );
 565
 566        all_examples.extend(requested_examples_from_response(
 567            &response,
 568            &column_indices,
 569        )?);
 570
 571        if num_partitions > 1 {
 572            let statement_handle = response
 573                .statement_handle
 574                .as_ref()
 575                .context("response has multiple partitions but no statementHandle")?;
 576
 577            for partition in 1..num_partitions {
 578                step_progress.set_substatus(format!(
 579                    "fetching partition {}/{}",
 580                    partition + 1,
 581                    num_partitions
 582                ));
 583
 584                let partition_response = fetch_partition(
 585                    http_client.clone(),
 586                    &base_url,
 587                    &token,
 588                    statement_handle,
 589                    partition,
 590                )
 591                .await?;
 592
 593                all_examples.extend(requested_examples_from_response(
 594                    &partition_response,
 595                    &column_indices,
 596                )?);
 597            }
 598        }
 599
 600        step_progress.set_substatus("done");
 601    }
 602
 603    Ok(all_examples)
 604}
 605
 606pub async fn fetch_settled_examples_after(
 607    http_client: Arc<dyn HttpClient>,
 608    after_timestamps: &[String],
 609    max_rows_per_timestamp: usize,
 610    offset: usize,
 611    background_executor: BackgroundExecutor,
 612    min_capture_version: Option<MinCaptureVersion>,
 613) -> Result<Vec<Example>> {
 614    if after_timestamps.is_empty() {
 615        return Ok(Vec::new());
 616    }
 617
 618    let progress = Progress::global();
 619
 620    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 621        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 622    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 623        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 624    )?;
 625    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 626
 627    let mut all_examples = Vec::new();
 628
 629    for after_date in after_timestamps.iter() {
 630        let step_progress_name = format!("settled>{after_date}");
 631        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 632        step_progress.set_substatus("querying");
 633
 634        let statement = indoc! {r#"
 635            WITH requested AS (
 636                SELECT
 637                    req.event_properties:request_id::string AS request_id,
 638                    req.device_id::string AS device_id,
 639                    req.time AS req_time,
 640                    req.time::string AS time,
 641                    req.event_properties:input AS input,
 642                    req.event_properties:format::string AS requested_format,
 643                    req.event_properties:output::string AS requested_output,
 644                    req.event_properties:zed_version::string AS zed_version
 645                FROM events req
 646                WHERE req.event_type = ?
 647                    AND req.event_properties:version = 'V3'
 648                    AND req.event_properties:input:can_collect_data = true
 649                    AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 650            )
 651            SELECT
 652                req.request_id AS request_id,
 653                req.device_id AS device_id,
 654                req.time AS time,
 655                req.input AS input,
 656                req.requested_output AS requested_output,
 657                settled.event_properties:settled_editable_region::string AS settled_editable_region,
 658                req.requested_format AS requested_format,
 659                req.zed_version AS zed_version
 660            FROM requested req
 661            INNER JOIN events settled
 662                ON req.request_id = settled.event_properties:request_id::string
 663            WHERE settled.event_type = ?
 664            ORDER BY req.req_time ASC
 665            LIMIT ?
 666            OFFSET ?
 667        "#};
 668
 669        let _ = min_capture_version;
 670        let request = json!({
 671            "statement": statement,
 672            "timeout": SETTLED_STATEMENT_TIMEOUT_SECONDS,
 673            "database": "EVENTS",
 674            "schema": "PUBLIC",
 675            "warehouse": "DBT",
 676            "role": role,
 677            "bindings": {
 678                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 679                "2": { "type": "TEXT", "value": after_date },
 680                "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
 681                "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 682                "5": { "type": "FIXED", "value": offset.to_string() }
 683            }
 684        });
 685
 686        let response = run_sql_with_polling(
 687            http_client.clone(),
 688            &base_url,
 689            &token,
 690            &request,
 691            &step_progress,
 692            background_executor.clone(),
 693        )
 694        .await?;
 695
 696        let total_rows = response
 697            .result_set_meta_data
 698            .as_ref()
 699            .and_then(|m| m.num_rows)
 700            .unwrap_or(response.data.len() as i64);
 701
 702        let num_partitions = response
 703            .result_set_meta_data
 704            .as_ref()
 705            .map(|m| m.partition_info.len())
 706            .unwrap_or(1)
 707            .max(1);
 708
 709        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 710        step_progress.set_substatus("parsing");
 711
 712        let column_indices = get_column_indices(
 713            &response.result_set_meta_data,
 714            &[
 715                "request_id",
 716                "device_id",
 717                "time",
 718                "input",
 719                "requested_output",
 720                "settled_editable_region",
 721                "requested_format",
 722                "zed_version",
 723            ],
 724        );
 725
 726        all_examples.extend(settled_examples_from_response(&response, &column_indices)?);
 727
 728        if num_partitions > 1 {
 729            let statement_handle = response
 730                .statement_handle
 731                .as_ref()
 732                .context("response has multiple partitions but no statementHandle")?;
 733
 734            for partition in 1..num_partitions {
 735                step_progress.set_substatus(format!(
 736                    "fetching partition {}/{}",
 737                    partition + 1,
 738                    num_partitions
 739                ));
 740
 741                let partition_response = fetch_partition(
 742                    http_client.clone(),
 743                    &base_url,
 744                    &token,
 745                    statement_handle,
 746                    partition,
 747                )
 748                .await?;
 749
 750                all_examples.extend(settled_examples_from_response(
 751                    &partition_response,
 752                    &column_indices,
 753                )?);
 754            }
 755        }
 756
 757        step_progress.set_substatus("done");
 758    }
 759
 760    Ok(all_examples)
 761}
 762
 763pub async fn fetch_rated_examples_after(
 764    http_client: Arc<dyn HttpClient>,
 765    inputs: &[(String, Option<EditPredictionRating>)],
 766    max_rows_per_timestamp: usize,
 767    offset: usize,
 768    background_executor: BackgroundExecutor,
 769    _min_capture_version: Option<MinCaptureVersion>,
 770) -> Result<Vec<Example>> {
 771    if inputs.is_empty() {
 772        return Ok(Vec::new());
 773    }
 774
 775    let progress = Progress::global();
 776
 777    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 778        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 779    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 780        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 781    )?;
 782    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 783
 784    let mut all_examples = Vec::new();
 785
 786    for (after_date, rating_filter) in inputs.iter() {
 787        let filter_label = match rating_filter {
 788            None => "",
 789            Some(EditPredictionRating::Positive) => ":positive",
 790            Some(EditPredictionRating::Negative) => ":negative",
 791        };
 792        let step_progress_name = format!("rated{filter_label}>{after_date}");
 793        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 794        step_progress.set_substatus("querying");
 795
 796        let rating_value = rating_filter.as_ref().map(|r| match r {
 797            EditPredictionRating::Positive => "Positive",
 798            EditPredictionRating::Negative => "Negative",
 799        });
 800
 801        let statement = indoc! {r#"
 802            SELECT
 803                rated.event_properties:request_id::string AS request_id,
 804                rated.event_properties:inputs AS inputs,
 805                rated.event_properties:output::string AS output,
 806                rated.event_properties:rating::string AS rating,
 807                rated.event_properties:feedback::string AS feedback,
 808                rated.device_id::string AS device_id,
 809                rated.time::string AS time,
 810                deploy.event_properties:experiment_name::string AS experiment_name,
 811                deploy.event_properties:environment::string AS environment,
 812                rated.event_properties:zed_version::string AS zed_version
 813            FROM events rated
 814            LEFT JOIN events req
 815                ON rated.event_properties:request_id::string = req.event_properties:request_id::string
 816                AND req.event_type = ?
 817            LEFT JOIN events deploy
 818                ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
 819                AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
 820                AND deploy.event_type = ?
 821            WHERE rated.event_type = ?
 822                AND (? IS NULL OR rated.event_properties:rating::string = ?)
 823                AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
 824                AND rated.event_properties:inputs IS NOT NULL
 825                AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
 826                AND rated.event_properties:output IS NOT NULL
 827                AND rated.event_properties:can_collect_data = true
 828            ORDER BY rated.time ASC
 829            LIMIT ?
 830            OFFSET ?
 831        "#};
 832
 833        let bindings = json!({
 834            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 835            "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
 836            "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
 837            "4": { "type": "TEXT", "value": rating_value },
 838            "5": { "type": "TEXT", "value": rating_value },
 839            "6": { "type": "TEXT", "value": after_date },
 840            "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 841            "8": { "type": "FIXED", "value": offset.to_string() }
 842        });
 843
 844        let request = json!({
 845            "statement": statement,
 846            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 847            "database": "EVENTS",
 848            "schema": "PUBLIC",
 849            "warehouse": "DBT",
 850            "role": role,
 851            "bindings": bindings
 852        });
 853
 854        let response = run_sql_with_polling(
 855            http_client.clone(),
 856            &base_url,
 857            &token,
 858            &request,
 859            &step_progress,
 860            background_executor.clone(),
 861        )
 862        .await?;
 863
 864        let total_rows = response
 865            .result_set_meta_data
 866            .as_ref()
 867            .and_then(|m| m.num_rows)
 868            .unwrap_or(response.data.len() as i64);
 869
 870        let num_partitions = response
 871            .result_set_meta_data
 872            .as_ref()
 873            .map(|m| m.partition_info.len())
 874            .unwrap_or(1)
 875            .max(1);
 876
 877        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 878        step_progress.set_substatus("parsing");
 879
 880        let column_indices = get_column_indices(
 881            &response.result_set_meta_data,
 882            &[
 883                "request_id",
 884                "inputs",
 885                "output",
 886                "rating",
 887                "feedback",
 888                "device_id",
 889                "time",
 890                "experiment_name",
 891                "environment",
 892                "zed_version",
 893            ],
 894        );
 895
 896        all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
 897
 898        if num_partitions > 1 {
 899            let statement_handle = response
 900                .statement_handle
 901                .as_ref()
 902                .context("response has multiple partitions but no statementHandle")?;
 903
 904            for partition in 1..num_partitions {
 905                step_progress.set_substatus(format!(
 906                    "fetching partition {}/{}",
 907                    partition + 1,
 908                    num_partitions
 909                ));
 910
 911                let partition_response = fetch_partition(
 912                    http_client.clone(),
 913                    &base_url,
 914                    &token,
 915                    statement_handle,
 916                    partition,
 917                )
 918                .await?;
 919
 920                all_examples.extend(rated_examples_from_response(
 921                    &partition_response,
 922                    &column_indices,
 923                )?);
 924            }
 925        }
 926
 927        step_progress.set_substatus("done");
 928    }
 929
 930    Ok(all_examples)
 931}
 932
 933fn rated_examples_from_response<'a>(
 934    response: &'a SnowflakeStatementResponse,
 935    column_indices: &'a std::collections::HashMap<String, usize>,
 936) -> Result<impl Iterator<Item = Example> + 'a> {
 937    if let Some(code) = &response.code {
 938        if code != SNOWFLAKE_SUCCESS_CODE {
 939            anyhow::bail!(
 940                "snowflake sql api returned error code={code} message={}",
 941                response.message.as_deref().unwrap_or("<no message>")
 942            );
 943        }
 944    }
 945
 946    let iter = response
 947        .data
 948        .iter()
 949        .enumerate()
 950        .filter_map(move |(row_index, data_row)| {
 951            let get_string = |name: &str| -> Option<String> {
 952                let index = column_indices.get(name).copied()?;
 953                match data_row.get(index)? {
 954                    JsonValue::String(s) => Some(s.clone()),
 955                    JsonValue::Null => None,
 956                    other => Some(other.to_string()),
 957                }
 958            };
 959
 960            let get_json = |name: &str| -> Option<JsonValue> {
 961                let index = column_indices.get(name).copied()?;
 962                let value = data_row.get(index)?;
 963                if value.is_null() {
 964                    return None;
 965                }
 966                match value {
 967                    JsonValue::String(s) => serde_json::from_str(s).ok(),
 968                    other => Some(other.clone()),
 969                }
 970            };
 971
 972            let request_id = get_string("request_id");
 973            let inputs_json = get_json("inputs");
 974            let inputs: Option<ZetaPromptInput> = match &inputs_json {
 975                Some(v) => match serde_json::from_value(v.clone()) {
 976                    Ok(parsed) => Some(parsed),
 977                    Err(e) => {
 978                        log::warn!(
 979                            "skipping row {row_index}: failed to parse inputs - {e}",
 980                        );
 981                        return None;
 982                    }
 983                },
 984                None => None,
 985            };
 986            let output = get_string("output");
 987            let rating = get_string("rating");
 988            let feedback = get_string("feedback").unwrap_or_default();
 989            let device_id = get_string("device_id");
 990            let time = get_string("time");
 991            let experiment_name = get_string("experiment_name");
 992            let environment = get_string("environment");
 993            let zed_version = get_string("zed_version");
 994
 995            match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
 996                (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
 997                    Some(build_rated_example(
 998                        request_id,
 999                        device_id,
1000                        time,
1001                        inputs,
1002                        output,
1003                        rating,
1004                        feedback,
1005                        experiment_name,
1006                        environment,
1007                        zed_version,
1008                    ))
1009                }
1010                _ => {
1011                    log::warn!(
1012                        "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
1013                        inputs_json.is_some(),
1014                        output.is_some(),
1015                        rating.is_some(),
1016                        device_id.is_some(),
1017                        time.is_some(),
1018                    );
1019                    None
1020                }
1021            }
1022        });
1023
1024    Ok(iter)
1025}
1026
1027fn build_rated_example(
1028    request_id: Option<String>,
1029    device_id: String,
1030    time: String,
1031    input: ZetaPromptInput,
1032    output: String,
1033    rating: String,
1034    feedback: String,
1035    experiment_name: Option<String>,
1036    environment: Option<String>,
1037    zed_version: Option<String>,
1038) -> Example {
1039    let parsed_rating = if rating == "Positive" {
1040        EditPredictionRating::Positive
1041    } else {
1042        EditPredictionRating::Negative
1043    };
1044    let is_positive = parsed_rating == EditPredictionRating::Positive;
1045    let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
1046
1047    let mut tags = Vec::with_capacity(3);
1048    tags.push(if is_positive {
1049        "rated:positive".to_string()
1050    } else {
1051        "rated:negative".to_string()
1052    });
1053    if let Some(experiment) = experiment_name {
1054        tags.push(format!("experiment:{experiment}"));
1055    }
1056    if let Some(env) = environment {
1057        tags.push(format!("environment:{env}"));
1058    }
1059
1060    let mut example =
1061        build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
1062
1063    example.spec.rating = Some(parsed_rating);
1064
1065    if !feedback.is_empty() {
1066        example
1067            .spec
1068            .human_feedback
1069            .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
1070    }
1071
1072    if is_positive {
1073        example.spec.expected_patches = vec![output];
1074    } else {
1075        example.spec.rejected_patch = Some(output);
1076    }
1077
1078    example
1079}
1080
1081fn requested_examples_from_response<'a>(
1082    response: &'a SnowflakeStatementResponse,
1083    column_indices: &'a std::collections::HashMap<String, usize>,
1084) -> Result<impl Iterator<Item = Example> + 'a> {
1085    if let Some(code) = &response.code {
1086        if code != SNOWFLAKE_SUCCESS_CODE {
1087            anyhow::bail!(
1088                "snowflake sql api returned error code={code} message={}",
1089                response.message.as_deref().unwrap_or("<no message>")
1090            );
1091        }
1092    }
1093
1094    let iter = response
1095        .data
1096        .iter()
1097        .enumerate()
1098        .filter_map(move |(row_index, data_row)| {
1099            let get_string = |name: &str| -> Option<String> {
1100                let index = column_indices.get(name).copied()?;
1101                match data_row.get(index)? {
1102                    JsonValue::String(s) => Some(s.clone()),
1103                    JsonValue::Null => None,
1104                    other => Some(other.to_string()),
1105                }
1106            };
1107
1108            let get_json = |name: &str| -> Option<JsonValue> {
1109                let index = column_indices.get(name).copied()?;
1110                let value = data_row.get(index)?;
1111                if value.is_null() {
1112                    return None;
1113                }
1114                match value {
1115                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1116                    other => Some(other.clone()),
1117                }
1118            };
1119
1120            let request_id_str = get_string("request_id");
1121            let device_id = get_string("device_id");
1122            let time = get_string("time");
1123            let input_json = get_json("input");
1124            let input: Option<ZetaPromptInput> =
1125                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1126            let zed_version = get_string("zed_version");
1127
1128            match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1129                (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1130                    Some(build_example_from_snowflake(
1131                        request_id,
1132                        device_id,
1133                        time,
1134                        input,
1135                        vec!["requested".to_string()],
1136                        None,
1137                        zed_version,
1138                    ))
1139                }
1140                _ => {
1141                    log::warn!(
1142                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1143                        request_id_str.is_some(),
1144                        device_id.is_some(),
1145                        time.is_some(),
1146                        input_json.is_some(),
1147                    );
1148                    None
1149                }
1150            }
1151        });
1152
1153    Ok(iter)
1154}
1155
1156fn settled_examples_from_response<'a>(
1157    response: &'a SnowflakeStatementResponse,
1158    column_indices: &'a std::collections::HashMap<String, usize>,
1159) -> Result<impl Iterator<Item = Example> + 'a> {
1160    if let Some(code) = &response.code {
1161        if code != SNOWFLAKE_SUCCESS_CODE {
1162            anyhow::bail!(
1163                "snowflake sql api returned error code={code} message={}",
1164                response.message.as_deref().unwrap_or("<no message>")
1165            );
1166        }
1167    }
1168
1169    let iter = response
1170        .data
1171        .iter()
1172        .enumerate()
1173        .filter_map(move |(row_index, data_row)| {
1174            let get_value = |name: &str| -> Option<JsonValue> {
1175                let index = column_indices.get(name).copied()?;
1176                let value = data_row.get(index)?;
1177                if value.is_null() {
1178                    None
1179                } else {
1180                    Some(value.clone())
1181                }
1182            };
1183
1184            let get_string = |name: &str| -> Option<String> {
1185                match get_value(name)? {
1186                    JsonValue::String(s) => Some(s),
1187                    other => Some(other.to_string()),
1188                }
1189            };
1190
1191            let parse_json_value = |_: &str, raw: Option<&JsonValue>| -> Option<JsonValue> {
1192                let value = raw?;
1193                match value {
1194                    JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
1195                    other => Some(other.clone()),
1196                }
1197            };
1198
1199            let request_id_str = get_string("request_id");
1200            let device_id = get_string("device_id");
1201            let time = get_string("time");
1202            let input_raw = get_value("input");
1203            let input_json = parse_json_value("input", input_raw.as_ref());
1204            let input: Option<ZetaPromptInput> = input_json
1205                .as_ref()
1206                .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
1207            let requested_output = get_string("requested_output");
1208            let settled_editable_region = get_string("settled_editable_region");
1209            let requested_format =
1210                get_string("requested_format").and_then(|s| ZetaFormat::parse(&s).ok());
1211            let zed_version = get_string("zed_version");
1212
1213            match (
1214                request_id_str.clone(),
1215                device_id.clone(),
1216                time.clone(),
1217                input.clone(),
1218                requested_output.clone(),
1219                settled_editable_region.clone(),
1220                requested_format,
1221            ) {
1222                (
1223                    Some(request_id),
1224                    Some(device_id),
1225                    Some(time),
1226                    Some(input),
1227                    Some(requested_output),
1228                    Some(settled_editable_region),
1229                    Some(requested_format),
1230                ) => Some(build_settled_example(
1231                    request_id,
1232                    device_id,
1233                    time,
1234                    input,
1235                    requested_output,
1236                    settled_editable_region,
1237                    requested_format,
1238                    zed_version,
1239                )),
1240                _ => {
1241                    let mut missing_fields = Vec::new();
1242
1243                    if request_id_str.is_none() {
1244                        missing_fields.push("request_id");
1245                    }
1246                    if device_id.is_none() {
1247                        missing_fields.push("device_id");
1248                    }
1249                    if time.is_none() {
1250                        missing_fields.push("time");
1251                    }
1252                    if input_raw.is_none() || input_json.is_none() || input.is_none() {
1253                        missing_fields.push("input");
1254                    }
1255                    if requested_output.is_none() {
1256                        missing_fields.push("requested_output");
1257                    }
1258                    if settled_editable_region.is_none() {
1259                        missing_fields.push("settled_editable_region");
1260                    }
1261                    if requested_format.is_none() {
1262                        missing_fields.push("requested_format");
1263                    }
1264
1265                    log::warn!(
1266                        "skipping settled row {row_index}: [{}]",
1267                        missing_fields.join(", "),
1268                    );
1269                    None
1270                }
1271            }
1272        });
1273
1274    Ok(iter)
1275}
1276
1277fn build_settled_example(
1278    request_id: String,
1279    device_id: String,
1280    time: String,
1281    input: ZetaPromptInput,
1282    requested_output: String,
1283    settled_editable_region: String,
1284    requested_format: ZetaFormat,
1285    zed_version: Option<String>,
1286) -> Example {
1287    let requested_editable_range = input
1288        .excerpt_ranges
1289        .as_ref()
1290        .map(|ranges| excerpt_range_for_format(requested_format, ranges).0)
1291        .unwrap_or_else(|| input.editable_range_in_excerpt.clone());
1292
1293    let base_cursor_excerpt = input.cursor_excerpt.to_string();
1294
1295    let requested_range_is_valid = requested_editable_range.start <= requested_editable_range.end
1296        && requested_editable_range.end <= base_cursor_excerpt.len();
1297    let mut example = build_example_from_snowflake(
1298        request_id.clone(),
1299        device_id,
1300        time,
1301        input,
1302        vec!["settled".to_string()],
1303        None,
1304        zed_version,
1305    );
1306
1307    if !requested_range_is_valid {
1308        log::warn!(
1309            "skipping malformed requested range for request {}: requested={:?} (base_len={})",
1310            request_id,
1311            requested_editable_range,
1312            base_cursor_excerpt.len(),
1313        );
1314        return example;
1315    }
1316
1317    let settled_replacement = settled_editable_region.as_str();
1318    let rejected_patch = build_output_patch(
1319        &example.spec.cursor_path,
1320        &base_cursor_excerpt,
1321        &requested_editable_range,
1322        &requested_output,
1323    );
1324    let expected_patch = build_output_patch(
1325        &example.spec.cursor_path,
1326        &base_cursor_excerpt,
1327        &requested_editable_range,
1328        settled_replacement,
1329    );
1330
1331    example.spec.expected_patches = vec![expected_patch];
1332    example.spec.rejected_patch = Some(rejected_patch);
1333    example
1334}
1335
1336fn rejected_examples_from_response<'a>(
1337    response: &'a SnowflakeStatementResponse,
1338    column_indices: &'a std::collections::HashMap<String, usize>,
1339) -> Result<impl Iterator<Item = Example> + 'a> {
1340    if let Some(code) = &response.code {
1341        if code != SNOWFLAKE_SUCCESS_CODE {
1342            anyhow::bail!(
1343                "snowflake sql api returned error code={code} message={}",
1344                response.message.as_deref().unwrap_or("<no message>")
1345            );
1346        }
1347    }
1348
1349    let iter = response
1350        .data
1351        .iter()
1352        .enumerate()
1353        .filter_map(move |(row_index, data_row)| {
1354            let get_string = |name: &str| -> Option<String> {
1355                let index = column_indices.get(name).copied()?;
1356                match data_row.get(index)? {
1357                    JsonValue::String(s) => Some(s.clone()),
1358                    JsonValue::Null => None,
1359                    other => Some(other.to_string()),
1360                }
1361            };
1362
1363            let get_json = |name: &str| -> Option<JsonValue> {
1364                let index = column_indices.get(name).copied()?;
1365                let value = data_row.get(index)?;
1366                if value.is_null() {
1367                    return None;
1368                }
1369                match value {
1370                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1371                    other => Some(other.clone()),
1372                }
1373            };
1374
1375            let get_bool = |name: &str| -> Option<bool> {
1376                let index = column_indices.get(name).copied()?;
1377                match data_row.get(index)? {
1378                    JsonValue::Bool(b) => Some(*b),
1379                    JsonValue::String(s) => s.parse().ok(),
1380                    _ => None,
1381                }
1382            };
1383
1384            let request_id_str = get_string("request_id");
1385            let device_id = get_string("device_id");
1386            let time = get_string("time");
1387            let input_json = get_json("input");
1388            let input: Option<ZetaPromptInput> =
1389                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1390            let output = get_string("output");
1391            let was_shown = get_bool("was_shown");
1392            let reason = get_string("reason");
1393            let zed_version = get_string("zed_version");
1394
1395            match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1396                (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1397                    Some(build_rejected_example(
1398                        request_id,
1399                        device_id,
1400                        time,
1401                        input,
1402                        output,
1403                        was_shown,
1404                        reason,
1405                        zed_version,
1406                    ))
1407                }
1408                _ => {
1409                    log::warn!(
1410                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1411                        request_id_str.is_some(),
1412                        device_id.is_some(),
1413                        time.is_some(),
1414                        input_json.is_some(),
1415                        output.is_some(),
1416                        was_shown.is_some(),
1417                        reason.is_some()
1418                    );
1419                    None
1420                }
1421            }
1422        });
1423
1424    Ok(iter)
1425}
1426
1427fn build_rejected_example(
1428    request_id: String,
1429    device_id: String,
1430    time: String,
1431    input: ZetaPromptInput,
1432    output: String,
1433    was_shown: bool,
1434    reason: String,
1435    zed_version: Option<String>,
1436) -> Example {
1437    let rejected_patch = build_output_patch(
1438        &input.cursor_path,
1439        input.cursor_excerpt.as_ref(),
1440        &input.editable_range_in_excerpt,
1441        &output,
1442    );
1443    let mut example = build_example_from_snowflake(
1444        request_id,
1445        device_id,
1446        time,
1447        input,
1448        vec![format!("rejection:{}", reason.to_lowercase())],
1449        Some(RejectionInfo { reason, was_shown }),
1450        zed_version,
1451    );
1452    example.spec.rejected_patch = Some(rejected_patch);
1453    example
1454}
1455
1456struct RejectionInfo {
1457    reason: String,
1458    was_shown: bool,
1459}
1460
1461fn build_example_from_snowflake(
1462    request_id: String,
1463    device_id: String,
1464    time: String,
1465    input: ZetaPromptInput,
1466    tags: Vec<String>,
1467    rejection: Option<RejectionInfo>,
1468    zed_version: Option<String>,
1469) -> Example {
1470    let cursor_excerpt = input.cursor_excerpt.as_ref();
1471    let cursor_offset = input.cursor_offset_in_excerpt;
1472
1473    let mut edit_history = String::new();
1474    for event in &input.events {
1475        zeta_prompt::write_event(&mut edit_history, event);
1476        edit_history.push('\n');
1477    }
1478
1479    let (rejection_reason, was_shown) = match &rejection {
1480        Some(r) => (r.reason.clone(), r.was_shown),
1481        None => (String::new(), false),
1482    };
1483
1484    let spec = ExampleSpec {
1485        name: request_id.clone(),
1486        repository_url: String::new(),
1487        revision: String::new(),
1488        tags,
1489        reasoning: None,
1490        uncommitted_diff: String::new(),
1491        cursor_path: input.cursor_path.clone(),
1492        cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1493        edit_history,
1494        expected_patches: Vec::new(),
1495        rejected_patch: None,
1496        telemetry: Some(TelemetrySource {
1497            request_id,
1498            device_id,
1499            time,
1500            rejection_reason,
1501            was_shown,
1502        }),
1503        human_feedback: Vec::new(),
1504        rating: None,
1505    };
1506
1507    Example {
1508        spec,
1509        zed_version,
1510        prompt_inputs: Some(input),
1511        prompt: None,
1512        predictions: Vec::new(),
1513        score: Vec::new(),
1514        qa: Vec::new(),
1515        state: None,
1516    }
1517}
1518
1519fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1520    let before = &excerpt[..cursor_offset.min(excerpt.len())];
1521    let after = &excerpt[cursor_offset.min(excerpt.len())..];
1522    format!("{}[CURSOR_POSITION]{}", before, after)
1523}
1524
1525fn build_output_patch(
1526    cursor_path: &std::path::Path,
1527    cursor_excerpt: &str,
1528    editable_range: &std::ops::Range<usize>,
1529    model_output: &str,
1530) -> String {
1531    let old_text = &cursor_excerpt[editable_range.clone()];
1532
1533    let editable_start_row = cursor_excerpt[..editable_range.start]
1534        .chars()
1535        .filter(|&c| c == '\n')
1536        .count() as u32;
1537
1538    let diff_body = language::unified_diff_with_offsets(
1539        old_text,
1540        model_output,
1541        editable_start_row,
1542        editable_start_row,
1543    );
1544
1545    let mut patch = String::new();
1546    writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1547    writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1548    patch.push_str(&diff_body);
1549    patch
1550}
1551
1552pub(crate) fn get_column_indices(
1553    meta: &Option<SnowflakeResultSetMetaData>,
1554    names: &[&str],
1555) -> std::collections::HashMap<String, usize> {
1556    let mut indices = std::collections::HashMap::new();
1557    if let Some(meta) = meta {
1558        for (index, col) in meta.row_type.iter().enumerate() {
1559            for &name in names {
1560                if col.name.eq_ignore_ascii_case(name) {
1561                    indices.insert(name.to_string(), index);
1562                }
1563            }
1564        }
1565    }
1566    indices
1567}