pull_examples.rs

   1use anyhow::{Context as _, Result};
   2use flate2::read::GzDecoder;
   3use gpui::BackgroundExecutor;
   4use http_client::{AsyncBody, HttpClient, Method, Request};
   5use indoc::indoc;
   6use serde::Deserialize;
   7use serde_json::{Value as JsonValue, json};
   8use std::fmt::Write as _;
   9use std::io::Read;
  10use std::sync::Arc;
  11use std::time::Duration;
  12use telemetry_events::EditPredictionRating;
  13
  14use zeta_prompt::{ZetaFormat, ZetaPromptInput, excerpt_range_for_format};
  15
  16use crate::example::Example;
  17use crate::progress::{InfoStyle, Progress, Step};
  18const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
  19use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
  20
  21pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
  22pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
  23const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
  24const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
  25const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
  26const EDIT_PREDICTION_SETTLED_EVENT: &str = "Edit Prediction Settled";
  27
  28/// Minimum Zed version for filtering captured examples.
  29/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
  30/// where `zed_version >= 0.224.1`.
  31#[derive(Clone, Copy, Debug)]
  32pub struct MinCaptureVersion {
  33    pub minor: u32,
  34    pub patch: u32,
  35}
  36
  37const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
  38const SETTLED_STATEMENT_TIMEOUT_SECONDS: u64 = 240;
  39pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
  40pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
  41
  42/// Parse an input token of the form `captured-after:{timestamp}`.
  43pub fn parse_captured_after_input(input: &str) -> Option<&str> {
  44    input.strip_prefix("captured-after:")
  45}
  46
  47/// Parse an input token of the form `rejected-after:{timestamp}`.
  48pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
  49    input.strip_prefix("rejected-after:")
  50}
  51
  52/// Parse an input token of the form `requested-after:{timestamp}`.
  53pub fn parse_requested_after_input(input: &str) -> Option<&str> {
  54    input.strip_prefix("requested-after:")
  55}
  56
  57/// Parse an input token of the form `settled-after:{timestamp}`.
  58pub fn parse_settled_after_input(input: &str) -> Option<&str> {
  59    input.strip_prefix("settled-after:")
  60}
  61
  62/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
  63/// or `rated-negative-after:{timestamp}`.
  64/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
  65pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
  66    if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
  67        Some((timestamp, Some(EditPredictionRating::Positive)))
  68    } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
  69        Some((timestamp, Some(EditPredictionRating::Negative)))
  70    } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
  71        Some((timestamp, None))
  72    } else {
  73        None
  74    }
  75}
  76
  77#[derive(Debug, Clone, Deserialize)]
  78#[serde(rename_all = "camelCase")]
  79pub(crate) struct SnowflakeStatementResponse {
  80    #[serde(default)]
  81    pub(crate) data: Vec<Vec<JsonValue>>,
  82    #[serde(default)]
  83    pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
  84    #[serde(default)]
  85    pub(crate) code: Option<String>,
  86    #[serde(default)]
  87    pub(crate) message: Option<String>,
  88    #[serde(default)]
  89    pub(crate) statement_handle: Option<String>,
  90}
  91
  92#[derive(Debug, Clone, Deserialize)]
  93#[serde(rename_all = "camelCase")]
  94pub(crate) struct SnowflakeResultSetMetaData {
  95    #[serde(default, rename = "rowType")]
  96    row_type: Vec<SnowflakeColumnMeta>,
  97    #[serde(default)]
  98    num_rows: Option<i64>,
  99    #[serde(default)]
 100    partition_info: Vec<SnowflakePartitionInfo>,
 101}
 102
 103#[derive(Debug, Clone, Deserialize)]
 104#[serde(rename_all = "camelCase")]
 105struct SnowflakePartitionInfo {}
 106
 107#[derive(Debug, Clone, Deserialize)]
 108struct SnowflakeColumnMeta {
 109    #[serde(default)]
 110    name: String,
 111}
 112
 113async fn run_sql_with_polling(
 114    http_client: Arc<dyn HttpClient>,
 115    base_url: &str,
 116    token: &str,
 117    request: &serde_json::Value,
 118    step_progress: &crate::progress::StepProgress,
 119    background_executor: BackgroundExecutor,
 120) -> Result<SnowflakeStatementResponse> {
 121    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
 122
 123    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 124        let statement_handle = response
 125            .statement_handle
 126            .as_ref()
 127            .context("async query response missing statementHandle")?
 128            .clone();
 129
 130        for attempt in 1..=MAX_POLL_ATTEMPTS {
 131            step_progress.set_substatus(format!("polling ({attempt})"));
 132
 133            background_executor.timer(POLL_INTERVAL).await;
 134
 135            response =
 136                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
 137
 138            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 139                break;
 140            }
 141        }
 142
 143        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 144            anyhow::bail!(
 145                "query still running after {} poll attempts ({} seconds)",
 146                MAX_POLL_ATTEMPTS,
 147                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
 148            );
 149        }
 150    }
 151
 152    Ok(response)
 153}
 154
 155struct SnowflakeConfig {
 156    token: String,
 157    base_url: String,
 158    role: Option<String>,
 159}
 160
 161async fn fetch_examples_with_query(
 162    http_client: Arc<dyn HttpClient>,
 163    step_progress: &crate::progress::StepProgress,
 164    background_executor: BackgroundExecutor,
 165    statement: &str,
 166    bindings: JsonValue,
 167    timeout_seconds: u64,
 168    required_columns: &[&str],
 169    parse_response: for<'a> fn(
 170        &'a SnowflakeStatementResponse,
 171        &'a std::collections::HashMap<String, usize>,
 172    ) -> Result<Box<dyn Iterator<Item = Example> + 'a>>,
 173) -> Result<Vec<Example>> {
 174    let snowflake = SnowflakeConfig {
 175        token: std::env::var("EP_SNOWFLAKE_API_KEY")
 176            .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?,
 177        base_url: std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 178            "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 179        )?,
 180        role: std::env::var("EP_SNOWFLAKE_ROLE").ok(),
 181    };
 182    let request = json!({
 183        "statement": statement,
 184        "timeout": timeout_seconds,
 185        "database": "EVENTS",
 186        "schema": "PUBLIC",
 187        "warehouse": "DBT",
 188        "role": snowflake.role.as_deref(),
 189        "bindings": bindings
 190    });
 191
 192    let response = run_sql_with_polling(
 193        http_client.clone(),
 194        &snowflake.base_url,
 195        &snowflake.token,
 196        &request,
 197        step_progress,
 198        background_executor,
 199    )
 200    .await?;
 201
 202    let total_rows = response
 203        .result_set_meta_data
 204        .as_ref()
 205        .and_then(|meta| meta.num_rows)
 206        .unwrap_or(response.data.len() as i64);
 207    let partition_count = response
 208        .result_set_meta_data
 209        .as_ref()
 210        .map(|meta| meta.partition_info.len())
 211        .unwrap_or(1)
 212        .max(1);
 213
 214    step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 215    step_progress.set_substatus("parsing");
 216
 217    let column_indices = get_column_indices(&response.result_set_meta_data, required_columns);
 218
 219    let mut parsed_examples = Vec::with_capacity(total_rows as usize);
 220    parsed_examples.extend(parse_response(&response, &column_indices)?);
 221
 222    if partition_count > 1 {
 223        let statement_handle = response
 224            .statement_handle
 225            .as_ref()
 226            .context("response has multiple partitions but no statementHandle")?;
 227
 228        for partition in 1..partition_count {
 229            step_progress.set_substatus(format!(
 230                "fetching partition {}/{}",
 231                partition + 1,
 232                partition_count
 233            ));
 234
 235            let partition_response = fetch_partition(
 236                http_client.clone(),
 237                &snowflake.base_url,
 238                &snowflake.token,
 239                statement_handle,
 240                partition,
 241            )
 242            .await?;
 243
 244            parsed_examples.extend(parse_response(&partition_response, &column_indices)?);
 245        }
 246    }
 247
 248    step_progress.set_substatus("done");
 249    Ok(parsed_examples)
 250}
 251
 252pub(crate) async fn fetch_partition(
 253    http_client: Arc<dyn HttpClient>,
 254    base_url: &str,
 255    token: &str,
 256    statement_handle: &str,
 257    partition: usize,
 258) -> Result<SnowflakeStatementResponse> {
 259    let url = format!(
 260        "{}/api/v2/statements/{}?partition={}",
 261        base_url.trim_end_matches('/'),
 262        statement_handle,
 263        partition
 264    );
 265
 266    let http_request = Request::builder()
 267        .method(Method::GET)
 268        .uri(url.as_str())
 269        .header("Authorization", format!("Bearer {token}"))
 270        .header(
 271            "X-Snowflake-Authorization-Token-Type",
 272            "PROGRAMMATIC_ACCESS_TOKEN",
 273        )
 274        .header("Accept", "application/json")
 275        .header("Accept-Encoding", "gzip")
 276        .header("User-Agent", "edit_prediction_cli")
 277        .body(AsyncBody::empty())?;
 278
 279    let response = http_client
 280        .send(http_request)
 281        .await
 282        .context("failed to send partition request to Snowflake SQL API")?;
 283
 284    let status = response.status();
 285    let content_encoding = response
 286        .headers()
 287        .get("content-encoding")
 288        .and_then(|v| v.to_str().ok())
 289        .map(|s| s.to_lowercase());
 290
 291    let body_bytes = {
 292        use futures::AsyncReadExt as _;
 293
 294        let mut body = response.into_body();
 295        let mut bytes = Vec::new();
 296        body.read_to_end(&mut bytes)
 297            .await
 298            .context("failed to read Snowflake SQL API partition response body")?;
 299        bytes
 300    };
 301
 302    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
 303        let mut decoder = GzDecoder::new(&body_bytes[..]);
 304        let mut decompressed = Vec::new();
 305        decoder
 306            .read_to_end(&mut decompressed)
 307            .context("failed to decompress gzip response")?;
 308        decompressed
 309    } else {
 310        body_bytes
 311    };
 312
 313    if !status.is_success() && status.as_u16() != 202 {
 314        let body_text = String::from_utf8_lossy(&body_bytes);
 315        anyhow::bail!(
 316            "snowflake sql api partition request http {}: {}",
 317            status.as_u16(),
 318            body_text
 319        );
 320    }
 321
 322    if body_bytes.is_empty() {
 323        anyhow::bail!(
 324            "snowflake sql api partition {} returned empty response body (http {})",
 325            partition,
 326            status.as_u16()
 327        );
 328    }
 329
 330    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
 331        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
 332        format!(
 333            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
 334            partition,
 335            status.as_u16(),
 336            body_preview
 337        )
 338    })
 339}
 340
 341pub(crate) async fn run_sql(
 342    http_client: Arc<dyn HttpClient>,
 343    base_url: &str,
 344    token: &str,
 345    request: &serde_json::Value,
 346) -> Result<SnowflakeStatementResponse> {
 347    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
 348
 349    let request_body =
 350        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
 351
 352    let http_request = Request::builder()
 353        .method(Method::POST)
 354        .uri(url.as_str())
 355        .header("Authorization", format!("Bearer {token}"))
 356        .header(
 357            "X-Snowflake-Authorization-Token-Type",
 358            "PROGRAMMATIC_ACCESS_TOKEN",
 359        )
 360        .header("Content-Type", "application/json")
 361        .header("Accept", "application/json")
 362        .header("User-Agent", "edit_prediction_cli")
 363        .body(AsyncBody::from(request_body.clone()))?;
 364
 365    let response = http_client
 366        .send(http_request)
 367        .await
 368        .context("failed to send request to Snowflake SQL API")?;
 369
 370    let status = response.status();
 371    let body_bytes = {
 372        use futures::AsyncReadExt as _;
 373
 374        let mut body = response.into_body();
 375        let mut bytes = Vec::new();
 376        body.read_to_end(&mut bytes)
 377            .await
 378            .context("failed to read Snowflake SQL API response body")?;
 379        bytes
 380    };
 381
 382    if !status.is_success() && status.as_u16() != 202 {
 383        let body_text = String::from_utf8_lossy(&body_bytes);
 384        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
 385    }
 386
 387    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
 388        .context("failed to parse Snowflake SQL API response JSON")
 389}
 390
 391pub async fn fetch_rejected_examples_after(
 392    http_client: Arc<dyn HttpClient>,
 393    after_timestamps: &[String],
 394    max_rows_per_timestamp: usize,
 395    offset: usize,
 396    background_executor: BackgroundExecutor,
 397    min_capture_version: Option<MinCaptureVersion>,
 398) -> Result<Vec<Example>> {
 399    if after_timestamps.is_empty() {
 400        return Ok(Vec::new());
 401    }
 402
 403    let progress = Progress::global();
 404
 405    let mut all_examples = Vec::new();
 406
 407    for after_date in after_timestamps.iter() {
 408        let step_progress_name = format!("rejected>{after_date}");
 409        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 410        step_progress.set_substatus("querying");
 411
 412        let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
 413        let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
 414        let min_minor_str_ref = min_minor_str.as_deref();
 415        let min_patch_str_ref = min_patch_str.as_deref();
 416
 417        let statement = indoc! {r#"
 418            SELECT
 419                req.event_properties:request_id::string AS request_id,
 420                req.device_id::string AS device_id,
 421                req.time::string AS time,
 422                req.event_properties:input AS input,
 423                req.event_properties:prompt::string AS prompt,
 424                req.event_properties:output::string AS output,
 425                rej.event_properties:was_shown::boolean AS was_shown,
 426                rej.event_properties:reason::string AS reason,
 427                req.event_properties:zed_version::string AS zed_version
 428            FROM events req
 429            INNER JOIN events rej
 430                ON req.event_properties:request_id = rej.event_properties:request_id
 431            WHERE req.event_type = ?
 432                AND rej.event_type = ?
 433                AND req.event_properties:version = 'V3'
 434                AND rej.event_properties:was_shown = true
 435                AND req.event_properties:input:can_collect_data = true
 436                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 437                AND (? IS NULL OR (
 438                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 439                    OR (
 440                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 441                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 442                    )
 443                ))
 444            ORDER BY req.time ASC
 445            LIMIT ?
 446            OFFSET ?
 447        "#};
 448
 449        let bindings = json!({
 450            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 451            "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
 452            "3": { "type": "TEXT", "value": after_date },
 453            "4": { "type": "FIXED", "value": min_minor_str_ref },
 454            "5": { "type": "FIXED", "value": min_minor_str_ref },
 455            "6": { "type": "FIXED", "value": min_minor_str_ref },
 456            "7": { "type": "FIXED", "value": min_patch_str_ref },
 457            "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 458            "9": { "type": "FIXED", "value": offset.to_string() }
 459        });
 460
 461        let examples = fetch_examples_with_query(
 462            http_client.clone(),
 463            &step_progress,
 464            background_executor.clone(),
 465            statement,
 466            bindings,
 467            DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 468            &[
 469                "request_id",
 470                "device_id",
 471                "time",
 472                "input",
 473                "prompt",
 474                "output",
 475                "was_shown",
 476                "reason",
 477                "zed_version",
 478            ],
 479            rejected_examples_from_response,
 480        )
 481        .await?;
 482
 483        all_examples.extend(examples);
 484    }
 485
 486    Ok(all_examples)
 487}
 488
 489pub async fn fetch_requested_examples_after(
 490    http_client: Arc<dyn HttpClient>,
 491    after_timestamps: &[String],
 492    max_rows_per_timestamp: usize,
 493    offset: usize,
 494    background_executor: BackgroundExecutor,
 495    min_capture_version: Option<MinCaptureVersion>,
 496) -> Result<Vec<Example>> {
 497    if after_timestamps.is_empty() {
 498        return Ok(Vec::new());
 499    }
 500
 501    let progress = Progress::global();
 502
 503    let mut all_examples = Vec::new();
 504
 505    for after_date in after_timestamps.iter() {
 506        let step_progress_name = format!("requested>{after_date}");
 507        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 508        step_progress.set_substatus("querying");
 509
 510        let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
 511        let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
 512        let min_minor_str_ref = min_minor_str.as_deref();
 513        let min_patch_str_ref = min_patch_str.as_deref();
 514
 515        let statement = indoc! {r#"
 516            SELECT
 517                req.event_properties:request_id::string AS request_id,
 518                req.device_id::string AS device_id,
 519                req.time::string AS time,
 520                req.event_properties:input AS input,
 521                req.event_properties:zed_version::string AS zed_version
 522            FROM events req
 523            WHERE req.event_type = ?
 524                AND req.event_properties:version = 'V3'
 525                AND req.event_properties:input:can_collect_data = true
 526                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 527                AND (? IS NULL OR (
 528                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 529                    OR (
 530                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 531                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 532                    )
 533                ))
 534            ORDER BY req.time ASC
 535            LIMIT ?
 536            OFFSET ?
 537        "#};
 538
 539        let bindings = json!({
 540            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 541            "2": { "type": "TEXT", "value": after_date },
 542            "3": { "type": "FIXED", "value": min_minor_str_ref },
 543            "4": { "type": "FIXED", "value": min_minor_str_ref },
 544            "5": { "type": "FIXED", "value": min_minor_str_ref },
 545            "6": { "type": "FIXED", "value": min_patch_str_ref },
 546            "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 547            "8": { "type": "FIXED", "value": offset.to_string() }
 548        });
 549
 550        let examples = fetch_examples_with_query(
 551            http_client.clone(),
 552            &step_progress,
 553            background_executor.clone(),
 554            statement,
 555            bindings,
 556            DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 557            &["request_id", "device_id", "time", "input", "zed_version"],
 558            requested_examples_from_response,
 559        )
 560        .await?;
 561
 562        all_examples.extend(examples);
 563    }
 564
 565    Ok(all_examples)
 566}
 567
 568pub async fn fetch_captured_examples_after(
 569    http_client: Arc<dyn HttpClient>,
 570    after_timestamps: &[String],
 571    max_rows_per_timestamp: usize,
 572    offset: usize,
 573    background_executor: BackgroundExecutor,
 574    min_capture_version: Option<MinCaptureVersion>,
 575) -> Result<Vec<Example>> {
 576    if after_timestamps.is_empty() {
 577        return Ok(Vec::new());
 578    }
 579
 580    let progress = Progress::global();
 581
 582    let mut all_examples = Vec::new();
 583
 584    for after_date in after_timestamps.iter() {
 585        let step_progress_name = format!("captured>{after_date}");
 586        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 587        step_progress.set_substatus("querying");
 588
 589        let min_minor_str = min_capture_version.map(|version| version.minor.to_string());
 590        let min_patch_str = min_capture_version.map(|version| version.patch.to_string());
 591        let min_minor_str_ref = min_minor_str.as_deref();
 592        let min_patch_str_ref = min_patch_str.as_deref();
 593
 594        let statement = indoc! {r#"
 595            SELECT
 596                settled.event_properties:request_id::string AS request_id,
 597                settled.device_id::string AS device_id,
 598                settled.time::string AS time,
 599                req.event_properties:input AS input,
 600                settled.event_properties:settled_editable_region::string AS settled_editable_region,
 601                settled.event_properties:example AS example,
 602                req.event_properties:zed_version::string AS zed_version
 603            FROM events settled
 604            INNER JOIN events req
 605                ON settled.event_properties:request_id::string = req.event_properties:request_id::string
 606            WHERE settled.event_type = ?
 607                AND req.event_type = ?
 608                AND req.event_properties:version = 'V3'
 609                AND req.event_properties:input:can_collect_data = true
 610                AND settled.event_properties:example IS NOT NULL
 611                AND TYPEOF(settled.event_properties:example) != 'NULL_VALUE'
 612                AND settled.time > TRY_TO_TIMESTAMP_NTZ(?)
 613                AND (? IS NULL OR (
 614                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 615                    OR (
 616                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 617                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 618                    )
 619                ))
 620            ORDER BY settled.time ASC
 621            LIMIT ?
 622            OFFSET ?
 623        "#};
 624
 625        let bindings = json!({
 626            "1": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
 627            "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 628            "3": { "type": "TEXT", "value": after_date },
 629            "4": { "type": "FIXED", "value": min_minor_str_ref },
 630            "5": { "type": "FIXED", "value": min_minor_str_ref },
 631            "6": { "type": "FIXED", "value": min_minor_str_ref },
 632            "7": { "type": "FIXED", "value": min_patch_str_ref },
 633            "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 634            "9": { "type": "FIXED", "value": offset.to_string() }
 635        });
 636
 637        let examples = fetch_examples_with_query(
 638            http_client.clone(),
 639            &step_progress,
 640            background_executor.clone(),
 641            statement,
 642            bindings,
 643            DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 644            &[
 645                "request_id",
 646                "device_id",
 647                "time",
 648                "input",
 649                "settled_editable_region",
 650                "example",
 651                "zed_version",
 652            ],
 653            captured_examples_from_response,
 654        )
 655        .await?;
 656
 657        all_examples.extend(examples);
 658    }
 659
 660    Ok(all_examples)
 661}
 662
 663pub async fn fetch_settled_examples_after(
 664    http_client: Arc<dyn HttpClient>,
 665    after_timestamps: &[String],
 666    max_rows_per_timestamp: usize,
 667    offset: usize,
 668    background_executor: BackgroundExecutor,
 669    min_capture_version: Option<MinCaptureVersion>,
 670) -> Result<Vec<Example>> {
 671    if after_timestamps.is_empty() {
 672        return Ok(Vec::new());
 673    }
 674
 675    let progress = Progress::global();
 676
 677    let mut all_examples = Vec::new();
 678
 679    for after_date in after_timestamps.iter() {
 680        let step_progress_name = format!("settled>{after_date}");
 681        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 682        step_progress.set_substatus("querying");
 683
 684        let _ = min_capture_version;
 685
 686        let statement = indoc! {r#"
 687            WITH requested AS (
 688                SELECT
 689                    req.event_properties:request_id::string AS request_id,
 690                    req.device_id::string AS device_id,
 691                    req.time AS req_time,
 692                    req.time::string AS time,
 693                    req.event_properties:input AS input,
 694                    req.event_properties:format::string AS requested_format,
 695                    req.event_properties:output::string AS requested_output,
 696                    req.event_properties:zed_version::string AS zed_version
 697                FROM events req
 698                WHERE req.event_type = ?
 699                    AND req.event_properties:version = 'V3'
 700                    AND req.event_properties:input:can_collect_data = true
 701                    AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 702            )
 703            SELECT
 704                req.request_id AS request_id,
 705                req.device_id AS device_id,
 706                req.time AS time,
 707                req.input AS input,
 708                req.requested_output AS requested_output,
 709                settled.event_properties:settled_editable_region::string AS settled_editable_region,
 710                req.requested_format AS requested_format,
 711                req.zed_version AS zed_version
 712            FROM requested req
 713            INNER JOIN events settled
 714                ON req.request_id = settled.event_properties:request_id::string
 715            WHERE settled.event_type = ?
 716            ORDER BY req.req_time ASC
 717            LIMIT ?
 718            OFFSET ?
 719        "#};
 720
 721        let bindings = json!({
 722            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 723            "2": { "type": "TEXT", "value": after_date },
 724            "3": { "type": "TEXT", "value": EDIT_PREDICTION_SETTLED_EVENT },
 725            "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 726            "5": { "type": "FIXED", "value": offset.to_string() }
 727        });
 728
 729        let examples = fetch_examples_with_query(
 730            http_client.clone(),
 731            &step_progress,
 732            background_executor.clone(),
 733            statement,
 734            bindings,
 735            SETTLED_STATEMENT_TIMEOUT_SECONDS,
 736            &[
 737                "request_id",
 738                "device_id",
 739                "time",
 740                "input",
 741                "requested_output",
 742                "settled_editable_region",
 743                "requested_format",
 744                "zed_version",
 745            ],
 746            settled_examples_from_response,
 747        )
 748        .await?;
 749
 750        all_examples.extend(examples);
 751    }
 752
 753    Ok(all_examples)
 754}
 755
 756pub async fn fetch_rated_examples_after(
 757    http_client: Arc<dyn HttpClient>,
 758    inputs: &[(String, Option<EditPredictionRating>)],
 759    max_rows_per_timestamp: usize,
 760    offset: usize,
 761    background_executor: BackgroundExecutor,
 762    _min_capture_version: Option<MinCaptureVersion>,
 763) -> Result<Vec<Example>> {
 764    if inputs.is_empty() {
 765        return Ok(Vec::new());
 766    }
 767
 768    let progress = Progress::global();
 769
 770    let mut all_examples = Vec::new();
 771
 772    for (after_date, rating_filter) in inputs.iter() {
 773        let filter_label = match rating_filter {
 774            None => "",
 775            Some(EditPredictionRating::Positive) => ":positive",
 776            Some(EditPredictionRating::Negative) => ":negative",
 777        };
 778        let step_progress_name = format!("rated{filter_label}>{after_date}");
 779        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 780        step_progress.set_substatus("querying");
 781
 782        let rating_value = rating_filter.as_ref().map(|rating| match rating {
 783            EditPredictionRating::Positive => "Positive",
 784            EditPredictionRating::Negative => "Negative",
 785        });
 786
 787        let statement = indoc! {r#"
 788            SELECT
 789                rated.event_properties:request_id::string AS request_id,
 790                rated.event_properties:inputs AS inputs,
 791                rated.event_properties:output::string AS output,
 792                rated.event_properties:rating::string AS rating,
 793                rated.event_properties:feedback::string AS feedback,
 794                rated.device_id::string AS device_id,
 795                rated.time::string AS time,
 796                deploy.event_properties:experiment_name::string AS experiment_name,
 797                deploy.event_properties:environment::string AS environment,
 798                rated.event_properties:zed_version::string AS zed_version
 799            FROM events rated
 800            LEFT JOIN events req
 801                ON rated.event_properties:request_id::string = req.event_properties:request_id::string
 802                AND req.event_type = ?
 803            LEFT JOIN events deploy
 804                ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
 805                AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
 806                AND deploy.event_type = ?
 807            WHERE rated.event_type = ?
 808                AND (? IS NULL OR rated.event_properties:rating::string = ?)
 809                AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
 810                AND rated.event_properties:inputs IS NOT NULL
 811                AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
 812                AND rated.event_properties:output IS NOT NULL
 813                AND rated.event_properties:inputs:can_collect_data = true
 814            ORDER BY rated.time ASC
 815            LIMIT ?
 816            OFFSET ?
 817        "#};
 818
 819        let bindings = json!({
 820            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 821            "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
 822            "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
 823            "4": { "type": "TEXT", "value": rating_value },
 824            "5": { "type": "TEXT", "value": rating_value },
 825            "6": { "type": "TEXT", "value": after_date },
 826            "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 827            "8": { "type": "FIXED", "value": offset.to_string() }
 828        });
 829
 830        let examples = fetch_examples_with_query(
 831            http_client.clone(),
 832            &step_progress,
 833            background_executor.clone(),
 834            statement,
 835            bindings,
 836            DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 837            &[
 838                "request_id",
 839                "inputs",
 840                "output",
 841                "rating",
 842                "feedback",
 843                "device_id",
 844                "time",
 845                "experiment_name",
 846                "environment",
 847                "zed_version",
 848            ],
 849            rated_examples_from_response,
 850        )
 851        .await?;
 852
 853        all_examples.extend(examples);
 854    }
 855
 856    Ok(all_examples)
 857}
 858
 859fn rated_examples_from_response<'a>(
 860    response: &'a SnowflakeStatementResponse,
 861    column_indices: &'a std::collections::HashMap<String, usize>,
 862) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
 863    if let Some(code) = &response.code {
 864        if code != SNOWFLAKE_SUCCESS_CODE {
 865            anyhow::bail!(
 866                "snowflake sql api returned error code={code} message={}",
 867                response.message.as_deref().unwrap_or("<no message>")
 868            );
 869        }
 870    }
 871
 872    let iter = response
 873        .data
 874        .iter()
 875        .enumerate()
 876        .filter_map(move |(row_index, data_row)| {
 877            let get_string = |name: &str| -> Option<String> {
 878                let index = column_indices.get(name).copied()?;
 879                match data_row.get(index)? {
 880                    JsonValue::String(s) => Some(s.clone()),
 881                    JsonValue::Null => None,
 882                    other => Some(other.to_string()),
 883                }
 884            };
 885
 886            let get_json = |name: &str| -> Option<JsonValue> {
 887                let index = column_indices.get(name).copied()?;
 888                let value = data_row.get(index)?;
 889                if value.is_null() {
 890                    return None;
 891                }
 892                match value {
 893                    JsonValue::String(s) => serde_json::from_str(s).ok(),
 894                    other => Some(other.clone()),
 895                }
 896            };
 897
 898            let request_id = get_string("request_id");
 899            let inputs_json = get_json("inputs");
 900            let inputs: Option<ZetaPromptInput> = match &inputs_json {
 901                Some(v) => match serde_json::from_value(v.clone()) {
 902                    Ok(parsed) => Some(parsed),
 903                    Err(e) => {
 904                        log::warn!(
 905                            "skipping row {row_index}: failed to parse inputs - {e}",
 906                        );
 907                        return None;
 908                    }
 909                },
 910                None => None,
 911            };
 912            let output = get_string("output");
 913            let rating = get_string("rating");
 914            let feedback = get_string("feedback").unwrap_or_default();
 915            let device_id = get_string("device_id");
 916            let time = get_string("time");
 917            let experiment_name = get_string("experiment_name");
 918            let environment = get_string("environment");
 919            let zed_version = get_string("zed_version");
 920
 921            match (inputs, output.clone(), rating.clone(), time.clone()) {
 922                (Some(inputs), Some(output), Some(rating), Some(time)) => {
 923                    Some(build_rated_example(
 924                        request_id,
 925                        device_id.unwrap_or_default(),
 926                        time,
 927                        inputs,
 928                        output,
 929                        rating,
 930                        feedback,
 931                        experiment_name,
 932                        environment,
 933                        zed_version,
 934                    ))
 935                }
 936                _ => {
 937                    log::warn!(
 938                        "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} time={:?}",
 939                        inputs_json.is_some(),
 940                        output.is_some(),
 941                        rating.is_some(),
 942                        time.is_some(),
 943                    );
 944                    None
 945                }
 946            }
 947        });
 948
 949    Ok(Box::new(iter))
 950}
 951
 952fn build_rated_example(
 953    request_id: Option<String>,
 954    device_id: String,
 955    time: String,
 956    input: ZetaPromptInput,
 957    output: String,
 958    rating: String,
 959    feedback: String,
 960    experiment_name: Option<String>,
 961    environment: Option<String>,
 962    zed_version: Option<String>,
 963) -> Example {
 964    let parsed_rating = if rating == "Positive" {
 965        EditPredictionRating::Positive
 966    } else {
 967        EditPredictionRating::Negative
 968    };
 969    let is_positive = parsed_rating == EditPredictionRating::Positive;
 970    let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
 971
 972    let mut tags = Vec::with_capacity(3);
 973    tags.push(if is_positive {
 974        "rated:positive".to_string()
 975    } else {
 976        "rated:negative".to_string()
 977    });
 978    if let Some(experiment) = experiment_name {
 979        tags.push(format!("experiment:{experiment}"));
 980    }
 981    if let Some(env) = environment {
 982        tags.push(format!("environment:{env}"));
 983    }
 984
 985    let mut example =
 986        build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
 987
 988    example.spec.rating = Some(parsed_rating);
 989
 990    if !feedback.is_empty() {
 991        example
 992            .spec
 993            .human_feedback
 994            .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
 995    }
 996
 997    if is_positive {
 998        example.spec.expected_patches = vec![output];
 999    } else {
1000        example.spec.rejected_patch = Some(output);
1001    }
1002
1003    example
1004}
1005
1006fn requested_examples_from_response<'a>(
1007    response: &'a SnowflakeStatementResponse,
1008    column_indices: &'a std::collections::HashMap<String, usize>,
1009) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1010    if let Some(code) = &response.code {
1011        if code != SNOWFLAKE_SUCCESS_CODE {
1012            anyhow::bail!(
1013                "snowflake sql api returned error code={code} message={}",
1014                response.message.as_deref().unwrap_or("<no message>")
1015            );
1016        }
1017    }
1018
1019    let iter = response
1020        .data
1021        .iter()
1022        .enumerate()
1023        .filter_map(move |(row_index, data_row)| {
1024            let get_string = |name: &str| -> Option<String> {
1025                let index = column_indices.get(name).copied()?;
1026                match data_row.get(index)? {
1027                    JsonValue::String(s) => Some(s.clone()),
1028                    JsonValue::Null => None,
1029                    other => Some(other.to_string()),
1030                }
1031            };
1032
1033            let get_json = |name: &str| -> Option<JsonValue> {
1034                let index = column_indices.get(name).copied()?;
1035                let value = data_row.get(index)?;
1036                if value.is_null() {
1037                    return None;
1038                }
1039                match value {
1040                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1041                    other => Some(other.clone()),
1042                }
1043            };
1044
1045            let request_id_str = get_string("request_id");
1046            let device_id = get_string("device_id");
1047            let time = get_string("time");
1048            let input_json = get_json("input");
1049            let input: Option<ZetaPromptInput> =
1050                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1051            let zed_version = get_string("zed_version");
1052
1053            match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1054                (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1055                    Some(build_example_from_snowflake(
1056                        request_id,
1057                        device_id,
1058                        time,
1059                        input,
1060                        vec!["requested".to_string()],
1061                        None,
1062                        zed_version,
1063                    ))
1064                }
1065                _ => {
1066                    log::warn!(
1067                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1068                        request_id_str.is_some(),
1069                        device_id.is_some(),
1070                        time.is_some(),
1071                        input_json.is_some(),
1072                    );
1073                    None
1074                }
1075            }
1076        });
1077
1078    Ok(Box::new(iter))
1079}
1080
1081fn settled_examples_from_response<'a>(
1082    response: &'a SnowflakeStatementResponse,
1083    column_indices: &'a std::collections::HashMap<String, usize>,
1084) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1085    if let Some(code) = &response.code {
1086        if code != SNOWFLAKE_SUCCESS_CODE {
1087            anyhow::bail!(
1088                "snowflake sql api returned error code={code} message={}",
1089                response.message.as_deref().unwrap_or("<no message>")
1090            );
1091        }
1092    }
1093
1094    let iter = response
1095        .data
1096        .iter()
1097        .enumerate()
1098        .filter_map(move |(row_index, data_row)| {
1099            let get_value = |name: &str| -> Option<JsonValue> {
1100                let index = column_indices.get(name).copied()?;
1101                let value = data_row.get(index)?;
1102                if value.is_null() {
1103                    None
1104                } else {
1105                    Some(value.clone())
1106                }
1107            };
1108
1109            let get_string = |name: &str| -> Option<String> {
1110                match get_value(name)? {
1111                    JsonValue::String(s) => Some(s),
1112                    other => Some(other.to_string()),
1113                }
1114            };
1115
1116            let parse_json_value = |raw: Option<&JsonValue>| -> Option<JsonValue> {
1117                let value = raw?;
1118                match value {
1119                    JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
1120                    other => Some(other.clone()),
1121                }
1122            };
1123
1124            let request_id_str = get_string("request_id");
1125            let device_id = get_string("device_id");
1126            let time = get_string("time");
1127            let input_raw = get_value("input");
1128            let input_json = parse_json_value(input_raw.as_ref());
1129            let input: Option<ZetaPromptInput> = input_json
1130                .as_ref()
1131                .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
1132            let requested_output = get_string("requested_output");
1133            let settled_editable_region = get_string("settled_editable_region");
1134            let requested_format =
1135                get_string("requested_format").and_then(|s| ZetaFormat::parse(&s).ok());
1136            let zed_version = get_string("zed_version");
1137
1138            match (
1139                request_id_str.clone(),
1140                device_id.clone(),
1141                time.clone(),
1142                input.clone(),
1143                requested_output.clone(),
1144                settled_editable_region.clone(),
1145                requested_format,
1146            ) {
1147                (
1148                    Some(request_id),
1149                    Some(device_id),
1150                    Some(time),
1151                    Some(input),
1152                    Some(requested_output),
1153                    Some(settled_editable_region),
1154                    Some(requested_format),
1155                ) => Some(build_settled_example(
1156                    request_id,
1157                    device_id,
1158                    time,
1159                    input,
1160                    requested_output,
1161                    settled_editable_region,
1162                    requested_format,
1163                    zed_version,
1164                )),
1165                _ => {
1166                    let mut missing_fields = Vec::new();
1167
1168                    if request_id_str.is_none() {
1169                        missing_fields.push("request_id");
1170                    }
1171                    if device_id.is_none() {
1172                        missing_fields.push("device_id");
1173                    }
1174                    if time.is_none() {
1175                        missing_fields.push("time");
1176                    }
1177                    if input_raw.is_none() || input_json.is_none() || input.is_none() {
1178                        missing_fields.push("input");
1179                    }
1180                    if requested_output.is_none() {
1181                        missing_fields.push("requested_output");
1182                    }
1183                    if settled_editable_region.is_none() {
1184                        missing_fields.push("settled_editable_region");
1185                    }
1186                    if requested_format.is_none() {
1187                        missing_fields.push("requested_format");
1188                    }
1189
1190                    log::warn!(
1191                        "skipping settled row {row_index}: [{}]",
1192                        missing_fields.join(", "),
1193                    );
1194                    None
1195                }
1196            }
1197        });
1198
1199    Ok(Box::new(iter))
1200}
1201
1202fn captured_examples_from_response<'a>(
1203    response: &'a SnowflakeStatementResponse,
1204    column_indices: &'a std::collections::HashMap<String, usize>,
1205) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1206    if let Some(code) = &response.code {
1207        if code != SNOWFLAKE_SUCCESS_CODE {
1208            anyhow::bail!(
1209                "snowflake sql api returned error code={code} message={}",
1210                response.message.as_deref().unwrap_or("<no message>")
1211            );
1212        }
1213    }
1214
1215    let iter = response
1216        .data
1217        .iter()
1218        .enumerate()
1219        .filter_map(move |(row_index, data_row)| {
1220            let get_value = |name: &str| -> Option<JsonValue> {
1221                let index = column_indices.get(name).copied()?;
1222                let value = data_row.get(index)?;
1223                if value.is_null() {
1224                    None
1225                } else {
1226                    Some(value.clone())
1227                }
1228            };
1229
1230            let get_string = |name: &str| -> Option<String> {
1231                match get_value(name)? {
1232                    JsonValue::String(s) => Some(s),
1233                    other => Some(other.to_string()),
1234                }
1235            };
1236
1237            let parse_json_value = |raw: Option<&JsonValue>| -> Option<JsonValue> {
1238                let value = raw?;
1239                match value {
1240                    JsonValue::String(s) => serde_json::from_str::<JsonValue>(s).ok(),
1241                    other => Some(other.clone()),
1242                }
1243            };
1244
1245            let request_id = get_string("request_id");
1246            let device_id = get_string("device_id");
1247            let time = get_string("time");
1248            let input_raw = get_value("input");
1249            let input_json = parse_json_value(input_raw.as_ref());
1250            let input: Option<ZetaPromptInput> = input_json
1251                .as_ref()
1252                .and_then(|parsed| serde_json::from_value(parsed.clone()).ok());
1253            let example_raw = get_value("example");
1254            let example_json = parse_json_value(example_raw.as_ref());
1255            let example_spec: Option<ExampleSpec> = example_json.as_ref().and_then(|parsed| {
1256                serde_json::from_value(parsed.clone())
1257                    .or_else(|_| {
1258                        parsed
1259                            .as_str()
1260                            .and_then(|markdown| ExampleSpec::from_markdown(markdown).ok())
1261                            .ok_or_else(|| {
1262                                serde_json::Error::io(std::io::Error::other("not markdown"))
1263                            })
1264                    })
1265                    .ok()
1266            });
1267            let has_example_spec = example_spec.is_some();
1268            let settled_editable_region = get_string("settled_editable_region");
1269            let zed_version = get_string("zed_version");
1270
1271            match (
1272                request_id.clone(),
1273                device_id.clone(),
1274                time.clone(),
1275                input.clone(),
1276                example_spec,
1277                settled_editable_region.clone(),
1278            ) {
1279                (
1280                    Some(request_id),
1281                    Some(device_id),
1282                    Some(time),
1283                    Some(input),
1284                    Some(example_spec),
1285                    Some(settled_editable_region),
1286                ) => Some(build_captured_example(
1287                    request_id,
1288                    device_id,
1289                    time,
1290                    input,
1291                    example_spec,
1292                    settled_editable_region,
1293                    zed_version,
1294                )),
1295                _ => {
1296                    let mut missing_fields = Vec::new();
1297
1298                    if request_id.is_none() {
1299                        missing_fields.push("request_id");
1300                    }
1301                    if device_id.is_none() {
1302                        missing_fields.push("device_id");
1303                    }
1304                    if time.is_none() {
1305                        missing_fields.push("time");
1306                    }
1307                    if input_raw.is_none() || input_json.is_none() || input.is_none() {
1308                        missing_fields.push("input");
1309                    }
1310                    if example_raw.is_none() || !has_example_spec {
1311                        missing_fields.push("example");
1312                    }
1313                    if settled_editable_region.is_none() {
1314                        missing_fields.push("settled_editable_region");
1315                    }
1316
1317                    log::warn!(
1318                        "skipping captured row {row_index}: [{}]",
1319                        missing_fields.join(", "),
1320                    );
1321                    None
1322                }
1323            }
1324        });
1325
1326    Ok(Box::new(iter))
1327}
1328
1329fn build_settled_example(
1330    request_id: String,
1331    device_id: String,
1332    time: String,
1333    input: ZetaPromptInput,
1334    requested_output: String,
1335    settled_editable_region: String,
1336    requested_format: ZetaFormat,
1337    zed_version: Option<String>,
1338) -> Example {
1339    let requested_editable_range =
1340        excerpt_range_for_format(requested_format, &input.excerpt_ranges).0;
1341
1342    let base_cursor_excerpt = input.cursor_excerpt.to_string();
1343
1344    let requested_range_is_valid = requested_editable_range.start <= requested_editable_range.end
1345        && requested_editable_range.end <= base_cursor_excerpt.len();
1346    let mut example = build_example_from_snowflake(
1347        request_id.clone(),
1348        device_id,
1349        time,
1350        input,
1351        vec!["settled".to_string()],
1352        None,
1353        zed_version,
1354    );
1355
1356    if !requested_range_is_valid {
1357        log::warn!(
1358            "skipping malformed requested range for request {}: requested={:?} (base_len={})",
1359            request_id,
1360            requested_editable_range,
1361            base_cursor_excerpt.len(),
1362        );
1363        return example;
1364    }
1365
1366    let settled_replacement = settled_editable_region.as_str();
1367    let rejected_patch = build_output_patch(
1368        &example.spec.cursor_path,
1369        &base_cursor_excerpt,
1370        &requested_editable_range,
1371        &requested_output,
1372    );
1373    let expected_patch = build_output_patch(
1374        &example.spec.cursor_path,
1375        &base_cursor_excerpt,
1376        &requested_editable_range,
1377        settled_replacement,
1378    );
1379
1380    example.spec.expected_patches = vec![expected_patch];
1381    example.spec.rejected_patch = Some(rejected_patch);
1382    example
1383}
1384
1385fn build_captured_example(
1386    request_id: String,
1387    device_id: String,
1388    time: String,
1389    input: ZetaPromptInput,
1390    mut example_spec: ExampleSpec,
1391    settled_editable_region: String,
1392    zed_version: Option<String>,
1393) -> Example {
1394    let expected_patch = build_output_patch(
1395        &input.cursor_path,
1396        input.cursor_excerpt.as_ref(),
1397        &input.excerpt_ranges.editable_350,
1398        settled_editable_region.as_str(),
1399    );
1400
1401    example_spec.expected_patches = vec![expected_patch];
1402    example_spec.telemetry = Some(TelemetrySource {
1403        request_id,
1404        device_id,
1405        time,
1406        rejection_reason: String::new(),
1407        was_shown: false,
1408    });
1409
1410    Example {
1411        spec: example_spec,
1412        zed_version,
1413        prompt_inputs: Some(input),
1414        prompt: None,
1415        predictions: Vec::new(),
1416        score: Vec::new(),
1417        qa: Vec::new(),
1418        state: None,
1419    }
1420}
1421
1422fn rejected_examples_from_response<'a>(
1423    response: &'a SnowflakeStatementResponse,
1424    column_indices: &'a std::collections::HashMap<String, usize>,
1425) -> Result<Box<dyn Iterator<Item = Example> + 'a>> {
1426    if let Some(code) = &response.code {
1427        if code != SNOWFLAKE_SUCCESS_CODE {
1428            anyhow::bail!(
1429                "snowflake sql api returned error code={code} message={}",
1430                response.message.as_deref().unwrap_or("<no message>")
1431            );
1432        }
1433    }
1434
1435    let iter = response
1436        .data
1437        .iter()
1438        .enumerate()
1439        .filter_map(move |(row_index, data_row)| {
1440            let get_string = |name: &str| -> Option<String> {
1441                let index = column_indices.get(name).copied()?;
1442                match data_row.get(index)? {
1443                    JsonValue::String(s) => Some(s.clone()),
1444                    JsonValue::Null => None,
1445                    other => Some(other.to_string()),
1446                }
1447            };
1448
1449            let get_json = |name: &str| -> Option<JsonValue> {
1450                let index = column_indices.get(name).copied()?;
1451                let value = data_row.get(index)?;
1452                if value.is_null() {
1453                    return None;
1454                }
1455                match value {
1456                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1457                    other => Some(other.clone()),
1458                }
1459            };
1460
1461            let get_bool = |name: &str| -> Option<bool> {
1462                let index = column_indices.get(name).copied()?;
1463                match data_row.get(index)? {
1464                    JsonValue::Bool(b) => Some(*b),
1465                    JsonValue::String(s) => s.parse().ok(),
1466                    _ => None,
1467                }
1468            };
1469
1470            let request_id_str = get_string("request_id");
1471            let device_id = get_string("device_id");
1472            let time = get_string("time");
1473            let input_json = get_json("input");
1474            let input: Option<ZetaPromptInput> =
1475                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1476            let output = get_string("output");
1477            let was_shown = get_bool("was_shown");
1478            let reason = get_string("reason");
1479            let zed_version = get_string("zed_version");
1480
1481            match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1482                (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1483                    Some(build_rejected_example(
1484                        request_id,
1485                        device_id,
1486                        time,
1487                        input,
1488                        output,
1489                        was_shown,
1490                        reason,
1491                        zed_version,
1492                    ))
1493                }
1494                _ => {
1495                    log::warn!(
1496                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1497                        request_id_str.is_some(),
1498                        device_id.is_some(),
1499                        time.is_some(),
1500                        input_json.is_some(),
1501                        output.is_some(),
1502                        was_shown.is_some(),
1503                        reason.is_some()
1504                    );
1505                    None
1506                }
1507            }
1508        });
1509
1510    Ok(Box::new(iter))
1511}
1512
1513fn build_rejected_example(
1514    request_id: String,
1515    device_id: String,
1516    time: String,
1517    input: ZetaPromptInput,
1518    output: String,
1519    was_shown: bool,
1520    reason: String,
1521    zed_version: Option<String>,
1522) -> Example {
1523    let rejected_patch = build_output_patch(
1524        &input.cursor_path,
1525        input.cursor_excerpt.as_ref(),
1526        &input.excerpt_ranges.editable_350,
1527        &output,
1528    );
1529    let mut example = build_example_from_snowflake(
1530        request_id,
1531        device_id,
1532        time,
1533        input,
1534        vec![format!("rejection:{}", reason.to_lowercase())],
1535        Some(RejectionInfo { reason, was_shown }),
1536        zed_version,
1537    );
1538    example.spec.rejected_patch = Some(rejected_patch);
1539    example
1540}
1541
1542struct RejectionInfo {
1543    reason: String,
1544    was_shown: bool,
1545}
1546
1547fn build_example_from_snowflake(
1548    request_id: String,
1549    device_id: String,
1550    time: String,
1551    input: ZetaPromptInput,
1552    tags: Vec<String>,
1553    rejection: Option<RejectionInfo>,
1554    zed_version: Option<String>,
1555) -> Example {
1556    let cursor_excerpt = input.cursor_excerpt.as_ref();
1557    let cursor_offset = input.cursor_offset_in_excerpt;
1558
1559    let mut edit_history = String::new();
1560    for event in &input.events {
1561        zeta_prompt::write_event(&mut edit_history, event);
1562        edit_history.push('\n');
1563    }
1564
1565    let (rejection_reason, was_shown) = match &rejection {
1566        Some(r) => (r.reason.clone(), r.was_shown),
1567        None => (String::new(), false),
1568    };
1569
1570    let spec = ExampleSpec {
1571        name: request_id.clone(),
1572        repository_url: String::new(),
1573        revision: String::new(),
1574        tags,
1575        reasoning: None,
1576        uncommitted_diff: String::new(),
1577        cursor_path: input.cursor_path.clone(),
1578        cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1579        edit_history,
1580        expected_patches: Vec::new(),
1581        rejected_patch: None,
1582        telemetry: Some(TelemetrySource {
1583            request_id,
1584            device_id,
1585            time,
1586            rejection_reason,
1587            was_shown,
1588        }),
1589        human_feedback: Vec::new(),
1590        rating: None,
1591    };
1592
1593    Example {
1594        spec,
1595        zed_version,
1596        prompt_inputs: Some(input),
1597        prompt: None,
1598        predictions: Vec::new(),
1599        score: Vec::new(),
1600        qa: Vec::new(),
1601        state: None,
1602    }
1603}
1604
1605fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1606    let before = &excerpt[..cursor_offset.min(excerpt.len())];
1607    let after = &excerpt[cursor_offset.min(excerpt.len())..];
1608    format!("{}[CURSOR_POSITION]{}", before, after)
1609}
1610
1611fn build_output_patch(
1612    cursor_path: &std::path::Path,
1613    cursor_excerpt: &str,
1614    editable_range: &std::ops::Range<usize>,
1615    model_output: &str,
1616) -> String {
1617    let old_text = &cursor_excerpt[editable_range.clone()];
1618
1619    let editable_start_row = cursor_excerpt[..editable_range.start]
1620        .chars()
1621        .filter(|&c| c == '\n')
1622        .count() as u32;
1623
1624    let diff_body = language::unified_diff_with_offsets(
1625        old_text,
1626        model_output,
1627        editable_start_row,
1628        editable_start_row,
1629    );
1630
1631    let mut patch = String::new();
1632    writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1633    writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1634    patch.push_str(&diff_body);
1635    patch
1636}
1637
1638pub(crate) fn get_column_indices(
1639    meta: &Option<SnowflakeResultSetMetaData>,
1640    names: &[&str],
1641) -> std::collections::HashMap<String, usize> {
1642    let mut indices = std::collections::HashMap::new();
1643    if let Some(meta) = meta {
1644        for (index, col) in meta.row_type.iter().enumerate() {
1645            for &name in names {
1646                if col.name.eq_ignore_ascii_case(name) {
1647                    indices.insert(name.to_string(), index);
1648                }
1649            }
1650        }
1651    }
1652    indices
1653}