pull_examples.rs

   1use anyhow::{Context as _, Result};
   2use flate2::read::GzDecoder;
   3use gpui::BackgroundExecutor;
   4use http_client::{AsyncBody, HttpClient, Method, Request};
   5use indoc::indoc;
   6use serde::Deserialize;
   7use serde_json::{Value as JsonValue, json};
   8use std::io::Read;
   9use std::sync::Arc;
  10use std::time::Duration;
  11
  12use zeta_prompt::ZetaPromptInput;
  13
  14use crate::example::Example;
  15use crate::progress::{InfoStyle, Progress, Step};
  16use edit_prediction::example_spec::{
  17    CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
  18    TelemetrySource,
  19};
  20use std::fmt::Write as _;
  21
  22const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
  23const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
  24const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
  25const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
  26const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
  27
  28const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
  29const POLL_INTERVAL: Duration = Duration::from_secs(2);
  30const MAX_POLL_ATTEMPTS: usize = 120;
  31
  32/// Parse an input token of the form `captured-after:{timestamp}`.
  33pub fn parse_captured_after_input(input: &str) -> Option<&str> {
  34    input.strip_prefix("captured-after:")
  35}
  36
  37/// Parse an input token of the form `rejected-after:{timestamp}`.
  38pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
  39    input.strip_prefix("rejected-after:")
  40}
  41
  42/// Parse an input token of the form `requested-after:{timestamp}`.
  43pub fn parse_requested_after_input(input: &str) -> Option<&str> {
  44    input.strip_prefix("requested-after:")
  45}
  46
  47pub async fn fetch_captured_examples_after(
  48    http_client: Arc<dyn HttpClient>,
  49    after_timestamps: &[String],
  50    max_rows_per_timestamp: usize,
  51    background_executor: BackgroundExecutor,
  52) -> Result<Vec<Example>> {
  53    if after_timestamps.is_empty() {
  54        return Ok(Vec::new());
  55    }
  56
  57    let progress = Progress::global();
  58
  59    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
  60        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
  61    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
  62        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
  63    )?;
  64    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
  65
  66    let mut all_examples = Vec::new();
  67
  68    for after_date in after_timestamps.iter() {
  69        let step_progress_name = format!(">{after_date}");
  70        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
  71        step_progress.set_substatus("querying");
  72
  73        let statement = indoc! {r#"
  74            SELECT
  75                event_properties:example AS example
  76            FROM events
  77            WHERE event_type = ?
  78                AND time > TRY_TO_TIMESTAMP_NTZ(?)
  79            ORDER BY time ASC
  80            LIMIT ?
  81        "#};
  82
  83        let request = json!({
  84            "statement": statement,
  85            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
  86            "database": "EVENTS",
  87            "schema": "PUBLIC",
  88            "warehouse": "DBT",
  89            "role": role,
  90            "bindings": {
  91                "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
  92                "2": { "type": "TEXT", "value": after_date },
  93                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
  94            }
  95        });
  96
  97        let response = run_sql_with_polling(
  98            http_client.clone(),
  99            &base_url,
 100            &token,
 101            &request,
 102            &step_progress,
 103            background_executor.clone(),
 104        )
 105        .await?;
 106
 107        let total_rows = response
 108            .result_set_meta_data
 109            .as_ref()
 110            .and_then(|m| m.num_rows)
 111            .unwrap_or(response.data.len() as i64);
 112
 113        let num_partitions = response
 114            .result_set_meta_data
 115            .as_ref()
 116            .map(|m| m.partition_info.len())
 117            .unwrap_or(1)
 118            .max(1);
 119
 120        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 121        step_progress.set_substatus("parsing");
 122
 123        all_examples.extend(examples_from_response(&response)?);
 124
 125        if num_partitions > 1 {
 126            let statement_handle = response
 127                .statement_handle
 128                .as_ref()
 129                .context("response has multiple partitions but no statementHandle")?;
 130
 131            for partition in 1..num_partitions {
 132                step_progress.set_substatus(format!(
 133                    "fetching partition {}/{}",
 134                    partition + 1,
 135                    num_partitions
 136                ));
 137
 138                let partition_response = fetch_partition(
 139                    http_client.clone(),
 140                    &base_url,
 141                    &token,
 142                    statement_handle,
 143                    partition,
 144                )
 145                .await?;
 146
 147                all_examples.extend(examples_from_response(&partition_response)?);
 148            }
 149        }
 150
 151        step_progress.set_substatus("done");
 152    }
 153
 154    Ok(all_examples)
 155}
 156
 157#[derive(Debug, Clone, Deserialize)]
 158#[serde(rename_all = "camelCase")]
 159struct SnowflakeStatementResponse {
 160    #[serde(default)]
 161    data: Vec<Vec<JsonValue>>,
 162    #[serde(default)]
 163    result_set_meta_data: Option<SnowflakeResultSetMetaData>,
 164    #[serde(default)]
 165    code: Option<String>,
 166    #[serde(default)]
 167    message: Option<String>,
 168    #[serde(default)]
 169    statement_handle: Option<String>,
 170}
 171
 172#[derive(Debug, Clone, Deserialize)]
 173#[serde(rename_all = "camelCase")]
 174struct SnowflakeResultSetMetaData {
 175    #[serde(default, rename = "rowType")]
 176    row_type: Vec<SnowflakeColumnMeta>,
 177    #[serde(default)]
 178    num_rows: Option<i64>,
 179    #[serde(default)]
 180    partition_info: Vec<SnowflakePartitionInfo>,
 181}
 182
 183#[derive(Debug, Clone, Deserialize)]
 184#[serde(rename_all = "camelCase")]
 185struct SnowflakePartitionInfo {}
 186
 187#[derive(Debug, Clone, Deserialize)]
 188struct SnowflakeColumnMeta {
 189    #[serde(default)]
 190    name: String,
 191}
 192
 193fn examples_from_response(
 194    response: &SnowflakeStatementResponse,
 195) -> Result<impl Iterator<Item = Example> + '_> {
 196    if let Some(code) = &response.code {
 197        if code != SNOWFLAKE_SUCCESS_CODE {
 198            anyhow::bail!(
 199                "snowflake sql api returned error code={code} message={}",
 200                response.message.as_deref().unwrap_or("<no message>")
 201            );
 202        }
 203    }
 204
 205    let example_index = response
 206        .result_set_meta_data
 207        .as_ref()
 208        .and_then(|m| {
 209            m.row_type.iter().enumerate().find_map(|(index, col)| {
 210                if col.name.eq_ignore_ascii_case("example") {
 211                    Some(index)
 212                } else {
 213                    None
 214                }
 215            })
 216        })
 217        .unwrap_or(0);
 218
 219    let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
 220        let Some(example_value) = data_row.get(example_index) else {
 221            return None;
 222        };
 223        if example_value.is_null() {
 224            return None;
 225        }
 226
 227        let parse_result = match example_value {
 228            JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
 229            _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
 230        };
 231
 232        match parse_result {
 233            Ok(spec) => Some(Example {
 234                spec,
 235                prompt_inputs: None,
 236                prompt: None,
 237                predictions: Vec::new(),
 238                score: Vec::new(),
 239                qa: Vec::new(),
 240                state: None,
 241            }),
 242            Err(error) => {
 243                let raw_json = serde_json::to_string_pretty(example_value)
 244                    .unwrap_or_else(|_| "<failed to serialize json>".to_string());
 245                log::error!(
 246                    "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
 247                );
 248                None
 249            }
 250        }
 251    });
 252
 253    Ok(iter)
 254}
 255
 256async fn run_sql_with_polling(
 257    http_client: Arc<dyn HttpClient>,
 258    base_url: &str,
 259    token: &str,
 260    request: &serde_json::Value,
 261    step_progress: &crate::progress::StepProgress,
 262    background_executor: BackgroundExecutor,
 263) -> Result<SnowflakeStatementResponse> {
 264    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
 265
 266    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 267        let statement_handle = response
 268            .statement_handle
 269            .as_ref()
 270            .context("async query response missing statementHandle")?
 271            .clone();
 272
 273        for attempt in 1..=MAX_POLL_ATTEMPTS {
 274            step_progress.set_substatus(format!("polling ({attempt})"));
 275
 276            background_executor.timer(POLL_INTERVAL).await;
 277
 278            response =
 279                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
 280
 281            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 282                break;
 283            }
 284        }
 285
 286        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 287            anyhow::bail!(
 288                "query still running after {} poll attempts ({} seconds)",
 289                MAX_POLL_ATTEMPTS,
 290                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
 291            );
 292        }
 293    }
 294
 295    Ok(response)
 296}
 297
 298async fn fetch_partition(
 299    http_client: Arc<dyn HttpClient>,
 300    base_url: &str,
 301    token: &str,
 302    statement_handle: &str,
 303    partition: usize,
 304) -> Result<SnowflakeStatementResponse> {
 305    let url = format!(
 306        "{}/api/v2/statements/{}?partition={}",
 307        base_url.trim_end_matches('/'),
 308        statement_handle,
 309        partition
 310    );
 311
 312    let http_request = Request::builder()
 313        .method(Method::GET)
 314        .uri(url.as_str())
 315        .header("Authorization", format!("Bearer {token}"))
 316        .header(
 317            "X-Snowflake-Authorization-Token-Type",
 318            "PROGRAMMATIC_ACCESS_TOKEN",
 319        )
 320        .header("Accept", "application/json")
 321        .header("Accept-Encoding", "gzip")
 322        .header("User-Agent", "edit_prediction_cli")
 323        .body(AsyncBody::empty())?;
 324
 325    let response = http_client
 326        .send(http_request)
 327        .await
 328        .context("failed to send partition request to Snowflake SQL API")?;
 329
 330    let status = response.status();
 331    let content_encoding = response
 332        .headers()
 333        .get("content-encoding")
 334        .and_then(|v| v.to_str().ok())
 335        .map(|s| s.to_lowercase());
 336
 337    let body_bytes = {
 338        use futures::AsyncReadExt as _;
 339
 340        let mut body = response.into_body();
 341        let mut bytes = Vec::new();
 342        body.read_to_end(&mut bytes)
 343            .await
 344            .context("failed to read Snowflake SQL API partition response body")?;
 345        bytes
 346    };
 347
 348    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
 349        let mut decoder = GzDecoder::new(&body_bytes[..]);
 350        let mut decompressed = Vec::new();
 351        decoder
 352            .read_to_end(&mut decompressed)
 353            .context("failed to decompress gzip response")?;
 354        decompressed
 355    } else {
 356        body_bytes
 357    };
 358
 359    if !status.is_success() && status.as_u16() != 202 {
 360        let body_text = String::from_utf8_lossy(&body_bytes);
 361        anyhow::bail!(
 362            "snowflake sql api partition request http {}: {}",
 363            status.as_u16(),
 364            body_text
 365        );
 366    }
 367
 368    if body_bytes.is_empty() {
 369        anyhow::bail!(
 370            "snowflake sql api partition {} returned empty response body (http {})",
 371            partition,
 372            status.as_u16()
 373        );
 374    }
 375
 376    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
 377        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
 378        format!(
 379            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
 380            partition,
 381            status.as_u16(),
 382            body_preview
 383        )
 384    })
 385}
 386
 387async fn run_sql(
 388    http_client: Arc<dyn HttpClient>,
 389    base_url: &str,
 390    token: &str,
 391    request: &serde_json::Value,
 392) -> Result<SnowflakeStatementResponse> {
 393    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
 394
 395    let request_body =
 396        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
 397
 398    let http_request = Request::builder()
 399        .method(Method::POST)
 400        .uri(url.as_str())
 401        .header("Authorization", format!("Bearer {token}"))
 402        .header(
 403            "X-Snowflake-Authorization-Token-Type",
 404            "PROGRAMMATIC_ACCESS_TOKEN",
 405        )
 406        .header("Content-Type", "application/json")
 407        .header("Accept", "application/json")
 408        .header("User-Agent", "edit_prediction_cli")
 409        .body(AsyncBody::from(request_body.clone()))?;
 410
 411    let response = http_client
 412        .send(http_request)
 413        .await
 414        .context("failed to send request to Snowflake SQL API")?;
 415
 416    let status = response.status();
 417    let body_bytes = {
 418        use futures::AsyncReadExt as _;
 419
 420        let mut body = response.into_body();
 421        let mut bytes = Vec::new();
 422        body.read_to_end(&mut bytes)
 423            .await
 424            .context("failed to read Snowflake SQL API response body")?;
 425        bytes
 426    };
 427
 428    if !status.is_success() && status.as_u16() != 202 {
 429        let body_text = String::from_utf8_lossy(&body_bytes);
 430        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
 431    }
 432
 433    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
 434        .context("failed to parse Snowflake SQL API response JSON")
 435}
 436
 437pub async fn fetch_rejected_examples_after(
 438    http_client: Arc<dyn HttpClient>,
 439    after_timestamps: &[String],
 440    max_rows_per_timestamp: usize,
 441    background_executor: BackgroundExecutor,
 442) -> Result<Vec<Example>> {
 443    if after_timestamps.is_empty() {
 444        return Ok(Vec::new());
 445    }
 446
 447    let progress = Progress::global();
 448
 449    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 450        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 451    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 452        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 453    )?;
 454    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 455
 456    let mut all_examples = Vec::new();
 457
 458    for after_date in after_timestamps.iter() {
 459        let step_progress_name = format!("rejected>{after_date}");
 460        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 461        step_progress.set_substatus("querying");
 462
 463        // Join rejected events with their corresponding request events to get the full context.
 464        // We filter for V3 sampling data which contains the structured input we need.
 465        // We also filter for predictions that were actually shown to the user (was_shown = true)
 466        // to focus on explicit user rejections rather than implicit cancellations.
 467        let statement = indoc! {r#"
 468            SELECT
 469                req.event_properties:request_id::string AS request_id,
 470                req.device_id::string AS device_id,
 471                req.time::string AS time,
 472                req.event_properties:input AS input,
 473                req.event_properties:prompt::string AS prompt,
 474                req.event_properties:output::string AS output,
 475                rej.event_properties:was_shown::boolean AS was_shown,
 476                rej.event_properties:reason::string AS reason
 477            FROM events req
 478            INNER JOIN events rej
 479                ON req.event_properties:request_id = rej.event_properties:request_id
 480            WHERE req.event_type = ?
 481                AND rej.event_type = ?
 482                AND req.event_properties:version = 'V3'
 483                AND rej.event_properties:was_shown = true
 484                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 485            ORDER BY req.time ASC
 486            LIMIT ?
 487        "#};
 488
 489        let request = json!({
 490            "statement": statement,
 491            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 492            "database": "EVENTS",
 493            "schema": "PUBLIC",
 494            "warehouse": "DBT",
 495            "role": role,
 496            "bindings": {
 497                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 498                "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
 499                "3": { "type": "TEXT", "value": after_date },
 500                "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 501            }
 502        });
 503
 504        let response = run_sql_with_polling(
 505            http_client.clone(),
 506            &base_url,
 507            &token,
 508            &request,
 509            &step_progress,
 510            background_executor.clone(),
 511        )
 512        .await?;
 513
 514        let total_rows = response
 515            .result_set_meta_data
 516            .as_ref()
 517            .and_then(|m| m.num_rows)
 518            .unwrap_or(response.data.len() as i64);
 519
 520        let num_partitions = response
 521            .result_set_meta_data
 522            .as_ref()
 523            .map(|m| m.partition_info.len())
 524            .unwrap_or(1)
 525            .max(1);
 526
 527        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 528        step_progress.set_substatus("parsing");
 529
 530        all_examples.extend(rejected_examples_from_response(&response)?);
 531
 532        if num_partitions > 1 {
 533            let statement_handle = response
 534                .statement_handle
 535                .as_ref()
 536                .context("response has multiple partitions but no statementHandle")?;
 537
 538            for partition in 1..num_partitions {
 539                step_progress.set_substatus(format!(
 540                    "fetching partition {}/{}",
 541                    partition + 1,
 542                    num_partitions
 543                ));
 544
 545                let partition_response = fetch_partition(
 546                    http_client.clone(),
 547                    &base_url,
 548                    &token,
 549                    statement_handle,
 550                    partition,
 551                )
 552                .await?;
 553
 554                all_examples.extend(rejected_examples_from_response(&partition_response)?);
 555            }
 556        }
 557
 558        step_progress.set_substatus("done");
 559    }
 560
 561    Ok(all_examples)
 562}
 563
 564pub async fn fetch_requested_examples_after(
 565    http_client: Arc<dyn HttpClient>,
 566    after_timestamps: &[String],
 567    max_rows_per_timestamp: usize,
 568    background_executor: BackgroundExecutor,
 569) -> Result<Vec<Example>> {
 570    if after_timestamps.is_empty() {
 571        return Ok(Vec::new());
 572    }
 573
 574    let progress = Progress::global();
 575
 576    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 577        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 578    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 579        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 580    )?;
 581    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 582
 583    let mut all_examples = Vec::new();
 584
 585    for after_date in after_timestamps.iter() {
 586        let step_progress_name = format!("requested>{after_date}");
 587        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 588        step_progress.set_substatus("querying");
 589
 590        let statement = indoc! {r#"
 591            SELECT
 592                req.event_properties:request_id::string AS request_id,
 593                req.device_id::string AS device_id,
 594                req.time::string AS time,
 595                req.event_properties:input AS input
 596            FROM events req
 597            WHERE req.event_type = ?
 598                AND req.event_properties:version = 'V3'
 599                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 600            ORDER BY req.time ASC
 601            LIMIT ?
 602        "#};
 603
 604        let request = json!({
 605            "statement": statement,
 606            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 607            "database": "EVENTS",
 608            "schema": "PUBLIC",
 609            "warehouse": "DBT",
 610            "role": role,
 611            "bindings": {
 612                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 613                "2": { "type": "TEXT", "value": after_date },
 614                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 615            }
 616        });
 617
 618        let response = run_sql_with_polling(
 619            http_client.clone(),
 620            &base_url,
 621            &token,
 622            &request,
 623            &step_progress,
 624            background_executor.clone(),
 625        )
 626        .await?;
 627
 628        let total_rows = response
 629            .result_set_meta_data
 630            .as_ref()
 631            .and_then(|m| m.num_rows)
 632            .unwrap_or(response.data.len() as i64);
 633
 634        let num_partitions = response
 635            .result_set_meta_data
 636            .as_ref()
 637            .map(|m| m.partition_info.len())
 638            .unwrap_or(1)
 639            .max(1);
 640
 641        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 642        step_progress.set_substatus("parsing");
 643
 644        let column_indices = get_column_indices(
 645            &response.result_set_meta_data,
 646            &["request_id", "device_id", "time", "input"],
 647        );
 648
 649        all_examples.extend(requested_examples_from_response(
 650            &response,
 651            &column_indices,
 652        )?);
 653
 654        if num_partitions > 1 {
 655            let statement_handle = response
 656                .statement_handle
 657                .as_ref()
 658                .context("response has multiple partitions but no statementHandle")?;
 659
 660            for partition in 1..num_partitions {
 661                step_progress.set_substatus(format!(
 662                    "fetching partition {}/{}",
 663                    partition + 1,
 664                    num_partitions
 665                ));
 666
 667                let partition_response = fetch_partition(
 668                    http_client.clone(),
 669                    &base_url,
 670                    &token,
 671                    statement_handle,
 672                    partition,
 673                )
 674                .await?;
 675
 676                all_examples.extend(requested_examples_from_response(
 677                    &partition_response,
 678                    &column_indices,
 679                )?);
 680            }
 681        }
 682
 683        step_progress.set_substatus("done");
 684    }
 685
 686    Ok(all_examples)
 687}
 688
 689fn requested_examples_from_response<'a>(
 690    response: &'a SnowflakeStatementResponse,
 691    column_indices: &'a std::collections::HashMap<String, usize>,
 692) -> Result<impl Iterator<Item = Example> + 'a> {
 693    if let Some(code) = &response.code {
 694        if code != SNOWFLAKE_SUCCESS_CODE {
 695            anyhow::bail!(
 696                "snowflake sql api returned error code={code} message={}",
 697                response.message.as_deref().unwrap_or("<no message>")
 698            );
 699        }
 700    }
 701
 702    let iter = response
 703        .data
 704        .iter()
 705        .enumerate()
 706        .filter_map(move |(row_index, data_row)| {
 707            let get_string = |name: &str| -> Option<String> {
 708                let index = column_indices.get(name).copied()?;
 709                match data_row.get(index)? {
 710                    JsonValue::String(s) => Some(s.clone()),
 711                    JsonValue::Null => None,
 712                    other => Some(other.to_string()),
 713                }
 714            };
 715
 716            let get_json = |name: &str| -> Option<JsonValue> {
 717                let index = column_indices.get(name).copied()?;
 718                let value = data_row.get(index)?;
 719                if value.is_null() {
 720                    return None;
 721                }
 722                match value {
 723                    JsonValue::String(s) => serde_json::from_str(s).ok(),
 724                    other => Some(other.clone()),
 725                }
 726            };
 727
 728            let request_id_str = get_string("request_id");
 729            let device_id = get_string("device_id");
 730            let time = get_string("time");
 731            let input_json = get_json("input");
 732            let input: Option<ZetaPromptInput> =
 733                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
 734
 735            match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
 736                (Some(request_id), Some(device_id), Some(time), Some(input)) => {
 737                    Some(build_example_from_snowflake(
 738                        request_id,
 739                        device_id,
 740                        time,
 741                        input,
 742                        vec!["requested".to_string()],
 743                        None,
 744                    ))
 745                }
 746                _ => {
 747                    log::warn!(
 748                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
 749                        request_id_str.is_some(),
 750                        device_id.is_some(),
 751                        time.is_some(),
 752                        input_json.is_some(),
 753                    );
 754                    None
 755                }
 756            }
 757        });
 758
 759    Ok(iter)
 760}
 761
 762fn rejected_examples_from_response(
 763    response: &SnowflakeStatementResponse,
 764) -> Result<impl Iterator<Item = Example> + '_> {
 765    if let Some(code) = &response.code {
 766        if code != SNOWFLAKE_SUCCESS_CODE {
 767            anyhow::bail!(
 768                "snowflake sql api returned error code={code} message={}",
 769                response.message.as_deref().unwrap_or("<no message>")
 770            );
 771        }
 772    }
 773
 774    let column_indices = get_column_indices(
 775        &response.result_set_meta_data,
 776        &[
 777            "request_id",
 778            "device_id",
 779            "time",
 780            "input",
 781            "prompt",
 782            "output",
 783            "was_shown",
 784            "reason",
 785        ],
 786    );
 787
 788    let iter = response
 789        .data
 790        .iter()
 791        .enumerate()
 792        .filter_map(move |(row_index, data_row)| {
 793            let get_string = |name: &str| -> Option<String> {
 794                let index = column_indices.get(name).copied()?;
 795                match data_row.get(index)? {
 796                    JsonValue::String(s) => Some(s.clone()),
 797                    JsonValue::Null => None,
 798                    other => Some(other.to_string()),
 799                }
 800            };
 801
 802            let get_json = |name: &str| -> Option<JsonValue> {
 803                let index = column_indices.get(name).copied()?;
 804                let value = data_row.get(index)?;
 805                if value.is_null() {
 806                    return None;
 807                }
 808                match value {
 809                    JsonValue::String(s) => serde_json::from_str(s).ok(),
 810                    other => Some(other.clone()),
 811                }
 812            };
 813
 814            let get_bool = |name: &str| -> Option<bool> {
 815                let index = column_indices.get(name).copied()?;
 816                match data_row.get(index)? {
 817                    JsonValue::Bool(b) => Some(*b),
 818                    JsonValue::String(s) => s.parse().ok(),
 819                    _ => None,
 820                }
 821            };
 822
 823            let request_id_str = get_string("request_id");
 824            let device_id = get_string("device_id");
 825            let time = get_string("time");
 826            let input_json = get_json("input");
 827            let input: Option<ZetaPromptInput> =
 828                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
 829            let output = get_string("output");
 830            let was_shown = get_bool("was_shown");
 831            let reason = get_string("reason");
 832
 833            match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
 834                (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
 835                    Some(build_rejected_example(
 836                        request_id,
 837                        device_id,
 838                        time,
 839                        input,
 840                        output,
 841                        was_shown,
 842                        reason,
 843                    ))
 844                }
 845                _ => {
 846                    log::warn!(
 847                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
 848                        request_id_str.is_some(),
 849                        device_id.is_some(),
 850                        time.is_some(),
 851                        input_json.is_some(),
 852                        output.is_some(),
 853                        was_shown.is_some(),
 854                        reason.is_some()
 855                    );
 856                    None
 857                }
 858            }
 859        });
 860
 861    Ok(iter)
 862}
 863
 864fn build_rejected_example(
 865    request_id: String,
 866    device_id: String,
 867    time: String,
 868    input: ZetaPromptInput,
 869    output: String,
 870    was_shown: bool,
 871    reason: String,
 872) -> Example {
 873    let rejected_patch = build_output_patch(
 874        &input.cursor_path,
 875        input.cursor_excerpt.as_ref(),
 876        &input.editable_range_in_excerpt,
 877        &output,
 878    );
 879    let mut example = build_example_from_snowflake(
 880        request_id,
 881        device_id,
 882        time,
 883        input,
 884        vec![format!("rejection:{}", reason.to_lowercase())],
 885        Some(RejectionInfo { reason, was_shown }),
 886    );
 887    example.spec.rejected_patch = Some(rejected_patch);
 888    example
 889}
 890
 891struct RejectionInfo {
 892    reason: String,
 893    was_shown: bool,
 894}
 895
 896fn build_example_from_snowflake(
 897    request_id: String,
 898    device_id: String,
 899    time: String,
 900    input: ZetaPromptInput,
 901    tags: Vec<String>,
 902    rejection: Option<RejectionInfo>,
 903) -> Example {
 904    let events: Vec<CapturedEvent> = input
 905        .events
 906        .iter()
 907        .map(|event| match event.as_ref() {
 908            zeta_prompt::Event::BufferChange {
 909                path,
 910                old_path,
 911                diff,
 912                predicted,
 913                in_open_source_repo,
 914            } => CapturedEvent {
 915                path: path.clone(),
 916                old_path: old_path.clone(),
 917                diff: diff.clone(),
 918                predicted: *predicted,
 919                in_open_source_repo: *in_open_source_repo,
 920            },
 921        })
 922        .collect();
 923
 924    let related_files: Vec<CapturedRelatedFile> = input
 925        .related_files
 926        .iter()
 927        .map(|rf| CapturedRelatedFile {
 928            path: rf.path.clone(),
 929            max_row: rf.max_row,
 930            excerpts: rf
 931                .excerpts
 932                .iter()
 933                .map(|e| CapturedRelatedExcerpt {
 934                    row_range: e.row_range.clone(),
 935                    text: e.text.to_string(),
 936                })
 937                .collect(),
 938        })
 939        .collect();
 940
 941    let cursor_excerpt = input.cursor_excerpt.as_ref();
 942    let cursor_offset = input.cursor_offset_in_excerpt;
 943
 944    let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
 945
 946    let mut edit_history = String::new();
 947    for event in &input.events {
 948        zeta_prompt::write_event(&mut edit_history, event);
 949        edit_history.push('\n');
 950    }
 951
 952    let (rejection_reason, was_shown) = match &rejection {
 953        Some(r) => (r.reason.clone(), r.was_shown),
 954        None => (String::new(), false),
 955    };
 956
 957    let spec = ExampleSpec {
 958        name: request_id.clone(),
 959        repository_url: String::new(),
 960        revision: String::new(),
 961        tags,
 962        reasoning: None,
 963        uncommitted_diff: String::new(),
 964        cursor_path: input.cursor_path.clone(),
 965        cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
 966        edit_history,
 967        expected_patches: Vec::new(),
 968        rejected_patch: None,
 969        captured_prompt_input: Some(CapturedPromptInput {
 970            cursor_file_content: cursor_excerpt.to_string(),
 971            cursor_offset,
 972            cursor_row,
 973            cursor_column,
 974            events,
 975            related_files,
 976        }),
 977        telemetry: Some(TelemetrySource {
 978            request_id,
 979            device_id,
 980            time,
 981            rejection_reason,
 982            was_shown,
 983        }),
 984    };
 985
 986    Example {
 987        spec,
 988        prompt_inputs: None,
 989        prompt: None,
 990        predictions: Vec::new(),
 991        score: Vec::new(),
 992        qa: Vec::new(),
 993        state: None,
 994    }
 995}
 996
 997fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
 998    let mut row = 0u32;
 999    let mut last_newline_offset = 0;
