pull_examples.rs

   1use anyhow::{Context as _, Result};
   2use flate2::read::GzDecoder;
   3use gpui::BackgroundExecutor;
   4use http_client::{AsyncBody, HttpClient, Method, Request};
   5use indoc::indoc;
   6use serde::Deserialize;
   7use serde_json::{Value as JsonValue, json};
   8use std::io::Read;
   9use std::sync::Arc;
  10use std::time::Duration;
  11use telemetry_events::EditPredictionRating;
  12
  13use zeta_prompt::ZetaPromptInput;
  14
  15use crate::example::Example;
  16use crate::progress::{InfoStyle, Progress, Step};
  17use crate::sync_deployments::EDIT_PREDICTION_DEPLOYMENT_EVENT;
  18use edit_prediction::example_spec::{
  19    CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
  20    TelemetrySource,
  21};
  22use std::fmt::Write as _;
  23
  24pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
  25pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
  26const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
  27const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
  28const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
  29const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
  30
  31/// Minimum Zed version for filtering captured examples.
  32/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
  33/// where `zed_version >= 0.224.1`.
  34#[derive(Clone, Copy, Debug)]
  35pub struct MinCaptureVersion {
  36    pub minor: u32,
  37    pub patch: u32,
  38}
  39
  40const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
  41pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
  42pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
  43
  44/// Parse an input token of the form `captured-after:{timestamp}`.
  45pub fn parse_captured_after_input(input: &str) -> Option<&str> {
  46    input.strip_prefix("captured-after:")
  47}
  48
  49/// Parse an input token of the form `rejected-after:{timestamp}`.
  50pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
  51    input.strip_prefix("rejected-after:")
  52}
  53
  54/// Parse an input token of the form `requested-after:{timestamp}`.
  55pub fn parse_requested_after_input(input: &str) -> Option<&str> {
  56    input.strip_prefix("requested-after:")
  57}
  58
  59/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
  60/// or `rated-negative-after:{timestamp}`.
  61/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
  62pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
  63    if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
  64        Some((timestamp, Some(EditPredictionRating::Positive)))
  65    } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
  66        Some((timestamp, Some(EditPredictionRating::Negative)))
  67    } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
  68        Some((timestamp, None))
  69    } else {
  70        None
  71    }
  72}
  73
  74pub async fn fetch_captured_examples_after(
  75    http_client: Arc<dyn HttpClient>,
  76    after_timestamps: &[String],
  77    max_rows_per_timestamp: usize,
  78    offset: usize,
  79    background_executor: BackgroundExecutor,
  80    _min_capture_version: Option<MinCaptureVersion>,
  81) -> Result<Vec<Example>> {
  82    if after_timestamps.is_empty() {
  83        return Ok(Vec::new());
  84    }
  85
  86    let progress = Progress::global();
  87
  88    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
  89        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
  90    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
  91        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
  92    )?;
  93    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
  94
  95    let mut all_examples = Vec::new();
  96
  97    for after_date in after_timestamps.iter() {
  98        let step_progress_name = format!(">{after_date}");
  99        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 100        step_progress.set_substatus("querying");
 101
 102        let statement = indoc! {r#"
 103            SELECT
 104                event_properties:example AS example
 105            FROM events
 106            WHERE event_type = ?
 107                AND time > TRY_TO_TIMESTAMP_NTZ(?)
 108                AND event_properties:can_collect_data = true
 109            ORDER BY time ASC
 110            LIMIT ?
 111            OFFSET ?
 112        "#};
 113
 114        let request = json!({
 115            "statement": statement,
 116            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 117            "database": "EVENTS",
 118            "schema": "PUBLIC",
 119            "warehouse": "DBT",
 120            "role": role,
 121            "bindings": {
 122                "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
 123                "2": { "type": "TEXT", "value": after_date },
 124                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 125                "4": { "type": "FIXED", "value": offset.to_string() }
 126            }
 127        });
 128
 129        let response = run_sql_with_polling(
 130            http_client.clone(),
 131            &base_url,
 132            &token,
 133            &request,
 134            &step_progress,
 135            background_executor.clone(),
 136        )
 137        .await?;
 138
 139        let total_rows = response
 140            .result_set_meta_data
 141            .as_ref()
 142            .and_then(|m| m.num_rows)
 143            .unwrap_or(response.data.len() as i64);
 144
 145        let num_partitions = response
 146            .result_set_meta_data
 147            .as_ref()
 148            .map(|m| m.partition_info.len())
 149            .unwrap_or(1)
 150            .max(1);
 151
 152        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 153        step_progress.set_substatus("parsing");
 154
 155        let example_index = response
 156            .result_set_meta_data
 157            .as_ref()
 158            .and_then(|m| {
 159                m.row_type.iter().enumerate().find_map(|(index, col)| {
 160                    if col.name.eq_ignore_ascii_case("example") {
 161                        Some(index)
 162                    } else {
 163                        None
 164                    }
 165                })
 166            })
 167            .unwrap_or(0);
 168
 169        all_examples.extend(examples_from_response(&response, example_index)?);
 170
 171        if num_partitions > 1 {
 172            let statement_handle = response
 173                .statement_handle
 174                .as_ref()
 175                .context("response has multiple partitions but no statementHandle")?;
 176
 177            for partition in 1..num_partitions {
 178                step_progress.set_substatus(format!(
 179                    "fetching partition {}/{}",
 180                    partition + 1,
 181                    num_partitions
 182                ));
 183
 184                let partition_response = fetch_partition(
 185                    http_client.clone(),
 186                    &base_url,
 187                    &token,
 188                    statement_handle,
 189                    partition,
 190                )
 191                .await?;
 192
 193                all_examples.extend(examples_from_response(&partition_response, example_index)?);
 194            }
 195        }
 196
 197        step_progress.set_substatus("done");
 198    }
 199
 200    Ok(all_examples)
 201}
 202
 203#[derive(Debug, Clone, Deserialize)]
 204#[serde(rename_all = "camelCase")]
 205pub(crate) struct SnowflakeStatementResponse {
 206    #[serde(default)]
 207    pub(crate) data: Vec<Vec<JsonValue>>,
 208    #[serde(default)]
 209    pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
 210    #[serde(default)]
 211    pub(crate) code: Option<String>,
 212    #[serde(default)]
 213    pub(crate) message: Option<String>,
 214    #[serde(default)]
 215    pub(crate) statement_handle: Option<String>,
 216}
 217
 218#[derive(Debug, Clone, Deserialize)]
 219#[serde(rename_all = "camelCase")]
 220pub(crate) struct SnowflakeResultSetMetaData {
 221    #[serde(default, rename = "rowType")]
 222    row_type: Vec<SnowflakeColumnMeta>,
 223    #[serde(default)]
 224    num_rows: Option<i64>,
 225    #[serde(default)]
 226    partition_info: Vec<SnowflakePartitionInfo>,
 227}
 228
 229#[derive(Debug, Clone, Deserialize)]
 230#[serde(rename_all = "camelCase")]
 231struct SnowflakePartitionInfo {}
 232
 233#[derive(Debug, Clone, Deserialize)]
 234struct SnowflakeColumnMeta {
 235    #[serde(default)]
 236    name: String,
 237}
 238
 239fn examples_from_response(
 240    response: &SnowflakeStatementResponse,
 241    example_index: usize,
 242) -> Result<impl Iterator<Item = Example> + '_> {
 243    if let Some(code) = &response.code {
 244        if code != SNOWFLAKE_SUCCESS_CODE {
 245            anyhow::bail!(
 246                "snowflake sql api returned error code={code} message={}",
 247                response.message.as_deref().unwrap_or("<no message>")
 248            );
 249        }
 250    }
 251
 252    let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
 253        let Some(example_value) = data_row.get(example_index) else {
 254            return None;
 255        };
 256        if example_value.is_null() {
 257            return None;
 258        }
 259
 260        let parse_result = match example_value {
 261            JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
 262            _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
 263        };
 264
 265        match parse_result {
 266            Ok(spec) => Some(Example {
 267                spec,
 268                prompt_inputs: None,
 269                prompt: None,
 270                predictions: Vec::new(),
 271                score: Vec::new(),
 272                qa: Vec::new(),
 273                state: None,
 274            }),
 275            Err(error) => {
 276                let raw_json = serde_json::to_string_pretty(example_value)
 277                    .unwrap_or_else(|_| "<failed to serialize json>".to_string());
 278                log::error!(
 279                    "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
 280                );
 281                None
 282            }
 283        }
 284    });
 285
 286    Ok(iter)
 287}
 288
 289async fn run_sql_with_polling(
 290    http_client: Arc<dyn HttpClient>,
 291    base_url: &str,
 292    token: &str,
 293    request: &serde_json::Value,
 294    step_progress: &crate::progress::StepProgress,
 295    background_executor: BackgroundExecutor,
 296) -> Result<SnowflakeStatementResponse> {
 297    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
 298
 299    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 300        let statement_handle = response
 301            .statement_handle
 302            .as_ref()
 303            .context("async query response missing statementHandle")?
 304            .clone();
 305
 306        for attempt in 1..=MAX_POLL_ATTEMPTS {
 307            step_progress.set_substatus(format!("polling ({attempt})"));
 308
 309            background_executor.timer(POLL_INTERVAL).await;
 310
 311            response =
 312                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
 313
 314            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 315                break;
 316            }
 317        }
 318
 319        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 320            anyhow::bail!(
 321                "query still running after {} poll attempts ({} seconds)",
 322                MAX_POLL_ATTEMPTS,
 323                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
 324            );
 325        }
 326    }
 327
 328    Ok(response)
 329}
 330
 331pub(crate) async fn fetch_partition(
 332    http_client: Arc<dyn HttpClient>,
 333    base_url: &str,
 334    token: &str,
 335    statement_handle: &str,
 336    partition: usize,
 337) -> Result<SnowflakeStatementResponse> {
 338    let url = format!(
 339        "{}/api/v2/statements/{}?partition={}",
 340        base_url.trim_end_matches('/'),
 341        statement_handle,
 342        partition
 343    );
 344
 345    let http_request = Request::builder()
 346        .method(Method::GET)
 347        .uri(url.as_str())
 348        .header("Authorization", format!("Bearer {token}"))
 349        .header(
 350            "X-Snowflake-Authorization-Token-Type",
 351            "PROGRAMMATIC_ACCESS_TOKEN",
 352        )
 353        .header("Accept", "application/json")
 354        .header("Accept-Encoding", "gzip")
 355        .header("User-Agent", "edit_prediction_cli")
 356        .body(AsyncBody::empty())?;
 357
 358    let response = http_client
 359        .send(http_request)
 360        .await
 361        .context("failed to send partition request to Snowflake SQL API")?;
 362
 363    let status = response.status();
 364    let content_encoding = response
 365        .headers()
 366        .get("content-encoding")
 367        .and_then(|v| v.to_str().ok())
 368        .map(|s| s.to_lowercase());
 369
 370    let body_bytes = {
 371        use futures::AsyncReadExt as _;
 372
 373        let mut body = response.into_body();
 374        let mut bytes = Vec::new();
 375        body.read_to_end(&mut bytes)
 376            .await
 377            .context("failed to read Snowflake SQL API partition response body")?;
 378        bytes
 379    };
 380
 381    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
 382        let mut decoder = GzDecoder::new(&body_bytes[..]);
 383        let mut decompressed = Vec::new();
 384        decoder
 385            .read_to_end(&mut decompressed)
 386            .context("failed to decompress gzip response")?;
 387        decompressed
 388    } else {
 389        body_bytes
 390    };
 391
 392    if !status.is_success() && status.as_u16() != 202 {
 393        let body_text = String::from_utf8_lossy(&body_bytes);
 394        anyhow::bail!(
 395            "snowflake sql api partition request http {}: {}",
 396            status.as_u16(),
 397            body_text
 398        );
 399    }
 400
 401    if body_bytes.is_empty() {
 402        anyhow::bail!(
 403            "snowflake sql api partition {} returned empty response body (http {})",
 404            partition,
 405            status.as_u16()
 406        );
 407    }
 408
 409    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
 410        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
 411        format!(
 412            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
 413            partition,
 414            status.as_u16(),
 415            body_preview
 416        )
 417    })
 418}
 419
 420pub(crate) async fn run_sql(
 421    http_client: Arc<dyn HttpClient>,
 422    base_url: &str,
 423    token: &str,
 424    request: &serde_json::Value,
 425) -> Result<SnowflakeStatementResponse> {
 426    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
 427
 428    let request_body =
 429        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
 430
 431    let http_request = Request::builder()
 432        .method(Method::POST)
 433        .uri(url.as_str())
 434        .header("Authorization", format!("Bearer {token}"))
 435        .header(
 436            "X-Snowflake-Authorization-Token-Type",
 437            "PROGRAMMATIC_ACCESS_TOKEN",
 438        )
 439        .header("Content-Type", "application/json")
 440        .header("Accept", "application/json")
 441        .header("User-Agent", "edit_prediction_cli")
 442        .body(AsyncBody::from(request_body.clone()))?;
 443
 444    let response = http_client
 445        .send(http_request)
 446        .await
 447        .context("failed to send request to Snowflake SQL API")?;
 448
 449    let status = response.status();
 450    let body_bytes = {
 451        use futures::AsyncReadExt as _;
 452
 453        let mut body = response.into_body();
 454        let mut bytes = Vec::new();
 455        body.read_to_end(&mut bytes)
 456            .await
 457            .context("failed to read Snowflake SQL API response body")?;
 458        bytes
 459    };
 460
 461    if !status.is_success() && status.as_u16() != 202 {
 462        let body_text = String::from_utf8_lossy(&body_bytes);
 463        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
 464    }
 465
 466    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
 467        .context("failed to parse Snowflake SQL API response JSON")
 468}
 469
 470pub async fn fetch_rejected_examples_after(
 471    http_client: Arc<dyn HttpClient>,
 472    after_timestamps: &[String],
 473    max_rows_per_timestamp: usize,
 474    offset: usize,
 475    background_executor: BackgroundExecutor,
 476    min_capture_version: Option<MinCaptureVersion>,
 477) -> Result<Vec<Example>> {
 478    if after_timestamps.is_empty() {
 479        return Ok(Vec::new());
 480    }
 481
 482    let progress = Progress::global();
 483
 484    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 485        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 486    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 487        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 488    )?;
 489    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 490
 491    let mut all_examples = Vec::new();
 492
 493    for after_date in after_timestamps.iter() {
 494        let step_progress_name = format!("rejected>{after_date}");
 495        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 496        step_progress.set_substatus("querying");
 497
 498        // Join rejected events with their corresponding request events to get the full context.
 499        // We filter for V3 sampling data which contains the structured input we need.
 500        // We also filter for predictions that were actually shown to the user (was_shown = true)
 501        // to focus on explicit user rejections rather than implicit cancellations.
 502        let statement = indoc! {r#"
 503            SELECT
 504                req.event_properties:request_id::string AS request_id,
 505                req.device_id::string AS device_id,
 506                req.time::string AS time,
 507                req.event_properties:input AS input,
 508                req.event_properties:prompt::string AS prompt,
 509                req.event_properties:output::string AS output,
 510                rej.event_properties:was_shown::boolean AS was_shown,
 511                rej.event_properties:reason::string AS reason,
 512                req.event_properties:zed_version::string AS zed_version
 513            FROM events req
 514            INNER JOIN events rej
 515                ON req.event_properties:request_id = rej.event_properties:request_id
 516            WHERE req.event_type = ?
 517                AND rej.event_type = ?
 518                AND req.event_properties:version = 'V3'
 519                AND rej.event_properties:was_shown = true
 520                AND req.event_properties:input:can_collect_data = true
 521                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 522                AND (? IS NULL OR (
 523                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 524                    OR (
 525                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 526                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 527                    )
 528                ))
 529            ORDER BY req.time ASC
 530            LIMIT ?
 531            OFFSET ?
 532        "#};
 533
 534        let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
 535        let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
 536        let min_minor_str_ref = min_minor_str.as_deref();
 537        let min_patch_str_ref = min_patch_str.as_deref();
 538        let request = json!({
 539            "statement": statement,
 540            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 541            "database": "EVENTS",
 542            "schema": "PUBLIC",
 543            "warehouse": "DBT",
 544            "role": role,
 545            "bindings": {
 546                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 547                "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
 548                "3": { "type": "TEXT", "value": after_date },
 549                "4": { "type": "FIXED", "value": min_minor_str_ref },
 550                "5": { "type": "FIXED", "value": min_minor_str_ref },
 551                "6": { "type": "FIXED", "value": min_minor_str_ref },
 552                "7": { "type": "FIXED", "value": min_patch_str_ref },
 553                "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 554                "9": { "type": "FIXED", "value": offset.to_string() }
 555            }
 556        });
 557
 558        let response = run_sql_with_polling(
 559            http_client.clone(),
 560            &base_url,
 561            &token,
 562            &request,
 563            &step_progress,
 564            background_executor.clone(),
 565        )
 566        .await?;
 567
 568        let total_rows = response
 569            .result_set_meta_data
 570            .as_ref()
 571            .and_then(|m| m.num_rows)
 572            .unwrap_or(response.data.len() as i64);
 573
 574        let num_partitions = response
 575            .result_set_meta_data
 576            .as_ref()
 577            .map(|m| m.partition_info.len())
 578            .unwrap_or(1)
 579            .max(1);
 580
 581        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 582        step_progress.set_substatus("parsing");
 583
 584        let column_indices = get_column_indices(
 585            &response.result_set_meta_data,
 586            &[
 587                "request_id",
 588                "device_id",
 589                "time",
 590                "input",
 591                "prompt",
 592                "output",
 593                "was_shown",
 594                "reason",
 595                "zed_version",
 596            ],
 597        );
 598
 599        all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
 600
 601        if num_partitions > 1 {
 602            let statement_handle = response
 603                .statement_handle
 604                .as_ref()
 605                .context("response has multiple partitions but no statementHandle")?;
 606
 607            for partition in 1..num_partitions {
 608                step_progress.set_substatus(format!(
 609                    "fetching partition {}/{}",
 610                    partition + 1,
 611                    num_partitions
 612                ));
 613
 614                let partition_response = fetch_partition(
 615                    http_client.clone(),
 616                    &base_url,
 617                    &token,
 618                    statement_handle,
 619                    partition,
 620                )
 621                .await?;
 622
 623                all_examples.extend(rejected_examples_from_response(
 624                    &partition_response,
 625                    &column_indices,
 626                )?);
 627            }
 628        }
 629
 630        step_progress.set_substatus("done");
 631    }
 632
 633    Ok(all_examples)
 634}
 635
 636pub async fn fetch_requested_examples_after(
 637    http_client: Arc<dyn HttpClient>,
 638    after_timestamps: &[String],
 639    max_rows_per_timestamp: usize,
 640    offset: usize,
 641    background_executor: BackgroundExecutor,
 642    min_capture_version: Option<MinCaptureVersion>,
 643) -> Result<Vec<Example>> {
 644    if after_timestamps.is_empty() {
 645        return Ok(Vec::new());
 646    }
 647
 648    let progress = Progress::global();
 649
 650    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 651        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 652    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 653        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 654    )?;
 655    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 656
 657    let mut all_examples = Vec::new();
 658
 659    for after_date in after_timestamps.iter() {
 660        let step_progress_name = format!("requested>{after_date}");
 661        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 662        step_progress.set_substatus("querying");
 663
 664        let statement = indoc! {r#"
 665            SELECT
 666                req.event_properties:request_id::string AS request_id,
 667                req.device_id::string AS device_id,
 668                req.time::string AS time,
 669                req.event_properties:input AS input,
 670                req.event_properties:zed_version::string AS zed_version
 671            FROM events req
 672            WHERE req.event_type = ?
 673                AND req.event_properties:version = 'V3'
 674                AND req.event_properties:input:can_collect_data = true
 675                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 676                AND (? IS NULL OR (
 677                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 678                    OR (
 679                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 680                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 681                    )
 682                ))
 683            ORDER BY req.time ASC
 684            LIMIT ?
 685            OFFSET ?
 686        "#};
 687
 688        let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
 689        let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
 690        let min_minor_str_ref = min_minor_str.as_deref();
 691        let min_patch_str_ref = min_patch_str.as_deref();
 692        let request = json!({
 693            "statement": statement,
 694            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 695            "database": "EVENTS",
 696            "schema": "PUBLIC",
 697            "warehouse": "DBT",
 698            "role": role,
 699            "bindings": {
 700                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 701                "2": { "type": "TEXT", "value": after_date },
 702                "3": { "type": "FIXED", "value": min_minor_str_ref },
 703                "4": { "type": "FIXED", "value": min_minor_str_ref },
 704                "5": { "type": "FIXED", "value": min_minor_str_ref },
 705                "6": { "type": "FIXED", "value": min_patch_str_ref },
 706                "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 707                "8": { "type": "FIXED", "value": offset.to_string() }
 708            }
 709        });
 710
 711        let response = run_sql_with_polling(
 712            http_client.clone(),
 713            &base_url,
 714            &token,
 715            &request,
 716            &step_progress,
 717            background_executor.clone(),
 718        )
 719        .await?;
 720
 721        let total_rows = response
 722            .result_set_meta_data
 723            .as_ref()
 724            .and_then(|m| m.num_rows)
 725            .unwrap_or(response.data.len() as i64);
 726
 727        let num_partitions = response
 728            .result_set_meta_data
 729            .as_ref()
 730            .map(|m| m.partition_info.len())
 731            .unwrap_or(1)
 732            .max(1);
 733
 734        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 735        step_progress.set_substatus("parsing");
 736
 737        let column_indices = get_column_indices(
 738            &response.result_set_meta_data,
 739            &["request_id", "device_id", "time", "input", "zed_version"],
 740        );
 741
 742        all_examples.extend(requested_examples_from_response(
 743            &response,
 744            &column_indices,
 745        )?);
 746
 747        if num_partitions > 1 {
 748            let statement_handle = response
 749                .statement_handle
 750                .as_ref()
 751                .context("response has multiple partitions but no statementHandle")?;
 752
 753            for partition in 1..num_partitions {
 754                step_progress.set_substatus(format!(
 755                    "fetching partition {}/{}",
 756                    partition + 1,
 757                    num_partitions
 758                ));
 759
 760                let partition_response = fetch_partition(
 761                    http_client.clone(),
 762                    &base_url,
 763                    &token,
 764                    statement_handle,
 765                    partition,
 766                )
 767                .await?;
 768
 769                all_examples.extend(requested_examples_from_response(
 770                    &partition_response,
 771                    &column_indices,
 772                )?);
 773            }
 774        }
 775
 776        step_progress.set_substatus("done");
 777    }
 778
 779    Ok(all_examples)
 780}
 781
 782pub async fn fetch_rated_examples_after(
 783    http_client: Arc<dyn HttpClient>,
 784    inputs: &[(String, Option<EditPredictionRating>)],
 785    max_rows_per_timestamp: usize,
 786    offset: usize,
 787    background_executor: BackgroundExecutor,
 788    min_capture_version: Option<MinCaptureVersion>,
 789) -> Result<Vec<Example>> {
 790    if inputs.is_empty() {
 791        return Ok(Vec::new());
 792    }
 793
 794    let progress = Progress::global();
 795
 796    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 797        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 798    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 799        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 800    )?;
 801    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 802
 803    let mut all_examples = Vec::new();
 804
 805    for (after_date, rating_filter) in inputs.iter() {
 806        let filter_label = match rating_filter {
 807            None => "",
 808            Some(EditPredictionRating::Positive) => ":positive",
 809            Some(EditPredictionRating::Negative) => ":negative",
 810        };
 811        let step_progress_name = format!("rated{filter_label}>{after_date}");
 812        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 813        step_progress.set_substatus("querying");
 814
 815        let rating_value = rating_filter.as_ref().map(|r| match r {
 816            EditPredictionRating::Positive => "Positive",
 817            EditPredictionRating::Negative => "Negative",
 818        });
 819
 820        let statement = indoc! {r#"
 821            SELECT
 822                rated.event_properties:request_id::string AS request_id,
 823                rated.event_properties:inputs AS inputs,
 824                rated.event_properties:output::string AS output,
 825                rated.event_properties:rating::string AS rating,
 826                rated.event_properties:feedback::string AS feedback,
 827                rated.device_id::string AS device_id,
 828                rated.time::string AS time,
 829                deploy.event_properties:experiment_name::string AS experiment_name,
 830                deploy.event_properties:environment::string AS environment,
 831                rated.event_properties:zed_version::string AS zed_version
 832            FROM events rated
 833            LEFT JOIN events req
 834                ON rated.event_properties:request_id::string = req.event_properties:request_id::string
 835                AND req.event_type = ?
 836            LEFT JOIN events deploy
 837                ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
 838                AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
 839                AND deploy.event_type = ?
 840            WHERE rated.event_type = ?
 841                AND (? IS NULL OR rated.event_properties:rating::string = ?)
 842                AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
 843                AND rated.event_properties:inputs IS NOT NULL
 844                AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
 845                AND rated.event_properties:output IS NOT NULL
 846                AND rated.event_properties:can_collect_data = true
 847                AND (? IS NULL OR (
 848                    TRY_CAST(SPLIT_PART(rated.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 849                    OR (
 850                        TRY_CAST(SPLIT_PART(rated.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 851                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(rated.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 852                    )
 853                ))
 854            ORDER BY rated.time ASC
 855            LIMIT ?
 856            OFFSET ?
 857        "#};
 858
 859        let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
 860        let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
 861        let min_minor_str_ref = min_minor_str.as_deref();
 862        let min_patch_str_ref = min_patch_str.as_deref();
 863        let bindings = json!({
 864            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 865            "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
 866            "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
 867            "4": { "type": "TEXT", "value": rating_value },
 868            "5": { "type": "TEXT", "value": rating_value },
 869            "6": { "type": "TEXT", "value": after_date },
 870            "7": { "type": "FIXED", "value": min_minor_str_ref },
 871            "8": { "type": "FIXED", "value": min_minor_str_ref },
 872            "9": { "type": "FIXED", "value": min_minor_str_ref },
 873            "10": { "type": "FIXED", "value": min_patch_str_ref },
 874            "11": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 875            "12": { "type": "FIXED", "value": offset.to_string() }
 876        });
 877
 878        let request = json!({
 879            "statement": statement,
 880            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 881            "database": "EVENTS",
 882            "schema": "PUBLIC",
 883            "warehouse": "DBT",
 884            "role": role,
 885            "bindings": bindings
 886        });
 887
 888        let response = run_sql_with_polling(
 889            http_client.clone(),
 890            &base_url,
 891            &token,
 892            &request,
 893            &step_progress,
 894            background_executor.clone(),
 895        )
 896        .await?;
 897
 898        let total_rows = response
 899            .result_set_meta_data
 900            .as_ref()
 901            .and_then(|m| m.num_rows)
 902            .unwrap_or(response.data.len() as i64);
 903
 904        let num_partitions = response
 905            .result_set_meta_data
 906            .as_ref()
 907            .map(|m| m.partition_info.len())
 908            .unwrap_or(1)
 909            .max(1);
 910
 911        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 912        step_progress.set_substatus("parsing");
 913
 914        let column_indices = get_column_indices(
 915            &response.result_set_meta_data,
 916            &[
 917                "request_id",
 918                "inputs",
 919                "output",
 920                "rating",
 921                "feedback",
 922                "device_id",
 923                "time",
 924                "experiment_name",
 925                "environment",
 926                "zed_version",
 927            ],
 928        );
 929
 930        all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
 931
 932        if num_partitions > 1 {
 933            let statement_handle = response
 934                .statement_handle
 935                .as_ref()
 936                .context("response has multiple partitions but no statementHandle")?;
 937
 938            for partition in 1..num_partitions {
 939                step_progress.set_substatus(format!(
 940                    "fetching partition {}/{}",
 941                    partition + 1,
 942                    num_partitions
 943                ));
 944
 945                let partition_response = fetch_partition(
 946                    http_client.clone(),
 947                    &base_url,
 948                    &token,
 949                    statement_handle,
 950                    partition,
 951                )
 952                .await?;
 953
 954                all_examples.extend(rated_examples_from_response(
 955                    &partition_response,
 956                    &column_indices,
 957                )?);
 958            }
 959        }
 960
 961        step_progress.set_substatus("done");
 962    }
 963
 964    Ok(all_examples)
 965}
 966
 967fn rated_examples_from_response<'a>(
 968    response: &'a SnowflakeStatementResponse,
 969    column_indices: &'a std::collections::HashMap<String, usize>,
 970) -> Result<impl Iterator<Item = Example> + 'a> {
 971    if let Some(code) = &response.code {
 972        if code != SNOWFLAKE_SUCCESS_CODE {
 973            anyhow::bail!(
 974                "snowflake sql api returned error code={code} message={}",
 975                response.message.as_deref().unwrap_or("<no message>")
 976            );
 977        }
 978    }
 979
 980    let iter = response
 981        .data
 982        .iter()
 983        .enumerate()
 984        .filter_map(move |(row_index, data_row)| {
 985            let get_string = |name: &str| -> Option<String> {
 986                let index = column_indices.get(name).copied()?;
 987                match data_row.get(index)? {
 988                    JsonValue::String(s) => Some(s.clone()),
 989                    JsonValue::Null => None,
 990                    other => Some(other.to_string()),
 991                }
 992            };
 993
 994            let get_json = |name: &str| -> Option<JsonValue> {
 995                let index = column_indices.get(name).copied()?;
 996                let value = data_row.get(index)?;
 997                if value.is_null() {
 998                    return None;
 999                }
1000                match value {
1001                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1002                    other => Some(other.clone()),
1003                }
1004            };
1005
1006            let request_id = get_string("request_id");
1007            let inputs_json = get_json("inputs");
1008            let inputs: Option<ZetaPromptInput> = match &inputs_json {
1009                Some(v) => match serde_json::from_value(v.clone()) {
1010                    Ok(parsed) => Some(parsed),
1011                    Err(e) => {
1012                        log::warn!(
1013                            "skipping row {row_index}: failed to parse inputs - {e}",
1014                        );
1015                        return None;
1016                    }
1017                },
1018                None => None,
1019            };
1020            let output = get_string("output");
1021            let rating = get_string("rating");
1022            let feedback = get_string("feedback").unwrap_or_default();
1023            let device_id = get_string("device_id");
1024            let time = get_string("time");
1025            let experiment_name = get_string("experiment_name");
1026            let environment = get_string("environment");
1027            let zed_version = get_string("zed_version");
1028
1029            match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
1030                (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
1031                    Some(build_rated_example(
1032                        request_id,
1033                        device_id,
1034                        time,
1035                        inputs,
1036                        output,
1037                        rating,
1038                        feedback,
1039                        experiment_name,
1040                        environment,
1041                        zed_version,
1042                    ))
1043                }
1044                _ => {
1045                    log::warn!(
1046                        "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
1047                        inputs_json.is_some(),
1048                        output.is_some(),
1049                        rating.is_some(),
1050                        device_id.is_some(),
1051                        time.is_some(),
1052                    );
1053                    None
1054                }
1055            }
1056        });
1057
1058    Ok(iter)
1059}
1060
1061fn build_rated_example(
1062    request_id: Option<String>,
1063    device_id: String,
1064    time: String,
1065    input: ZetaPromptInput,
1066    output: String,
1067    rating: String,
1068    feedback: String,
1069    experiment_name: Option<String>,
1070    environment: Option<String>,
1071    zed_version: Option<String>,
1072) -> Example {
1073    let parsed_rating = if rating == "Positive" {
1074        EditPredictionRating::Positive
1075    } else {
1076        EditPredictionRating::Negative
1077    };
1078    let is_positive = parsed_rating == EditPredictionRating::Positive;
1079    let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
1080
1081    let mut tags = Vec::with_capacity(3);
1082    tags.push(if is_positive {
1083        "rated:positive".to_string()
1084    } else {
1085        "rated:negative".to_string()
1086    });
1087    if let Some(experiment) = experiment_name {
1088        tags.push(format!("experiment:{experiment}"));
1089    }
1090    if let Some(env) = environment {
1091        tags.push(format!("environment:{env}"));
1092    }
1093
1094    let mut example =
1095        build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
1096
1097    example.spec.rating = Some(parsed_rating);
1098
1099    if !feedback.is_empty() {
1100        example
1101            .spec
1102            .human_feedback
1103            .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
1104    }
1105
1106    if is_positive {
1107        example.spec.expected_patches = vec![output];
1108    } else {
1109        example.spec.rejected_patch = Some(output);
1110    }
1111
1112    example
1113}
1114
1115fn requested_examples_from_response<'a>(
1116    response: &'a SnowflakeStatementResponse,
1117    column_indices: &'a std::collections::HashMap<String, usize>,
1118) -> Result<impl Iterator<Item = Example> + 'a> {
1119    if let Some(code) = &response.code {
1120        if code != SNOWFLAKE_SUCCESS_CODE {
1121            anyhow::bail!(
1122                "snowflake sql api returned error code={code} message={}",
1123                response.message.as_deref().unwrap_or("<no message>")
1124            );
1125        }
1126    }
1127
1128    let iter = response
1129        .data
1130        .iter()
1131        .enumerate()
1132        .filter_map(move |(row_index, data_row)| {
1133            let get_string = |name: &str| -> Option<String> {
1134                let index = column_indices.get(name).copied()?;
1135                match data_row.get(index)? {
1136                    JsonValue::String(s) => Some(s.clone()),
1137                    JsonValue::Null => None,
1138                    other => Some(other.to_string()),
1139                }
1140            };
1141
1142            let get_json = |name: &str| -> Option<JsonValue> {
1143                let index = column_indices.get(name).copied()?;
1144                let value = data_row.get(index)?;
1145                if value.is_null() {
1146                    return None;
1147                }
1148                match value {
1149                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1150                    other => Some(other.clone()),
1151                }
1152            };
1153
1154            let request_id_str = get_string("request_id");
1155            let device_id = get_string("device_id");
1156            let time = get_string("time");
1157            let input_json = get_json("input");
1158            let input: Option<ZetaPromptInput> =
1159                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1160            let zed_version = get_string("zed_version");
1161
1162            match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1163                (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1164                    Some(build_example_from_snowflake(
1165                        request_id,
1166                        device_id,
1167                        time,
1168                        input,
1169                        vec!["requested".to_string()],
1170                        None,
1171                        zed_version,
1172                    ))
1173                }
1174                _ => {
1175                    log::warn!(
1176                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1177                        request_id_str.is_some(),
1178                        device_id.is_some(),
1179                        time.is_some(),
1180                        input_json.is_some(),
1181                    );
1182                    None
1183                }
1184            }
1185        });
1186
1187    Ok(iter)
1188}
1189
1190fn rejected_examples_from_response<'a>(
1191    response: &'a SnowflakeStatementResponse,
1192    column_indices: &'a std::collections::HashMap<String, usize>,
1193) -> Result<impl Iterator<Item = Example> + 'a> {
1194    if let Some(code) = &response.code {
1195        if code != SNOWFLAKE_SUCCESS_CODE {
1196            anyhow::bail!(
1197                "snowflake sql api returned error code={code} message={}",
1198                response.message.as_deref().unwrap_or("<no message>")
1199            );
1200        }
1201    }
1202
1203    let iter = response
1204        .data
1205        .iter()
1206        .enumerate()
1207        .filter_map(move |(row_index, data_row)| {
1208            let get_string = |name: &str| -> Option<String> {
1209                let index = column_indices.get(name).copied()?;
1210                match data_row.get(index)? {
1211                    JsonValue::String(s) => Some(s.clone()),
1212                    JsonValue::Null => None,
1213                    other => Some(other.to_string()),
1214                }
1215            };
1216
1217            let get_json = |name: &str| -> Option<JsonValue> {
1218                let index = column_indices.get(name).copied()?;
1219                let value = data_row.get(index)?;
1220                if value.is_null() {
1221                    return None;
1222                }
1223                match value {
1224                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1225                    other => Some(other.clone()),
1226                }
1227            };
1228
1229            let get_bool = |name: &str| -> Option<bool> {
1230                let index = column_indices.get(name).copied()?;
1231                match data_row.get(index)? {
1232                    JsonValue::Bool(b) => Some(*b),
1233                    JsonValue::String(s) => s.parse().ok(),
1234                    _ => None,
1235                }
1236            };
1237
1238            let request_id_str = get_string("request_id");
1239            let device_id = get_string("device_id");
1240            let time = get_string("time");
1241            let input_json = get_json("input");
1242            let input: Option<ZetaPromptInput> =
1243                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1244            let output = get_string("output");
1245            let was_shown = get_bool("was_shown");
1246            let reason = get_string("reason");
1247            let zed_version = get_string("zed_version");
1248
1249            match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1250                (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1251                    Some(build_rejected_example(
1252                        request_id,
1253                        device_id,
1254                        time,
1255                        input,
1256                        output,
1257                        was_shown,
1258                        reason,
1259                        zed_version,
1260                    ))
1261                }
1262                _ => {
1263                    log::warn!(
1264                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1265                        request_id_str.is_some(),
1266                        device_id.is_some(),
1267                        time.is_some(),
1268                        input_json.is_some(),
1269                        output.is_some(),
1270                        was_shown.is_some(),
1271                        reason.is_some()
1272                    );
1273                    None
1274                }
1275            }
1276        });
1277
1278    Ok(iter)
1279}
1280
1281fn build_rejected_example(
1282    request_id: String,
1283    device_id: String,
1284    time: String,
1285    input: ZetaPromptInput,
1286    output: String,
1287    was_shown: bool,
1288    reason: String,
1289    zed_version: Option<String>,
1290) -> Example {
1291    let rejected_patch = build_output_patch(
1292        &input.cursor_path,
1293        input.cursor_excerpt.as_ref(),
1294        &input.editable_range_in_excerpt,
1295        &output,
1296    );
1297    let mut example = build_example_from_snowflake(
1298        request_id,
1299        device_id,
1300        time,
1301        input,
1302        vec![format!("rejection:{}", reason.to_lowercase())],
1303        Some(RejectionInfo { reason, was_shown }),
1304        zed_version,
1305    );
1306    example.spec.rejected_patch = Some(rejected_patch);
1307    example
1308}
1309
1310struct RejectionInfo {
1311    reason: String,
1312    was_shown: bool,
1313}
1314
1315fn build_example_from_snowflake(
1316    request_id: String,
1317    device_id: String,
1318    time: String,
1319    input: ZetaPromptInput,
1320    tags: Vec<String>,
1321    rejection: Option<RejectionInfo>,
1322    zed_version: Option<String>,
1323) -> Example {
1324    let events: Vec<CapturedEvent> = input
1325        .events
1326        .iter()
1327        .map(|event| match event.as_ref() {
1328            zeta_prompt::Event::BufferChange {
1329                path,
1330                old_path,
1331                diff,
1332                predicted,
1333                in_open_source_repo,
1334            } => CapturedEvent {
1335                path: path.clone(),
1336                old_path: old_path.clone(),
1337                diff: diff.clone(),
1338                predicted: *predicted,
1339                in_open_source_repo: *in_open_source_repo,
1340            },
1341        })
1342        .collect();
1343
1344    let related_files: Vec<CapturedRelatedFile> = input
1345        .related_files
1346        .iter()
1347        .map(|rf| CapturedRelatedFile {
1348            path: rf.path.clone(),
1349            max_row: rf.max_row,
1350            excerpts: rf
1351                .excerpts
1352                .iter()
1353                .map(|e| CapturedRelatedExcerpt {
1354                    row_range: e.row_range.clone(),
1355                    text: e.text.to_string(),
1356                })
1357                .collect(),
1358        })
1359        .collect();
1360
1361    let cursor_excerpt = input.cursor_excerpt.as_ref();
1362    let cursor_offset = input.cursor_offset_in_excerpt;
1363
1364    let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
1365
1366    let mut edit_history = String::new();
1367    for event in &input.events {
1368        zeta_prompt::write_event(&mut edit_history, event);
1369        edit_history.push('\n');
1370    }
1371
1372    let (rejection_reason, was_shown) = match &rejection {
1373        Some(r) => (r.reason.clone(), r.was_shown),
1374        None => (String::new(), false),
1375    };
1376
1377    let spec = ExampleSpec {
1378        name: request_id.clone(),
1379        repository_url: String::new(),
1380        revision: String::new(),
1381        tags,
1382        reasoning: None,
1383        uncommitted_diff: String::new(),
1384        cursor_path: input.cursor_path.clone(),
1385        cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1386        edit_history,
1387        expected_patches: Vec::new(),
1388        rejected_patch: None,
1389        captured_prompt_input: Some(CapturedPromptInput {
1390            cursor_file_content: cursor_excerpt.to_string(),
1391            cursor_offset,
1392            cursor_row,
1393            cursor_column,
1394            excerpt_start_row: None,
1395            events,
1396            related_files,
1397            in_open_source_repo: input.in_open_source_repo,
1398            zed_version,
1399        }),
1400        telemetry: Some(TelemetrySource {
1401            request_id,
1402            device_id,
1403            time,
1404            rejection_reason,
1405            was_shown,
1406        }),
1407        human_feedback: Vec::new(),
1408        rating: None,
1409    };
1410
1411    Example {
1412        spec,
1413        prompt_inputs: None,
1414        prompt: None,
1415        predictions: Vec::new(),
1416        score: Vec::new(),
1417        qa: Vec::new(),
1418        state: None,
1419    }
1420}
1421
1422fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
1423    let mut row = 0u32;
1424    let mut last_newline_offset = 0;
1425    for (i, c) in text.char_indices() {
1426        if i >= offset {
1427            break;
1428        }
1429        if c == '\n' {
1430            row += 1;
1431            last_newline_offset = i + 1;
1432        }
1433    }
1434    let column = (offset - last_newline_offset) as u32;
1435    (row, column)
1436}
1437
1438fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1439    let before = &excerpt[..cursor_offset.min(excerpt.len())];
1440    let after = &excerpt[cursor_offset.min(excerpt.len())..];
1441    format!("{}[CURSOR_POSITION]{}", before, after)
1442}
1443
1444fn build_output_patch(
1445    cursor_path: &std::path::Path,
1446    cursor_excerpt: &str,
1447    editable_range: &std::ops::Range<usize>,
1448    model_output: &str,
1449) -> String {
1450    let old_text = &cursor_excerpt[editable_range.clone()];
1451
1452    let editable_start_row = cursor_excerpt[..editable_range.start]
1453        .chars()
1454        .filter(|&c| c == '\n')
1455        .count() as u32;
1456
1457    let diff_body = language::unified_diff_with_offsets(
1458        old_text,
1459        model_output,
1460        editable_start_row,
1461        editable_start_row,
1462    );
1463
1464    let mut patch = String::new();
1465    writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1466    writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1467    patch.push_str(&diff_body);
1468    patch
1469}
1470
1471pub(crate) fn get_column_indices(
1472    meta: &Option<SnowflakeResultSetMetaData>,
1473    names: &[&str],
1474) -> std::collections::HashMap<String, usize> {
1475    let mut indices = std::collections::HashMap::new();
1476    if let Some(meta) = meta {
1477        for (index, col) in meta.row_type.iter().enumerate() {
1478            for &name in names {
1479                if col.name.eq_ignore_ascii_case(name) {
1480                    indices.insert(name.to_string(), index);
1481                }
1482            }
1483        }
1484    }
1485    indices
1486}