pull_examples.rs

   1use anyhow::{Context as _, Result};
   2use flate2::read::GzDecoder;
   3use gpui::BackgroundExecutor;
   4use http_client::{AsyncBody, HttpClient, Method, Request};
   5use indoc::indoc;
   6use serde::Deserialize;
   7use serde_json::{Value as JsonValue, json};
   8use std::io::Read;
   9use std::sync::Arc;
  10use std::time::Duration;
  11use telemetry_events::EditPredictionRating;
  12
  13use zeta_prompt::ZetaPromptInput;
  14
  15use crate::example::Example;
  16use crate::progress::{InfoStyle, Progress, Step};
  17use crate::sync_deployments::EDIT_PREDICTION_DEPLOYMENT_EVENT;
  18use edit_prediction::example_spec::{
  19    CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
  20    TelemetrySource,
  21};
  22use std::fmt::Write as _;
  23
  24pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
  25pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
  26const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
  27const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
  28const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
  29const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
  30
  31/// Minimum Zed version for filtering captured examples.
  32/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
  33/// where `zed_version >= 0.224.1`.
  34#[derive(Clone, Copy, Debug)]
  35pub struct MinCaptureVersion {
  36    pub minor: u32,
  37    pub patch: u32,
  38}
  39
  40const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
  41pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
  42pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
  43
  44/// Parse an input token of the form `captured-after:{timestamp}`.
  45pub fn parse_captured_after_input(input: &str) -> Option<&str> {
  46    input.strip_prefix("captured-after:")
  47}
  48
  49/// Parse an input token of the form `rejected-after:{timestamp}`.
  50pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
  51    input.strip_prefix("rejected-after:")
  52}
  53
  54/// Parse an input token of the form `requested-after:{timestamp}`.
  55pub fn parse_requested_after_input(input: &str) -> Option<&str> {
  56    input.strip_prefix("requested-after:")
  57}
  58
  59/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
  60/// or `rated-negative-after:{timestamp}`.
  61/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
  62pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
  63    if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
  64        Some((timestamp, Some(EditPredictionRating::Positive)))
  65    } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
  66        Some((timestamp, Some(EditPredictionRating::Negative)))
  67    } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
  68        Some((timestamp, None))
  69    } else {
  70        None
  71    }
  72}
  73
  74pub async fn fetch_captured_examples_after(
  75    http_client: Arc<dyn HttpClient>,
  76    after_timestamps: &[String],
  77    max_rows_per_timestamp: usize,
  78    offset: usize,
  79    background_executor: BackgroundExecutor,
  80    _min_capture_version: Option<MinCaptureVersion>,
  81) -> Result<Vec<Example>> {
  82    if after_timestamps.is_empty() {
  83        return Ok(Vec::new());
  84    }
  85
  86    let progress = Progress::global();
  87
  88    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
  89        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
  90    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
  91        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
  92    )?;
  93    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
  94
  95    let mut all_examples = Vec::new();
  96
  97    for after_date in after_timestamps.iter() {
  98        let step_progress_name = format!(">{after_date}");
  99        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 100        step_progress.set_substatus("querying");
 101
 102        let statement = indoc! {r#"
 103            SELECT
 104                event_properties:example AS example
 105            FROM events
 106            WHERE event_type = ?
 107                AND time > TRY_TO_TIMESTAMP_NTZ(?)
 108            ORDER BY time ASC
 109            LIMIT ?
 110            OFFSET ?
 111        "#};
 112
 113        let request = json!({
 114            "statement": statement,
 115            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 116            "database": "EVENTS",
 117            "schema": "PUBLIC",
 118            "warehouse": "DBT",
 119            "role": role,
 120            "bindings": {
 121                "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
 122                "2": { "type": "TEXT", "value": after_date },
 123                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 124                "4": { "type": "FIXED", "value": offset.to_string() }
 125            }
 126        });
 127
 128        let response = run_sql_with_polling(
 129            http_client.clone(),
 130            &base_url,
 131            &token,
 132            &request,
 133            &step_progress,
 134            background_executor.clone(),
 135        )
 136        .await?;
 137
 138        let total_rows = response
 139            .result_set_meta_data
 140            .as_ref()
 141            .and_then(|m| m.num_rows)
 142            .unwrap_or(response.data.len() as i64);
 143
 144        let num_partitions = response
 145            .result_set_meta_data
 146            .as_ref()
 147            .map(|m| m.partition_info.len())
 148            .unwrap_or(1)
 149            .max(1);
 150
 151        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 152        step_progress.set_substatus("parsing");
 153
 154        let example_index = response
 155            .result_set_meta_data
 156            .as_ref()
 157            .and_then(|m| {
 158                m.row_type.iter().enumerate().find_map(|(index, col)| {
 159                    if col.name.eq_ignore_ascii_case("example") {
 160                        Some(index)
 161                    } else {
 162                        None
 163                    }
 164                })
 165            })
 166            .unwrap_or(0);
 167
 168        all_examples.extend(examples_from_response(&response, example_index)?);
 169
 170        if num_partitions > 1 {
 171            let statement_handle = response
 172                .statement_handle
 173                .as_ref()
 174                .context("response has multiple partitions but no statementHandle")?;
 175
 176            for partition in 1..num_partitions {
 177                step_progress.set_substatus(format!(
 178                    "fetching partition {}/{}",
 179                    partition + 1,
 180                    num_partitions
 181                ));
 182
 183                let partition_response = fetch_partition(
 184                    http_client.clone(),
 185                    &base_url,
 186                    &token,
 187                    statement_handle,
 188                    partition,
 189                )
 190                .await?;
 191
 192                all_examples.extend(examples_from_response(&partition_response, example_index)?);
 193            }
 194        }
 195
 196        step_progress.set_substatus("done");
 197    }
 198
 199    Ok(all_examples)
 200}
 201
 202#[derive(Debug, Clone, Deserialize)]
 203#[serde(rename_all = "camelCase")]
 204pub(crate) struct SnowflakeStatementResponse {
 205    #[serde(default)]
 206    pub(crate) data: Vec<Vec<JsonValue>>,
 207    #[serde(default)]
 208    pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
 209    #[serde(default)]
 210    pub(crate) code: Option<String>,
 211    #[serde(default)]
 212    pub(crate) message: Option<String>,
 213    #[serde(default)]
 214    pub(crate) statement_handle: Option<String>,
 215}
 216
 217#[derive(Debug, Clone, Deserialize)]
 218#[serde(rename_all = "camelCase")]
 219pub(crate) struct SnowflakeResultSetMetaData {
 220    #[serde(default, rename = "rowType")]
 221    row_type: Vec<SnowflakeColumnMeta>,
 222    #[serde(default)]
 223    num_rows: Option<i64>,
 224    #[serde(default)]
 225    partition_info: Vec<SnowflakePartitionInfo>,
 226}
 227
 228#[derive(Debug, Clone, Deserialize)]
 229#[serde(rename_all = "camelCase")]
 230struct SnowflakePartitionInfo {}
 231
 232#[derive(Debug, Clone, Deserialize)]
 233struct SnowflakeColumnMeta {
 234    #[serde(default)]
 235    name: String,
 236}
 237
 238fn examples_from_response(
 239    response: &SnowflakeStatementResponse,
 240    example_index: usize,
 241) -> Result<impl Iterator<Item = Example> + '_> {
 242    if let Some(code) = &response.code {
 243        if code != SNOWFLAKE_SUCCESS_CODE {
 244            anyhow::bail!(
 245                "snowflake sql api returned error code={code} message={}",
 246                response.message.as_deref().unwrap_or("<no message>")
 247            );
 248        }
 249    }
 250
 251    let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
 252        let Some(example_value) = data_row.get(example_index) else {
 253            return None;
 254        };
 255        if example_value.is_null() {
 256            return None;
 257        }
 258
 259        let parse_result = match example_value {
 260            JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
 261            _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
 262        };
 263
 264        match parse_result {
 265            Ok(spec) => Some(Example {
 266                spec,
 267                prompt_inputs: None,
 268                prompt: None,
 269                predictions: Vec::new(),
 270                score: Vec::new(),
 271                qa: Vec::new(),
 272                state: None,
 273            }),
 274            Err(error) => {
 275                let raw_json = serde_json::to_string_pretty(example_value)
 276                    .unwrap_or_else(|_| "<failed to serialize json>".to_string());
 277                log::error!(
 278                    "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
 279                );
 280                None
 281            }
 282        }
 283    });
 284
 285    Ok(iter)
 286}
 287
 288async fn run_sql_with_polling(
 289    http_client: Arc<dyn HttpClient>,
 290    base_url: &str,
 291    token: &str,
 292    request: &serde_json::Value,
 293    step_progress: &crate::progress::StepProgress,
 294    background_executor: BackgroundExecutor,
 295) -> Result<SnowflakeStatementResponse> {
 296    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
 297
 298    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 299        let statement_handle = response
 300            .statement_handle
 301            .as_ref()
 302            .context("async query response missing statementHandle")?
 303            .clone();
 304
 305        for attempt in 1..=MAX_POLL_ATTEMPTS {
 306            step_progress.set_substatus(format!("polling ({attempt})"));
 307
 308            background_executor.timer(POLL_INTERVAL).await;
 309
 310            response =
 311                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
 312
 313            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 314                break;
 315            }
 316        }
 317
 318        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 319            anyhow::bail!(
 320                "query still running after {} poll attempts ({} seconds)",
 321                MAX_POLL_ATTEMPTS,
 322                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
 323            );
 324        }
 325    }
 326
 327    Ok(response)
 328}
 329
 330pub(crate) async fn fetch_partition(
 331    http_client: Arc<dyn HttpClient>,
 332    base_url: &str,
 333    token: &str,
 334    statement_handle: &str,
 335    partition: usize,
 336) -> Result<SnowflakeStatementResponse> {
 337    let url = format!(
 338        "{}/api/v2/statements/{}?partition={}",
 339        base_url.trim_end_matches('/'),
 340        statement_handle,
 341        partition
 342    );
 343
 344    let http_request = Request::builder()
 345        .method(Method::GET)
 346        .uri(url.as_str())
 347        .header("Authorization", format!("Bearer {token}"))
 348        .header(
 349            "X-Snowflake-Authorization-Token-Type",
 350            "PROGRAMMATIC_ACCESS_TOKEN",
 351        )
 352        .header("Accept", "application/json")
 353        .header("Accept-Encoding", "gzip")
 354        .header("User-Agent", "edit_prediction_cli")
 355        .body(AsyncBody::empty())?;
 356
 357    let response = http_client
 358        .send(http_request)
 359        .await
 360        .context("failed to send partition request to Snowflake SQL API")?;
 361
 362    let status = response.status();
 363    let content_encoding = response
 364        .headers()
 365        .get("content-encoding")
 366        .and_then(|v| v.to_str().ok())
 367        .map(|s| s.to_lowercase());
 368
 369    let body_bytes = {
 370        use futures::AsyncReadExt as _;
 371
 372        let mut body = response.into_body();
 373        let mut bytes = Vec::new();
 374        body.read_to_end(&mut bytes)
 375            .await
 376            .context("failed to read Snowflake SQL API partition response body")?;
 377        bytes
 378    };
 379
 380    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
 381        let mut decoder = GzDecoder::new(&body_bytes[..]);
 382        let mut decompressed = Vec::new();
 383        decoder
 384            .read_to_end(&mut decompressed)
 385            .context("failed to decompress gzip response")?;
 386        decompressed
 387    } else {
 388        body_bytes
 389    };
 390
 391    if !status.is_success() && status.as_u16() != 202 {
 392        let body_text = String::from_utf8_lossy(&body_bytes);
 393        anyhow::bail!(
 394            "snowflake sql api partition request http {}: {}",
 395            status.as_u16(),
 396            body_text
 397        );
 398    }
 399
 400    if body_bytes.is_empty() {
 401        anyhow::bail!(
 402            "snowflake sql api partition {} returned empty response body (http {})",
 403            partition,
 404            status.as_u16()
 405        );
 406    }
 407
 408    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
 409        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
 410        format!(
 411            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
 412            partition,
 413            status.as_u16(),
 414            body_preview
 415        )
 416    })
 417}
 418
 419pub(crate) async fn run_sql(
 420    http_client: Arc<dyn HttpClient>,
 421    base_url: &str,
 422    token: &str,
 423    request: &serde_json::Value,
 424) -> Result<SnowflakeStatementResponse> {
 425    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
 426
 427    let request_body =
 428        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
 429
 430    let http_request = Request::builder()
 431        .method(Method::POST)
 432        .uri(url.as_str())
 433        .header("Authorization", format!("Bearer {token}"))
 434        .header(
 435            "X-Snowflake-Authorization-Token-Type",
 436            "PROGRAMMATIC_ACCESS_TOKEN",
 437        )
 438        .header("Content-Type", "application/json")
 439        .header("Accept", "application/json")
 440        .header("User-Agent", "edit_prediction_cli")
 441        .body(AsyncBody::from(request_body.clone()))?;
 442
 443    let response = http_client
 444        .send(http_request)
 445        .await
 446        .context("failed to send request to Snowflake SQL API")?;
 447
 448    let status = response.status();
 449    let body_bytes = {
 450        use futures::AsyncReadExt as _;
 451
 452        let mut body = response.into_body();
 453        let mut bytes = Vec::new();
 454        body.read_to_end(&mut bytes)
 455            .await
 456            .context("failed to read Snowflake SQL API response body")?;
 457        bytes
 458    };
 459
 460    if !status.is_success() && status.as_u16() != 202 {
 461        let body_text = String::from_utf8_lossy(&body_bytes);
 462        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
 463    }
 464
 465    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
 466        .context("failed to parse Snowflake SQL API response JSON")
 467}
 468
 469pub async fn fetch_rejected_examples_after(
 470    http_client: Arc<dyn HttpClient>,
 471    after_timestamps: &[String],
 472    max_rows_per_timestamp: usize,
 473    offset: usize,
 474    background_executor: BackgroundExecutor,
 475    min_capture_version: Option<MinCaptureVersion>,
 476) -> Result<Vec<Example>> {
 477    if after_timestamps.is_empty() {
 478        return Ok(Vec::new());
 479    }
 480
 481    let progress = Progress::global();
 482
 483    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 484        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 485    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 486        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 487    )?;
 488    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 489
 490    let mut all_examples = Vec::new();
 491
 492    for after_date in after_timestamps.iter() {
 493        let step_progress_name = format!("rejected>{after_date}");
 494        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 495        step_progress.set_substatus("querying");
 496
 497        // Join rejected events with their corresponding request events to get the full context.
 498        // We filter for V3 sampling data which contains the structured input we need.
 499        // We also filter for predictions that were actually shown to the user (was_shown = true)
 500        // to focus on explicit user rejections rather than implicit cancellations.
 501        let statement = indoc! {r#"
 502            SELECT
 503                req.event_properties:request_id::string AS request_id,
 504                req.device_id::string AS device_id,
 505                req.time::string AS time,
 506                req.event_properties:input AS input,
 507                req.event_properties:prompt::string AS prompt,
 508                req.event_properties:output::string AS output,
 509                rej.event_properties:was_shown::boolean AS was_shown,
 510                rej.event_properties:reason::string AS reason,
 511                req.event_properties:zed_version::string AS zed_version
 512            FROM events req
 513            INNER JOIN events rej
 514                ON req.event_properties:request_id = rej.event_properties:request_id
 515            WHERE req.event_type = ?
 516                AND rej.event_type = ?
 517                AND req.event_properties:version = 'V3'
 518                AND rej.event_properties:was_shown = true
 519                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 520                AND (? IS NULL OR (
 521                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 522                    OR (
 523                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 524                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 525                    )
 526                ))
 527            ORDER BY req.time ASC
 528            LIMIT ?
 529            OFFSET ?
 530        "#};
 531
 532        let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
 533        let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
 534        let min_minor_str_ref = min_minor_str.as_deref();
 535        let min_patch_str_ref = min_patch_str.as_deref();
 536        let request = json!({
 537            "statement": statement,
 538            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 539            "database": "EVENTS",
 540            "schema": "PUBLIC",
 541            "warehouse": "DBT",
 542            "role": role,
 543            "bindings": {
 544                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 545                "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
 546                "3": { "type": "TEXT", "value": after_date },
 547                "4": { "type": "FIXED", "value": min_minor_str_ref },
 548                "5": { "type": "FIXED", "value": min_minor_str_ref },
 549                "6": { "type": "FIXED", "value": min_minor_str_ref },
 550                "7": { "type": "FIXED", "value": min_patch_str_ref },
 551                "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 552                "9": { "type": "FIXED", "value": offset.to_string() }
 553            }
 554        });
 555
 556        let response = run_sql_with_polling(
 557            http_client.clone(),
 558            &base_url,
 559            &token,
 560            &request,
 561            &step_progress,
 562            background_executor.clone(),
 563        )
 564        .await?;
 565
 566        let total_rows = response
 567            .result_set_meta_data
 568            .as_ref()
 569            .and_then(|m| m.num_rows)
 570            .unwrap_or(response.data.len() as i64);
 571
 572        let num_partitions = response
 573            .result_set_meta_data
 574            .as_ref()
 575            .map(|m| m.partition_info.len())
 576            .unwrap_or(1)
 577            .max(1);
 578
 579        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 580        step_progress.set_substatus("parsing");
 581
 582        let column_indices = get_column_indices(
 583            &response.result_set_meta_data,
 584            &[
 585                "request_id",
 586                "device_id",
 587                "time",
 588                "input",
 589                "prompt",
 590                "output",
 591                "was_shown",
 592                "reason",
 593                "zed_version",
 594            ],
 595        );
 596
 597        all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
 598
 599        if num_partitions > 1 {
 600            let statement_handle = response
 601                .statement_handle
 602                .as_ref()
 603                .context("response has multiple partitions but no statementHandle")?;
 604
 605            for partition in 1..num_partitions {
 606                step_progress.set_substatus(format!(
 607                    "fetching partition {}/{}",
 608                    partition + 1,
 609                    num_partitions
 610                ));
 611
 612                let partition_response = fetch_partition(
 613                    http_client.clone(),
 614                    &base_url,
 615                    &token,
 616                    statement_handle,
 617                    partition,
 618                )
 619                .await?;
 620
 621                all_examples.extend(rejected_examples_from_response(
 622                    &partition_response,
 623                    &column_indices,
 624                )?);
 625            }
 626        }
 627
 628        step_progress.set_substatus("done");
 629    }
 630
 631    Ok(all_examples)
 632}
 633
 634pub async fn fetch_requested_examples_after(
 635    http_client: Arc<dyn HttpClient>,
 636    after_timestamps: &[String],
 637    max_rows_per_timestamp: usize,
 638    offset: usize,
 639    background_executor: BackgroundExecutor,
 640    min_capture_version: Option<MinCaptureVersion>,
 641) -> Result<Vec<Example>> {
 642    if after_timestamps.is_empty() {
 643        return Ok(Vec::new());
 644    }
 645
 646    let progress = Progress::global();
 647
 648    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 649        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 650    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 651        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 652    )?;
 653    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 654
 655    let mut all_examples = Vec::new();
 656
 657    for after_date in after_timestamps.iter() {
 658        let step_progress_name = format!("requested>{after_date}");
 659        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 660        step_progress.set_substatus("querying");
 661
 662        let statement = indoc! {r#"
 663            SELECT
 664                req.event_properties:request_id::string AS request_id,
 665                req.device_id::string AS device_id,
 666                req.time::string AS time,
 667                req.event_properties:input AS input,
 668                req.event_properties:zed_version::string AS zed_version
 669            FROM events req
 670            WHERE req.event_type = ?
 671                AND req.event_properties:version = 'V3'
 672                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 673                AND (? IS NULL OR (
 674                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 675                    OR (
 676                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 677                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 678                    )
 679                ))
 680            ORDER BY req.time ASC
 681            LIMIT ?
 682            OFFSET ?
 683        "#};
 684
 685        let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
 686        let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
 687        let min_minor_str_ref = min_minor_str.as_deref();
 688        let min_patch_str_ref = min_patch_str.as_deref();
 689        let request = json!({
 690            "statement": statement,
 691            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 692            "database": "EVENTS",
 693            "schema": "PUBLIC",
 694            "warehouse": "DBT",
 695            "role": role,
 696            "bindings": {
 697                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 698                "2": { "type": "TEXT", "value": after_date },
 699                "3": { "type": "FIXED", "value": min_minor_str_ref },
 700                "4": { "type": "FIXED", "value": min_minor_str_ref },
 701                "5": { "type": "FIXED", "value": min_minor_str_ref },
 702                "6": { "type": "FIXED", "value": min_patch_str_ref },
 703                "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 704                "8": { "type": "FIXED", "value": offset.to_string() }
 705            }
 706        });
 707
 708        let response = run_sql_with_polling(
 709            http_client.clone(),
 710            &base_url,
 711            &token,
 712            &request,
 713            &step_progress,
 714            background_executor.clone(),
 715        )
 716        .await?;
 717
 718        let total_rows = response
 719            .result_set_meta_data
 720            .as_ref()
 721            .and_then(|m| m.num_rows)
 722            .unwrap_or(response.data.len() as i64);
 723
 724        let num_partitions = response
 725            .result_set_meta_data
 726            .as_ref()
 727            .map(|m| m.partition_info.len())
 728            .unwrap_or(1)
 729            .max(1);
 730
 731        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 732        step_progress.set_substatus("parsing");
 733
 734        let column_indices = get_column_indices(
 735            &response.result_set_meta_data,
 736            &["request_id", "device_id", "time", "input", "zed_version"],
 737        );
 738
 739        all_examples.extend(requested_examples_from_response(
 740            &response,
 741            &column_indices,
 742        )?);
 743
 744        if num_partitions > 1 {
 745            let statement_handle = response
 746                .statement_handle
 747                .as_ref()
 748                .context("response has multiple partitions but no statementHandle")?;
 749
 750            for partition in 1..num_partitions {
 751                step_progress.set_substatus(format!(
 752                    "fetching partition {}/{}",
 753                    partition + 1,
 754                    num_partitions
 755                ));
 756
 757                let partition_response = fetch_partition(
 758                    http_client.clone(),
 759                    &base_url,
 760                    &token,
 761                    statement_handle,
 762                    partition,
 763                )
 764                .await?;
 765
 766                all_examples.extend(requested_examples_from_response(
 767                    &partition_response,
 768                    &column_indices,
 769                )?);
 770            }
 771        }
 772
 773        step_progress.set_substatus("done");
 774    }
 775
 776    Ok(all_examples)
 777}
 778
 779pub async fn fetch_rated_examples_after(
 780    http_client: Arc<dyn HttpClient>,
 781    inputs: &[(String, Option<EditPredictionRating>)],
 782    max_rows_per_timestamp: usize,
 783    offset: usize,
 784    background_executor: BackgroundExecutor,
 785    min_capture_version: Option<MinCaptureVersion>,
 786) -> Result<Vec<Example>> {
 787    if inputs.is_empty() {
 788        return Ok(Vec::new());
 789    }
 790
 791    let progress = Progress::global();
 792
 793    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 794        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 795    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 796        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 797    )?;
 798    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 799
 800    let mut all_examples = Vec::new();
 801
 802    for (after_date, rating_filter) in inputs.iter() {
 803        let filter_label = match rating_filter {
 804            None => "",
 805            Some(EditPredictionRating::Positive) => ":positive",
 806            Some(EditPredictionRating::Negative) => ":negative",
 807        };
 808        let step_progress_name = format!("rated{filter_label}>{after_date}");
 809        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 810        step_progress.set_substatus("querying");
 811
 812        let rating_value = rating_filter.as_ref().map(|r| match r {
 813            EditPredictionRating::Positive => "Positive",
 814            EditPredictionRating::Negative => "Negative",
 815        });
 816
 817        let statement = indoc! {r#"
 818            SELECT
 819                rated.event_properties:request_id::string AS request_id,
 820                rated.event_properties:inputs AS inputs,
 821                rated.event_properties:output::string AS output,
 822                rated.event_properties:rating::string AS rating,
 823                rated.event_properties:feedback::string AS feedback,
 824                rated.device_id::string AS device_id,
 825                rated.time::string AS time,
 826                deploy.event_properties:experiment_name::string AS experiment_name,
 827                deploy.event_properties:environment::string AS environment,
 828                rated.event_properties:zed_version::string AS zed_version
 829            FROM events rated
 830            LEFT JOIN events req
 831                ON rated.event_properties:request_id::string = req.event_properties:request_id::string
 832                AND req.event_type = ?
 833            LEFT JOIN events deploy
 834                ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
 835                AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
 836                AND deploy.event_type = ?
 837            WHERE rated.event_type = ?
 838                AND (? IS NULL OR rated.event_properties:rating::string = ?)
 839                AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
 840                AND rated.event_properties:inputs IS NOT NULL
 841                AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
 842                AND rated.event_properties:output IS NOT NULL
 843                AND (? IS NULL OR (
 844                    TRY_CAST(SPLIT_PART(rated.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 845                    OR (
 846                        TRY_CAST(SPLIT_PART(rated.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 847                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(rated.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 848                    )
 849                ))
 850            ORDER BY rated.time ASC
 851            LIMIT ?
 852            OFFSET ?
 853        "#};
 854
 855        let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
 856        let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
 857        let min_minor_str_ref = min_minor_str.as_deref();
 858        let min_patch_str_ref = min_patch_str.as_deref();
 859        let bindings = json!({
 860            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 861            "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
 862            "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
 863            "4": { "type": "TEXT", "value": rating_value },
 864            "5": { "type": "TEXT", "value": rating_value },
 865            "6": { "type": "TEXT", "value": after_date },
 866            "7": { "type": "FIXED", "value": min_minor_str_ref },
 867            "8": { "type": "FIXED", "value": min_minor_str_ref },
 868            "9": { "type": "FIXED", "value": min_minor_str_ref },
 869            "10": { "type": "FIXED", "value": min_patch_str_ref },
 870            "11": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 871            "12": { "type": "FIXED", "value": offset.to_string() }
 872        });
 873
 874        let request = json!({
 875            "statement": statement,
 876            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 877            "database": "EVENTS",
 878            "schema": "PUBLIC",
 879            "warehouse": "DBT",
 880            "role": role,
 881            "bindings": bindings
 882        });
 883
 884        let response = run_sql_with_polling(
 885            http_client.clone(),
 886            &base_url,
 887            &token,
 888            &request,
 889            &step_progress,
 890            background_executor.clone(),
 891        )
 892        .await?;
 893
 894        let total_rows = response
 895            .result_set_meta_data
 896            .as_ref()
 897            .and_then(|m| m.num_rows)
 898            .unwrap_or(response.data.len() as i64);
 899
 900        let num_partitions = response
 901            .result_set_meta_data
 902            .as_ref()
 903            .map(|m| m.partition_info.len())
 904            .unwrap_or(1)
 905            .max(1);
 906
 907        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 908        step_progress.set_substatus("parsing");
 909
 910        let column_indices = get_column_indices(
 911            &response.result_set_meta_data,
 912            &[
 913                "request_id",
 914                "inputs",
 915                "output",
 916                "rating",
 917                "feedback",
 918                "device_id",
 919                "time",
 920                "experiment_name",
 921                "environment",
 922                "zed_version",
 923            ],
 924        );
 925
 926        all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
 927
 928        if num_partitions > 1 {
 929            let statement_handle = response
 930                .statement_handle
 931                .as_ref()
 932                .context("response has multiple partitions but no statementHandle")?;
 933
 934            for partition in 1..num_partitions {
 935                step_progress.set_substatus(format!(
 936                    "fetching partition {}/{}",
 937                    partition + 1,
 938                    num_partitions
 939                ));
 940
 941                let partition_response = fetch_partition(
 942                    http_client.clone(),
 943                    &base_url,
 944                    &token,
 945                    statement_handle,
 946                    partition,
 947                )
 948                .await?;
 949
 950                all_examples.extend(rated_examples_from_response(
 951                    &partition_response,
 952                    &column_indices,
 953                )?);
 954            }
 955        }
 956
 957        step_progress.set_substatus("done");
 958    }
 959
 960    Ok(all_examples)
 961}
 962
 963fn rated_examples_from_response<'a>(
 964    response: &'a SnowflakeStatementResponse,
 965    column_indices: &'a std::collections::HashMap<String, usize>,
 966) -> Result<impl Iterator<Item = Example> + 'a> {
 967    if let Some(code) = &response.code {
 968        if code != SNOWFLAKE_SUCCESS_CODE {
 969            anyhow::bail!(
 970                "snowflake sql api returned error code={code} message={}",
 971                response.message.as_deref().unwrap_or("<no message>")
 972            );
 973        }
 974    }
 975
 976    let iter = response
 977        .data
 978        .iter()
 979        .enumerate()
 980        .filter_map(move |(row_index, data_row)| {
 981            let get_string = |name: &str| -> Option<String> {
 982                let index = column_indices.get(name).copied()?;
 983                match data_row.get(index)? {
 984                    JsonValue::String(s) => Some(s.clone()),
 985                    JsonValue::Null => None,
 986                    other => Some(other.to_string()),
 987                }
 988            };
 989
 990            let get_json = |name: &str| -> Option<JsonValue> {
 991                let index = column_indices.get(name).copied()?;
 992                let value = data_row.get(index)?;
 993                if value.is_null() {
 994                    return None;
 995                }
 996                match value {
 997                    JsonValue::String(s) => serde_json::from_str(s).ok(),
 998                    other => Some(other.clone()),
 999                }
1000            };
1001
1002            let request_id = get_string("request_id");
1003            let inputs_json = get_json("inputs");
1004            let inputs: Option<ZetaPromptInput> = match &inputs_json {
1005                Some(v) => match serde_json::from_value(v.clone()) {
1006                    Ok(parsed) => Some(parsed),
1007                    Err(e) => {
1008                        log::warn!(
1009                            "skipping row {row_index}: failed to parse inputs - {e}",
1010                        );
1011                        return None;
1012                    }
1013                },
1014                None => None,
1015            };
1016            let output = get_string("output");
1017            let rating = get_string("rating");
1018            let feedback = get_string("feedback").unwrap_or_default();
1019            let device_id = get_string("device_id");
1020            let time = get_string("time");
1021            let experiment_name = get_string("experiment_name");
1022            let environment = get_string("environment");
1023            let zed_version = get_string("zed_version");
1024
1025            match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
1026                (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
1027                    Some(build_rated_example(
1028                        request_id,
1029                        device_id,
1030                        time,
1031                        inputs,
1032                        output,
1033                        rating,
1034                        feedback,
1035                        experiment_name,
1036                        environment,
1037                        zed_version,
1038                    ))
1039                }
1040                _ => {
1041                    log::warn!(
1042                        "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
1043                        inputs_json.is_some(),
1044                        output.is_some(),
1045                        rating.is_some(),
1046                        device_id.is_some(),
1047                        time.is_some(),
1048                    );
1049                    None
1050                }
1051            }
1052        });
1053
1054    Ok(iter)
1055}
1056
1057fn build_rated_example(
1058    request_id: Option<String>,
1059    device_id: String,
1060    time: String,
1061    input: ZetaPromptInput,
1062    output: String,
1063    rating: String,
1064    feedback: String,
1065    experiment_name: Option<String>,
1066    environment: Option<String>,
1067    zed_version: Option<String>,
1068) -> Example {
1069    let parsed_rating = if rating == "Positive" {
1070        EditPredictionRating::Positive
1071    } else {
1072        EditPredictionRating::Negative
1073    };
1074    let is_positive = parsed_rating == EditPredictionRating::Positive;
1075    let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
1076
1077    let mut tags = Vec::with_capacity(3);
1078    tags.push(if is_positive {
1079        "rated:positive".to_string()
1080    } else {
1081        "rated:negative".to_string()
1082    });
1083    if let Some(experiment) = experiment_name {
1084        tags.push(format!("experiment:{experiment}"));
1085    }
1086    if let Some(env) = environment {
1087        tags.push(format!("environment:{env}"));
1088    }
1089
1090    let mut example =
1091        build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
1092
1093    example.spec.rating = Some(parsed_rating);
1094
1095    if !feedback.is_empty() {
1096        example
1097            .spec
1098            .human_feedback
1099            .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
1100    }
1101
1102    if is_positive {
1103        example.spec.expected_patches = vec![output];
1104    } else {
1105        example.spec.rejected_patch = Some(output);
1106    }
1107
1108    example
1109}
1110
1111fn requested_examples_from_response<'a>(
1112    response: &'a SnowflakeStatementResponse,
1113    column_indices: &'a std::collections::HashMap<String, usize>,
1114) -> Result<impl Iterator<Item = Example> + 'a> {
1115    if let Some(code) = &response.code {
1116        if code != SNOWFLAKE_SUCCESS_CODE {
1117            anyhow::bail!(
1118                "snowflake sql api returned error code={code} message={}",
1119                response.message.as_deref().unwrap_or("<no message>")
1120            );
1121        }
1122    }
1123
1124    let iter = response
1125        .data
1126        .iter()
1127        .enumerate()
1128        .filter_map(move |(row_index, data_row)| {
1129            let get_string = |name: &str| -> Option<String> {
1130                let index = column_indices.get(name).copied()?;
1131                match data_row.get(index)? {
1132                    JsonValue::String(s) => Some(s.clone()),
1133                    JsonValue::Null => None,
1134                    other => Some(other.to_string()),
1135                }
1136            };
1137
1138            let get_json = |name: &str| -> Option<JsonValue> {
1139                let index = column_indices.get(name).copied()?;
1140                let value = data_row.get(index)?;
1141                if value.is_null() {
1142                    return None;
1143                }
1144                match value {
1145                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1146                    other => Some(other.clone()),
1147                }
1148            };
1149
1150            let request_id_str = get_string("request_id");
1151            let device_id = get_string("device_id");
1152            let time = get_string("time");
1153            let input_json = get_json("input");
1154            let input: Option<ZetaPromptInput> =
1155                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1156            let zed_version = get_string("zed_version");
1157
1158            match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1159                (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1160                    Some(build_example_from_snowflake(
1161                        request_id,
1162                        device_id,
1163                        time,
1164                        input,
1165                        vec!["requested".to_string()],
1166                        None,
1167                        zed_version,
1168                    ))
1169                }
1170                _ => {
1171                    log::warn!(
1172                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1173                        request_id_str.is_some(),
1174                        device_id.is_some(),
1175                        time.is_some(),
1176                        input_json.is_some(),
1177                    );
1178                    None
1179                }
1180            }
1181        });
1182
1183    Ok(iter)
1184}
1185
1186fn rejected_examples_from_response<'a>(
1187    response: &'a SnowflakeStatementResponse,
1188    column_indices: &'a std::collections::HashMap<String, usize>,
1189) -> Result<impl Iterator<Item = Example> + 'a> {
1190    if let Some(code) = &response.code {
1191        if code != SNOWFLAKE_SUCCESS_CODE {
1192            anyhow::bail!(
1193                "snowflake sql api returned error code={code} message={}",
1194                response.message.as_deref().unwrap_or("<no message>")
1195            );
1196        }
1197    }
1198
1199    let iter = response
1200        .data
1201        .iter()
1202        .enumerate()
1203        .filter_map(move |(row_index, data_row)| {
1204            let get_string = |name: &str| -> Option<String> {
1205                let index = column_indices.get(name).copied()?;
1206                match data_row.get(index)? {
1207                    JsonValue::String(s) => Some(s.clone()),
1208                    JsonValue::Null => None,
1209                    other => Some(other.to_string()),
1210                }
1211            };
1212
1213            let get_json = |name: &str| -> Option<JsonValue> {
1214                let index = column_indices.get(name).copied()?;
1215                let value = data_row.get(index)?;
1216                if value.is_null() {
1217                    return None;
1218                }
1219                match value {
1220                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1221                    other => Some(other.clone()),
1222                }
1223            };
1224
1225            let get_bool = |name: &str| -> Option<bool> {
1226                let index = column_indices.get(name).copied()?;
1227                match data_row.get(index)? {
1228                    JsonValue::Bool(b) => Some(*b),
1229                    JsonValue::String(s) => s.parse().ok(),
1230                    _ => None,
1231                }
1232            };
1233
1234            let request_id_str = get_string("request_id");
1235            let device_id = get_string("device_id");
1236            let time = get_string("time");
1237            let input_json = get_json("input");
1238            let input: Option<ZetaPromptInput> =
1239                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1240            let output = get_string("output");
1241            let was_shown = get_bool("was_shown");
1242            let reason = get_string("reason");
1243            let zed_version = get_string("zed_version");
1244
1245            match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1246                (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1247                    Some(build_rejected_example(
1248                        request_id,
1249                        device_id,
1250                        time,
1251                        input,
1252                        output,
1253                        was_shown,
1254                        reason,
1255                        zed_version,
1256                    ))
1257                }
1258                _ => {
1259                    log::warn!(
1260                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1261                        request_id_str.is_some(),
1262                        device_id.is_some(),
1263                        time.is_some(),
1264                        input_json.is_some(),
1265                        output.is_some(),
1266                        was_shown.is_some(),
1267                        reason.is_some()
1268                    );
1269                    None
1270                }
1271            }
1272        });
1273
1274    Ok(iter)
1275}
1276
1277fn build_rejected_example(
1278    request_id: String,
1279    device_id: String,
1280    time: String,
1281    input: ZetaPromptInput,
1282    output: String,
1283    was_shown: bool,
1284    reason: String,
1285    zed_version: Option<String>,
1286) -> Example {
1287    let rejected_patch = build_output_patch(
1288        &input.cursor_path,
1289        input.cursor_excerpt.as_ref(),
1290        &input.editable_range_in_excerpt,
1291        &output,
1292    );
1293    let mut example = build_example_from_snowflake(
1294        request_id,
1295        device_id,
1296        time,
1297        input,
1298        vec![format!("rejection:{}", reason.to_lowercase())],
1299        Some(RejectionInfo { reason, was_shown }),
1300        zed_version,
1301    );
1302    example.spec.rejected_patch = Some(rejected_patch);
1303    example
1304}
1305
1306struct RejectionInfo {
1307    reason: String,
1308    was_shown: bool,
1309}
1310
1311fn build_example_from_snowflake(
1312    request_id: String,
1313    device_id: String,
1314    time: String,
1315    input: ZetaPromptInput,
1316    tags: Vec<String>,
1317    rejection: Option<RejectionInfo>,
1318    zed_version: Option<String>,
1319) -> Example {
1320    let events: Vec<CapturedEvent> = input
1321        .events
1322        .iter()
1323        .map(|event| match event.as_ref() {
1324            zeta_prompt::Event::BufferChange {
1325                path,
1326                old_path,
1327                diff,
1328                predicted,
1329                in_open_source_repo,
1330            } => CapturedEvent {
1331                path: path.clone(),
1332                old_path: old_path.clone(),
1333                diff: diff.clone(),
1334                predicted: *predicted,
1335                in_open_source_repo: *in_open_source_repo,
1336            },
1337        })
1338        .collect();
1339
1340    let related_files: Vec<CapturedRelatedFile> = input
1341        .related_files
1342        .iter()
1343        .map(|rf| CapturedRelatedFile {
1344            path: rf.path.clone(),
1345            max_row: rf.max_row,
1346            excerpts: rf
1347                .excerpts
1348                .iter()
1349                .map(|e| CapturedRelatedExcerpt {
1350                    row_range: e.row_range.clone(),
1351                    text: e.text.to_string(),
1352                })
1353                .collect(),
1354        })
1355        .collect();
1356
1357    let cursor_excerpt = input.cursor_excerpt.as_ref();
1358    let cursor_offset = input.cursor_offset_in_excerpt;
1359
1360    let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
1361
1362    let mut edit_history = String::new();
1363    for event in &input.events {
1364        zeta_prompt::write_event(&mut edit_history, event);
1365        edit_history.push('\n');
1366    }
1367
1368    let (rejection_reason, was_shown) = match &rejection {
1369        Some(r) => (r.reason.clone(), r.was_shown),
1370        None => (String::new(), false),
1371    };
1372
1373    let spec = ExampleSpec {
1374        name: request_id.clone(),
1375        repository_url: String::new(),
1376        revision: String::new(),
1377        tags,
1378        reasoning: None,
1379        uncommitted_diff: String::new(),
1380        cursor_path: input.cursor_path.clone(),
1381        cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1382        edit_history,
1383        expected_patches: Vec::new(),
1384        rejected_patch: None,
1385        captured_prompt_input: Some(CapturedPromptInput {
1386            cursor_file_content: cursor_excerpt.to_string(),
1387            cursor_offset,
1388            cursor_row,
1389            cursor_column,
1390            excerpt_start_row: None,
1391            events,
1392            related_files,
1393            in_open_source_repo: input.in_open_source_repo,
1394            zed_version,
1395        }),
1396        telemetry: Some(TelemetrySource {
1397            request_id,
1398            device_id,
1399            time,
1400            rejection_reason,
1401            was_shown,
1402        }),
1403        human_feedback: Vec::new(),
1404        rating: None,
1405    };
1406
1407    Example {
1408        spec,
1409        prompt_inputs: None,
1410        prompt: None,
1411        predictions: Vec::new(),
1412        score: Vec::new(),
1413        qa: Vec::new(),
1414        state: None,
1415    }
1416}
1417
1418fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
1419    let mut row = 0u32;
1420    let mut last_newline_offset = 0;
1421    for (i, c) in text.char_indices() {
1422        if i >= offset {
1423            break;
1424        }
1425        if c == '\n' {
1426            row += 1;
1427            last_newline_offset = i + 1;
1428        }
1429    }
1430    let column = (offset - last_newline_offset) as u32;
1431    (row, column)
1432}
1433
1434fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1435    let before = &excerpt[..cursor_offset.min(excerpt.len())];
1436    let after = &excerpt[cursor_offset.min(excerpt.len())..];
1437    format!("{}[CURSOR_POSITION]{}", before, after)
1438}
1439
1440fn build_output_patch(
1441    cursor_path: &std::path::Path,
1442    cursor_excerpt: &str,
1443    editable_range: &std::ops::Range<usize>,
1444    model_output: &str,
1445) -> String {
1446    let old_text = &cursor_excerpt[editable_range.clone()];
1447
1448    let editable_start_row = cursor_excerpt[..editable_range.start]
1449        .chars()
1450        .filter(|&c| c == '\n')
1451        .count() as u32;
1452
1453    let diff_body = language::unified_diff_with_offsets(
1454        old_text,
1455        model_output,
1456        editable_start_row,
1457        editable_start_row,
1458    );
1459
1460    let mut patch = String::new();
1461    writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1462    writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1463    patch.push_str(&diff_body);
1464    patch
1465}
1466
1467pub(crate) fn get_column_indices(
1468    meta: &Option<SnowflakeResultSetMetaData>,
1469    names: &[&str],
1470) -> std::collections::HashMap<String, usize> {
1471    let mut indices = std::collections::HashMap::new();
1472    if let Some(meta) = meta {
1473        for (index, col) in meta.row_type.iter().enumerate() {
1474            for &name in names {
1475                if col.name.eq_ignore_ascii_case(name) {
1476                    indices.insert(name.to_string(), index);
1477                }
1478            }
1479        }
1480    }
1481    indices
1482}