1000    for (i, c) in text.char_indices() {
1001        if i >= offset {
1002            break;
1003        }
1004        if c == '\n' {
1005            row += 1;
1006            last_newline_offset = i + 1;
1007        }
1008    }
1009    let column = (offset - last_newline_offset) as u32;
1010    (row, column)
1011}
1012
1013fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1014    let before = &excerpt[..cursor_offset.min(excerpt.len())];
1015    let after = &excerpt[cursor_offset.min(excerpt.len())..];
1016    format!("{}[CURSOR_POSITION]{}", before, after)
1017}
1018
1019fn build_output_patch(
1020    cursor_path: &std::path::Path,
1021    cursor_excerpt: &str,
1022    editable_range: &std::ops::Range<usize>,
1023    model_output: &str,
1024) -> String {
1025    let old_text = &cursor_excerpt[editable_range.clone()];
1026
1027    let editable_start_row = cursor_excerpt[..editable_range.start]
1028        .chars()
1029        .filter(|&c| c == '\n')
1030        .count() as u32;
1031
1032    let diff_body = language::unified_diff_with_offsets(
1033        old_text,
1034        model_output,
1035        editable_start_row,
1036        editable_start_row,
1037    );
1038
1039    let mut patch = String::new();
1040    writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1041    writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1042    patch.push_str(&diff_body);
1043    patch
1044}
1045
1046fn get_column_indices(
1047    meta: &Option<SnowflakeResultSetMetaData>,
1048    names: &[&str],
1049) -> std::collections::HashMap<String, usize> {
1050    let mut indices = std::collections::HashMap::new();
1051    if let Some(meta) = meta {
1052        for (index, col) in meta.row_type.iter().enumerate() {
1053            for &name in names {
1054                if col.name.eq_ignore_ascii_case(name) {
1055                    indices.insert(name.to_string(), index);
1056                }
1057            }
1058        }
1059    }
1060    indices
1061}