pull_examples.rs

   1use anyhow::{Context as _, Result};
   2use flate2::read::GzDecoder;
   3use gpui::BackgroundExecutor;
   4use http_client::{AsyncBody, HttpClient, Method, Request};
   5use indoc::indoc;
   6use serde::Deserialize;
   7use serde_json::{Value as JsonValue, json};
   8use std::io::Read;
   9use std::sync::Arc;
  10use std::time::Duration;
  11use telemetry_events::EditPredictionRating;
  12
  13use zeta_prompt::ZetaPromptInput;
  14
  15use crate::example::Example;
  16use crate::progress::{InfoStyle, Progress, Step};
  17const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
  18use edit_prediction::example_spec::{
  19    CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
  20    TelemetrySource,
  21};
  22use std::fmt::Write as _;
  23
  24pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
  25pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
  26const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
  27const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
  28const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
  29const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
  30
  31/// Minimum Zed version for filtering captured examples.
  32/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
  33/// where `zed_version >= 0.224.1`.
  34#[derive(Clone, Copy, Debug)]
  35pub struct MinCaptureVersion {
  36    pub minor: u32,
  37    pub patch: u32,
  38}
  39
  40const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
  41pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
  42pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
  43
  44/// Parse an input token of the form `captured-after:{timestamp}`.
  45pub fn parse_captured_after_input(input: &str) -> Option<&str> {
  46    input.strip_prefix("captured-after:")
  47}
  48
  49/// Parse an input token of the form `rejected-after:{timestamp}`.
  50pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
  51    input.strip_prefix("rejected-after:")
  52}
  53
  54/// Parse an input token of the form `requested-after:{timestamp}`.
  55pub fn parse_requested_after_input(input: &str) -> Option<&str> {
  56    input.strip_prefix("requested-after:")
  57}
  58
  59/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
  60/// or `rated-negative-after:{timestamp}`.
  61/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
  62pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
  63    if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
  64        Some((timestamp, Some(EditPredictionRating::Positive)))
  65    } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
  66        Some((timestamp, Some(EditPredictionRating::Negative)))
  67    } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
  68        Some((timestamp, None))
  69    } else {
  70        None
  71    }
  72}
  73
  74pub async fn fetch_captured_examples_after(
  75    http_client: Arc<dyn HttpClient>,
  76    after_timestamps: &[String],
  77    max_rows_per_timestamp: usize,
  78    offset: usize,
  79    background_executor: BackgroundExecutor,
  80    _min_capture_version: Option<MinCaptureVersion>,
  81) -> Result<Vec<Example>> {
  82    if after_timestamps.is_empty() {
  83        return Ok(Vec::new());
  84    }
  85
  86    let progress = Progress::global();
  87
  88    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
  89        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
  90    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
  91        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
  92    )?;
  93    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
  94
  95    let mut all_examples = Vec::new();
  96
  97    for after_date in after_timestamps.iter() {
  98        let step_progress_name = format!(">{after_date}");
  99        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 100        step_progress.set_substatus("querying");
 101
 102        let statement = indoc! {r#"
 103            SELECT
 104                event_properties:example AS example
 105            FROM events
 106            WHERE event_type = ?
 107                AND time > TRY_TO_TIMESTAMP_NTZ(?)
 108                AND event_properties:can_collect_data = true
 109            ORDER BY time ASC
 110            LIMIT ?
 111            OFFSET ?
 112        "#};
 113
 114        let request = json!({
 115            "statement": statement,
 116            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 117            "database": "EVENTS",
 118            "schema": "PUBLIC",
 119            "warehouse": "DBT",
 120            "role": role,
 121            "bindings": {
 122                "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
 123                "2": { "type": "TEXT", "value": after_date },
 124                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 125                "4": { "type": "FIXED", "value": offset.to_string() }
 126            }
 127        });
 128
 129        let response = run_sql_with_polling(
 130            http_client.clone(),
 131            &base_url,
 132            &token,
 133            &request,
 134            &step_progress,
 135            background_executor.clone(),
 136        )
 137        .await?;
 138
 139        let total_rows = response
 140            .result_set_meta_data
 141            .as_ref()
 142            .and_then(|m| m.num_rows)
 143            .unwrap_or(response.data.len() as i64);
 144
 145        let num_partitions = response
 146            .result_set_meta_data
 147            .as_ref()
 148            .map(|m| m.partition_info.len())
 149            .unwrap_or(1)
 150            .max(1);
 151
 152        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 153        step_progress.set_substatus("parsing");
 154
 155        let example_index = response
 156            .result_set_meta_data
 157            .as_ref()
 158            .and_then(|m| {
 159                m.row_type.iter().enumerate().find_map(|(index, col)| {
 160                    if col.name.eq_ignore_ascii_case("example") {
 161                        Some(index)
 162                    } else {
 163                        None
 164                    }
 165                })
 166            })
 167            .unwrap_or(0);
 168
 169        all_examples.extend(examples_from_response(&response, example_index)?);
 170
 171        if num_partitions > 1 {
 172            let statement_handle = response
 173                .statement_handle
 174                .as_ref()
 175                .context("response has multiple partitions but no statementHandle")?;
 176
 177            for partition in 1..num_partitions {
 178                step_progress.set_substatus(format!(
 179                    "fetching partition {}/{}",
 180                    partition + 1,
 181                    num_partitions
 182                ));
 183
 184                let partition_response = fetch_partition(
 185                    http_client.clone(),
 186                    &base_url,
 187                    &token,
 188                    statement_handle,
 189                    partition,
 190                )
 191                .await?;
 192
 193                all_examples.extend(examples_from_response(&partition_response, example_index)?);
 194            }
 195        }
 196
 197        step_progress.set_substatus("done");
 198    }
 199
 200    Ok(all_examples)
 201}
 202
 203#[derive(Debug, Clone, Deserialize)]
 204#[serde(rename_all = "camelCase")]
 205pub(crate) struct SnowflakeStatementResponse {
 206    #[serde(default)]
 207    pub(crate) data: Vec<Vec<JsonValue>>,
 208    #[serde(default)]
 209    pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
 210    #[serde(default)]
 211    pub(crate) code: Option<String>,
 212    #[serde(default)]
 213    pub(crate) message: Option<String>,
 214    #[serde(default)]
 215    pub(crate) statement_handle: Option<String>,
 216}
 217
 218#[derive(Debug, Clone, Deserialize)]
 219#[serde(rename_all = "camelCase")]
 220pub(crate) struct SnowflakeResultSetMetaData {
 221    #[serde(default, rename = "rowType")]
 222    row_type: Vec<SnowflakeColumnMeta>,
 223    #[serde(default)]
 224    num_rows: Option<i64>,
 225    #[serde(default)]
 226    partition_info: Vec<SnowflakePartitionInfo>,
 227}
 228
 229#[derive(Debug, Clone, Deserialize)]
 230#[serde(rename_all = "camelCase")]
 231struct SnowflakePartitionInfo {}
 232
 233#[derive(Debug, Clone, Deserialize)]
 234struct SnowflakeColumnMeta {
 235    #[serde(default)]
 236    name: String,
 237}
 238
 239fn examples_from_response(
 240    response: &SnowflakeStatementResponse,
 241    example_index: usize,
 242) -> Result<impl Iterator<Item = Example> + '_> {
 243    if let Some(code) = &response.code {
 244        if code != SNOWFLAKE_SUCCESS_CODE {
 245            anyhow::bail!(
 246                "snowflake sql api returned error code={code} message={}",
 247                response.message.as_deref().unwrap_or("<no message>")
 248            );
 249        }
 250    }
 251
 252    let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
 253        let Some(example_value) = data_row.get(example_index) else {
 254            return None;
 255        };
 256        if example_value.is_null() {
 257            return None;
 258        }
 259
 260        let parse_result = match example_value {
 261            JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
 262            _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
 263        };
 264
 265        match parse_result {
 266            Ok(spec) => Some(Example {
 267                spec,
 268                prompt_inputs: None,
 269                prompt: None,
 270                predictions: Vec::new(),
 271                score: Vec::new(),
 272                qa: Vec::new(),
 273                state: None,
 274            }),
 275            Err(error) => {
 276                let raw_json = serde_json::to_string_pretty(example_value)
 277                    .unwrap_or_else(|_| "<failed to serialize json>".to_string());
 278                log::error!(
 279                    "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
 280                );
 281                None
 282            }
 283        }
 284    });
 285
 286    Ok(iter)
 287}
 288
 289async fn run_sql_with_polling(
 290    http_client: Arc<dyn HttpClient>,
 291    base_url: &str,
 292    token: &str,
 293    request: &serde_json::Value,
 294    step_progress: &crate::progress::StepProgress,
 295    background_executor: BackgroundExecutor,
 296) -> Result<SnowflakeStatementResponse> {
 297    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
 298
 299    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 300        let statement_handle = response
 301            .statement_handle
 302            .as_ref()
 303            .context("async query response missing statementHandle")?
 304            .clone();
 305
 306        for attempt in 1..=MAX_POLL_ATTEMPTS {
 307            step_progress.set_substatus(format!("polling ({attempt})"));
 308
 309            background_executor.timer(POLL_INTERVAL).await;
 310
 311            response =
 312                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
 313
 314            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 315                break;
 316            }
 317        }
 318
 319        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 320            anyhow::bail!(
 321                "query still running after {} poll attempts ({} seconds)",
 322                MAX_POLL_ATTEMPTS,
 323                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
 324            );
 325        }
 326    }
 327
 328    Ok(response)
 329}
 330
 331pub(crate) async fn fetch_partition(
 332    http_client: Arc<dyn HttpClient>,
 333    base_url: &str,
 334    token: &str,
 335    statement_handle: &str,
 336    partition: usize,
 337) -> Result<SnowflakeStatementResponse> {
 338    let url = format!(
 339        "{}/api/v2/statements/{}?partition={}",
 340        base_url.trim_end_matches('/'),
 341        statement_handle,
 342        partition
 343    );
 344
 345    let http_request = Request::builder()
 346        .method(Method::GET)
 347        .uri(url.as_str())
 348        .header("Authorization", format!("Bearer {token}"))
 349        .header(
 350            "X-Snowflake-Authorization-Token-Type",
 351            "PROGRAMMATIC_ACCESS_TOKEN",
 352        )
 353        .header("Accept", "application/json")
 354        .header("Accept-Encoding", "gzip")
 355        .header("User-Agent", "edit_prediction_cli")
 356        .body(AsyncBody::empty())?;
 357
 358    let response = http_client
 359        .send(http_request)
 360        .await
 361        .context("failed to send partition request to Snowflake SQL API")?;
 362
 363    let status = response.status();
 364    let content_encoding = response
 365        .headers()
 366        .get("content-encoding")
 367        .and_then(|v| v.to_str().ok())
 368        .map(|s| s.to_lowercase());
 369
 370    let body_bytes = {
 371        use futures::AsyncReadExt as _;
 372
 373        let mut body = response.into_body();
 374        let mut bytes = Vec::new();
 375        body.read_to_end(&mut bytes)
 376            .await
 377            .context("failed to read Snowflake SQL API partition response body")?;
 378        bytes
 379    };
 380
 381    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
 382        let mut decoder = GzDecoder::new(&body_bytes[..]);
 383        let mut decompressed = Vec::new();
 384        decoder
 385            .read_to_end(&mut decompressed)
 386            .context("failed to decompress gzip response")?;
 387        decompressed
 388    } else {
 389        body_bytes
 390    };
 391
 392    if !status.is_success() && status.as_u16() != 202 {
 393        let body_text = String::from_utf8_lossy(&body_bytes);
 394        anyhow::bail!(
 395            "snowflake sql api partition request http {}: {}",
 396            status.as_u16(),
 397            body_text
 398        );
 399    }
 400
 401    if body_bytes.is_empty() {
 402        anyhow::bail!(
 403            "snowflake sql api partition {} returned empty response body (http {})",
 404            partition,
 405            status.as_u16()
 406        );
 407    }
 408
 409    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
 410        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
 411        format!(
 412            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
 413            partition,
 414            status.as_u16(),
 415            body_preview
 416        )
 417    })
 418}
 419
 420pub(crate) async fn run_sql(
 421    http_client: Arc<dyn HttpClient>,
 422    base_url: &str,
 423    token: &str,
 424    request: &serde_json::Value,
 425) -> Result<SnowflakeStatementResponse> {
 426    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
 427
 428    let request_body =
 429        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
 430
 431    let http_request = Request::builder()
 432        .method(Method::POST)
 433        .uri(url.as_str())
 434        .header("Authorization", format!("Bearer {token}"))
 435        .header(
 436            "X-Snowflake-Authorization-Token-Type",
 437            "PROGRAMMATIC_ACCESS_TOKEN",
 438        )
 439        .header("Content-Type", "application/json")
 440        .header("Accept", "application/json")
 441        .header("User-Agent", "edit_prediction_cli")
 442        .body(AsyncBody::from(request_body.clone()))?;
 443
 444    let response = http_client
 445        .send(http_request)
 446        .await
 447        .context("failed to send request to Snowflake SQL API")?;
 448
 449    let status = response.status();
 450    let body_bytes = {
 451        use futures::AsyncReadExt as _;
 452
 453        let mut body = response.into_body();
 454        let mut bytes = Vec::new();
 455        body.read_to_end(&mut bytes)
 456            .await
 457            .context("failed to read Snowflake SQL API response body")?;
 458        bytes
 459    };
 460
 461    if !status.is_success() && status.as_u16() != 202 {
 462        let body_text = String::from_utf8_lossy(&body_bytes);
 463        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
 464    }
 465
 466    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
 467        .context("failed to parse Snowflake SQL API response JSON")
 468}
 469
 470pub async fn fetch_rejected_examples_after(
 471    http_client: Arc<dyn HttpClient>,
 472    after_timestamps: &[String],
 473    max_rows_per_timestamp: usize,
 474    offset: usize,
 475    background_executor: BackgroundExecutor,
 476    min_capture_version: Option<MinCaptureVersion>,
 477) -> Result<Vec<Example>> {
 478    if after_timestamps.is_empty() {
 479        return Ok(Vec::new());
 480    }
 481
 482    let progress = Progress::global();
 483
 484    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 485        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 486    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 487        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 488    )?;
 489    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 490
 491    let mut all_examples = Vec::new();
 492
 493    for after_date in after_timestamps.iter() {
 494        let step_progress_name = format!("rejected>{after_date}");
 495        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 496        step_progress.set_substatus("querying");
 497
 498        // Join rejected events with their corresponding request events to get the full context.
 499        // We filter for V3 sampling data which contains the structured input we need.
 500        // We also filter for predictions that were actually shown to the user (was_shown = true)
 501        // to focus on explicit user rejections rather than implicit cancellations.
 502        let statement = indoc! {r#"
 503            SELECT
 504                req.event_properties:request_id::string AS request_id,
 505                req.device_id::string AS device_id,
 506                req.time::string AS time,
 507                req.event_properties:input AS input,
 508                req.event_properties:prompt::string AS prompt,
 509                req.event_properties:output::string AS output,
 510                rej.event_properties:was_shown::boolean AS was_shown,
 511                rej.event_properties:reason::string AS reason,
 512                req.event_properties:zed_version::string AS zed_version
 513            FROM events req
 514            INNER JOIN events rej
 515                ON req.event_properties:request_id = rej.event_properties:request_id
 516            WHERE req.event_type = ?
 517                AND rej.event_type = ?
 518                AND req.event_properties:version = 'V3'
 519                AND rej.event_properties:was_shown = true
 520                AND req.event_properties:input:can_collect_data = true
 521                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 522                AND (? IS NULL OR (
 523                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 524                    OR (
 525                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 526                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 527                    )
 528                ))
 529            ORDER BY req.time ASC
 530            LIMIT ?
 531            OFFSET ?
 532        "#};
 533
 534        let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
 535        let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
 536        let min_minor_str_ref = min_minor_str.as_deref();
 537        let min_patch_str_ref = min_patch_str.as_deref();
 538        let request = json!({
 539            "statement": statement,
 540            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 541            "database": "EVENTS",
 542            "schema": "PUBLIC",
 543            "warehouse": "DBT",
 544            "role": role,
 545            "bindings": {
 546                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 547                "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
 548                "3": { "type": "TEXT", "value": after_date },
 549                "4": { "type": "FIXED", "value": min_minor_str_ref },
 550                "5": { "type": "FIXED", "value": min_minor_str_ref },
 551                "6": { "type": "FIXED", "value": min_minor_str_ref },
 552                "7": { "type": "FIXED", "value": min_patch_str_ref },
 553                "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 554                "9": { "type": "FIXED", "value": offset.to_string() }
 555            }
 556        });
 557
 558        let response = run_sql_with_polling(
 559            http_client.clone(),
 560            &base_url,
 561            &token,
 562            &request,
 563            &step_progress,
 564            background_executor.clone(),
 565        )
 566        .await?;
 567
 568        let total_rows = response
 569            .result_set_meta_data
 570            .as_ref()
 571            .and_then(|m| m.num_rows)
 572            .unwrap_or(response.data.len() as i64);
 573
 574        let num_partitions = response
 575            .result_set_meta_data
 576            .as_ref()
 577            .map(|m| m.partition_info.len())
 578            .unwrap_or(1)
 579            .max(1);
 580
 581        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 582        step_progress.set_substatus("parsing");
 583
 584        let column_indices = get_column_indices(
 585            &response.result_set_meta_data,
 586            &[
 587                "request_id",
 588                "device_id",
 589                "time",
 590                "input",
 591                "prompt",
 592                "output",
 593                "was_shown",
 594                "reason",
 595                "zed_version",
 596            ],
 597        );
 598
 599        all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
 600
 601        if num_partitions > 1 {
 602            let statement_handle = response
 603                .statement_handle
 604                .as_ref()
 605                .context("response has multiple partitions but no statementHandle")?;
 606
 607            for partition in 1..num_partitions {
 608                step_progress.set_substatus(format!(
 609                    "fetching partition {}/{}",
 610                    partition + 1,
 611                    num_partitions
 612                ));
 613
 614                let partition_response = fetch_partition(
 615                    http_client.clone(),
 616                    &base_url,
 617                    &token,
 618                    statement_handle,
 619                    partition,
 620                )
 621                .await?;
 622
 623                all_examples.extend(rejected_examples_from_response(
 624                    &partition_response,
 625                    &column_indices,
 626                )?);
 627            }
 628        }
 629
 630        step_progress.set_substatus("done");
 631    }
 632
 633    Ok(all_examples)
 634}
 635
 636pub async fn fetch_requested_examples_after(
 637    http_client: Arc<dyn HttpClient>,
 638    after_timestamps: &[String],
 639    max_rows_per_timestamp: usize,
 640    offset: usize,
 641    background_executor: BackgroundExecutor,
 642    min_capture_version: Option<MinCaptureVersion>,
 643) -> Result<Vec<Example>> {
 644    if after_timestamps.is_empty() {
 645        return Ok(Vec::new());
 646    }
 647
 648    let progress = Progress::global();
 649
 650    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 651        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 652    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 653        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 654    )?;
 655    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 656
 657    let mut all_examples = Vec::new();
 658
 659    for after_date in after_timestamps.iter() {
 660        let step_progress_name = format!("requested>{after_date}");
 661        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 662        step_progress.set_substatus("querying");
 663
 664        let statement = indoc! {r#"
 665            SELECT
 666                req.event_properties:request_id::string AS request_id,
 667                req.device_id::string AS device_id,
 668                req.time::string AS time,
 669                req.event_properties:input AS input,
 670                req.event_properties:zed_version::string AS zed_version
 671            FROM events req
 672            WHERE req.event_type = ?
 673                AND req.event_properties:version = 'V3'
 674                AND req.event_properties:input:can_collect_data = true
 675                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 676                AND (? IS NULL OR (
 677                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 678                    OR (
 679                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 680                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 681                    )
 682                ))
 683            ORDER BY req.time ASC
 684            LIMIT ?
 685            OFFSET ?
 686        "#};
 687
 688        let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
 689        let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
 690        let min_minor_str_ref = min_minor_str.as_deref();
 691        let min_patch_str_ref = min_patch_str.as_deref();
 692        let request = json!({
 693            "statement": statement,
 694            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 695            "database": "EVENTS",
 696            "schema": "PUBLIC",
 697            "warehouse": "DBT",
 698            "role": role,
 699            "bindings": {
 700                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 701                "2": { "type": "TEXT", "value": after_date },
 702                "3": { "type": "FIXED", "value": min_minor_str_ref },
 703                "4": { "type": "FIXED", "value": min_minor_str_ref },
 704                "5": { "type": "FIXED", "value": min_minor_str_ref },
 705                "6": { "type": "FIXED", "value": min_patch_str_ref },
 706                "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 707                "8": { "type": "FIXED", "value": offset.to_string() }
 708            }
 709        });
 710
 711        let response = run_sql_with_polling(
 712            http_client.clone(),
 713            &base_url,
 714            &token,
 715            &request,
 716            &step_progress,
 717            background_executor.clone(),
 718        )
 719        .await?;
 720
 721        let total_rows = response
 722            .result_set_meta_data
 723            .as_ref()
 724            .and_then(|m| m.num_rows)
 725            .unwrap_or(response.data.len() as i64);
 726
 727        let num_partitions = response
 728            .result_set_meta_data
 729            .as_ref()
 730            .map(|m| m.partition_info.len())
 731            .unwrap_or(1)
 732            .max(1);
 733
 734        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 735        step_progress.set_substatus("parsing");
 736
 737        let column_indices = get_column_indices(
 738            &response.result_set_meta_data,
 739            &["request_id", "device_id", "time", "input", "zed_version"],
 740        );
 741
 742        all_examples.extend(requested_examples_from_response(
 743            &response,
 744            &column_indices,
 745        )?);
 746
 747        if num_partitions > 1 {
 748            let statement_handle = response
 749                .statement_handle
 750                .as_ref()
 751                .context("response has multiple partitions but no statementHandle")?;
 752
 753            for partition in 1..num_partitions {
 754                step_progress.set_substatus(format!(
 755                    "fetching partition {}/{}",
 756                    partition + 1,
 757                    num_partitions
 758                ));
 759
 760                let partition_response = fetch_partition(
 761                    http_client.clone(),
 762                    &base_url,
 763                    &token,
 764                    statement_handle,
 765                    partition,
 766                )
 767                .await?;
 768
 769                all_examples.extend(requested_examples_from_response(
 770                    &partition_response,
 771                    &column_indices,
 772                )?);
 773            }
 774        }
 775
 776        step_progress.set_substatus("done");
 777    }
 778
 779    Ok(all_examples)
 780}
 781
 782pub async fn fetch_rated_examples_after(
 783    http_client: Arc<dyn HttpClient>,
 784    inputs: &[(String, Option<EditPredictionRating>)],
 785    max_rows_per_timestamp: usize,
 786    offset: usize,
 787    background_executor: BackgroundExecutor,
 788    _min_capture_version: Option<MinCaptureVersion>,
 789) -> Result<Vec<Example>> {
 790    if inputs.is_empty() {
 791        return Ok(Vec::new());
 792    }
 793
 794    let progress = Progress::global();
 795
 796    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 797        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 798    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 799        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 800    )?;
 801    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 802
 803    let mut all_examples = Vec::new();
 804
 805    for (after_date, rating_filter) in inputs.iter() {
 806        let filter_label = match rating_filter {
 807            None => "",
 808            Some(EditPredictionRating::Positive) => ":positive",
 809            Some(EditPredictionRating::Negative) => ":negative",
 810        };
 811        let step_progress_name = format!("rated{filter_label}>{after_date}");
 812        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 813        step_progress.set_substatus("querying");
 814
 815        let rating_value = rating_filter.as_ref().map(|r| match r {
 816            EditPredictionRating::Positive => "Positive",
 817            EditPredictionRating::Negative => "Negative",
 818        });
 819
 820        let statement = indoc! {r#"
 821            SELECT
 822                rated.event_properties:request_id::string AS request_id,
 823                rated.event_properties:inputs AS inputs,
 824                rated.event_properties:output::string AS output,
 825                rated.event_properties:rating::string AS rating,
 826                rated.event_properties:feedback::string AS feedback,
 827                rated.device_id::string AS device_id,
 828                rated.time::string AS time,
 829                deploy.event_properties:experiment_name::string AS experiment_name,
 830                deploy.event_properties:environment::string AS environment,
 831                rated.event_properties:zed_version::string AS zed_version
 832            FROM events rated
 833            LEFT JOIN events req
 834                ON rated.event_properties:request_id::string = req.event_properties:request_id::string
 835                AND req.event_type = ?
 836            LEFT JOIN events deploy
 837                ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
 838                AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
 839                AND deploy.event_type = ?
 840            WHERE rated.event_type = ?
 841                AND (? IS NULL OR rated.event_properties:rating::string = ?)
 842                AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
 843                AND rated.event_properties:inputs IS NOT NULL
 844                AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
 845                AND rated.event_properties:output IS NOT NULL
 846                AND rated.event_properties:can_collect_data = true
 847            ORDER BY rated.time ASC
 848            LIMIT ?
 849            OFFSET ?
 850        "#};
 851
 852        let bindings = json!({
 853            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 854            "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
 855            "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
 856            "4": { "type": "TEXT", "value": rating_value },
 857            "5": { "type": "TEXT", "value": rating_value },
 858            "6": { "type": "TEXT", "value": after_date },
 859            "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 860            "8": { "type": "FIXED", "value": offset.to_string() }
 861        });
 862
 863        let request = json!({
 864            "statement": statement,
 865            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 866            "database": "EVENTS",
 867            "schema": "PUBLIC",
 868            "warehouse": "DBT",
 869            "role": role,
 870            "bindings": bindings
 871        });
 872
 873        let response = run_sql_with_polling(
 874            http_client.clone(),
 875            &base_url,
 876            &token,
 877            &request,
 878            &step_progress,
 879            background_executor.clone(),
 880        )
 881        .await?;
 882
 883        let total_rows = response
 884            .result_set_meta_data
 885            .as_ref()
 886            .and_then(|m| m.num_rows)
 887            .unwrap_or(response.data.len() as i64);
 888
 889        let num_partitions = response
 890            .result_set_meta_data
 891            .as_ref()
 892            .map(|m| m.partition_info.len())
 893            .unwrap_or(1)
 894            .max(1);
 895
 896        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 897        step_progress.set_substatus("parsing");
 898
 899        let column_indices = get_column_indices(
 900            &response.result_set_meta_data,
 901            &[
 902                "request_id",
 903                "inputs",
 904                "output",
 905                "rating",
 906                "feedback",
 907                "device_id",
 908                "time",
 909                "experiment_name",
 910                "environment",
 911                "zed_version",
 912            ],
 913        );
 914
 915        all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
 916
 917        if num_partitions > 1 {
 918            let statement_handle = response
 919                .statement_handle
 920                .as_ref()
 921                .context("response has multiple partitions but no statementHandle")?;
 922
 923            for partition in 1..num_partitions {
 924                step_progress.set_substatus(format!(
 925                    "fetching partition {}/{}",
 926                    partition + 1,
 927                    num_partitions
 928                ));
 929
 930                let partition_response = fetch_partition(
 931                    http_client.clone(),
 932                    &base_url,
 933                    &token,
 934                    statement_handle,
 935                    partition,
 936                )
 937                .await?;
 938
 939                all_examples.extend(rated_examples_from_response(
 940                    &partition_response,
 941                    &column_indices,
 942                )?);
 943            }
 944        }
 945
 946        step_progress.set_substatus("done");
 947    }
 948
 949    Ok(all_examples)
 950}
 951
 952fn rated_examples_from_response<'a>(
 953    response: &'a SnowflakeStatementResponse,
 954    column_indices: &'a std::collections::HashMap<String, usize>,
 955) -> Result<impl Iterator<Item = Example> + 'a> {
 956    if let Some(code) = &response.code {
 957        if code != SNOWFLAKE_SUCCESS_CODE {
 958            anyhow::bail!(
 959                "snowflake sql api returned error code={code} message={}",
 960                response.message.as_deref().unwrap_or("<no message>")
 961            );
 962        }
 963    }
 964
 965    let iter = response
 966        .data
 967        .iter()
 968        .enumerate()
 969        .filter_map(move |(row_index, data_row)| {
 970            let get_string = |name: &str| -> Option<String> {
 971                let index = column_indices.get(name).copied()?;
 972                match data_row.get(index)? {
 973                    JsonValue::String(s) => Some(s.clone()),
 974                    JsonValue::Null => None,
 975                    other => Some(other.to_string()),
 976                }
 977            };
 978
 979            let get_json = |name: &str| -> Option<JsonValue> {
 980                let index = column_indices.get(name).copied()?;
 981                let value = data_row.get(index)?;
 982                if value.is_null() {
 983                    return None;
 984                }
 985                match value {
 986                    JsonValue::String(s) => serde_json::from_str(s).ok(),
 987                    other => Some(other.clone()),
 988                }
 989            };
 990
 991            let request_id = get_string("request_id");
 992            let inputs_json = get_json("inputs");
 993            let inputs: Option<ZetaPromptInput> = match &inputs_json {
 994                Some(v) => match serde_json::from_value(v.clone()) {
 995                    Ok(parsed) => Some(parsed),
 996                    Err(e) => {
 997                        log::warn!(
 998                            "skipping row {row_index}: failed to parse inputs - {e}",
 999                        );
1000                        return None;
1001                    }
1002                },
1003                None => None,
1004            };
1005            let output = get_string("output");
1006            let rating = get_string("rating");
1007            let feedback = get_string("feedback").unwrap_or_default();
1008            let device_id = get_string("device_id");
1009            let time = get_string("time");
1010            let experiment_name = get_string("experiment_name");
1011            let environment = get_string("environment");
1012            let zed_version = get_string("zed_version");
1013
1014            match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
1015                (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
1016                    Some(build_rated_example(
1017                        request_id,
1018                        device_id,
1019                        time,
1020                        inputs,
1021                        output,
1022                        rating,
1023                        feedback,
1024                        experiment_name,
1025                        environment,
1026                        zed_version,
1027                    ))
1028                }
1029                _ => {
1030                    log::warn!(
1031                        "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
1032                        inputs_json.is_some(),
1033                        output.is_some(),
1034                        rating.is_some(),
1035                        device_id.is_some(),
1036                        time.is_some(),
1037                    );
1038                    None
1039                }
1040            }
1041        });
1042
1043    Ok(iter)
1044}
1045
1046fn build_rated_example(
1047    request_id: Option<String>,
1048    device_id: String,
1049    time: String,
1050    input: ZetaPromptInput,
1051    output: String,
1052    rating: String,
1053    feedback: String,
1054    experiment_name: Option<String>,
1055    environment: Option<String>,
1056    zed_version: Option<String>,
1057) -> Example {
1058    let parsed_rating = if rating == "Positive" {
1059        EditPredictionRating::Positive
1060    } else {
1061        EditPredictionRating::Negative
1062    };
1063    let is_positive = parsed_rating == EditPredictionRating::Positive;
1064    let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
1065
1066    let mut tags = Vec::with_capacity(3);
1067    tags.push(if is_positive {
1068        "rated:positive".to_string()
1069    } else {
1070        "rated:negative".to_string()
1071    });
1072    if let Some(experiment) = experiment_name {
1073        tags.push(format!("experiment:{experiment}"));
1074    }
1075    if let Some(env) = environment {
1076        tags.push(format!("environment:{env}"));
1077    }
1078
1079    let mut example =
1080        build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
1081
1082    example.spec.rating = Some(parsed_rating);
1083
1084    if !feedback.is_empty() {
1085        example
1086            .spec
1087            .human_feedback
1088            .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
1089    }
1090
1091    if is_positive {
1092        example.spec.expected_patches = vec![output];
1093    } else {
1094        example.spec.rejected_patch = Some(output);
1095    }
1096
1097    example
1098}
1099
1100fn requested_examples_from_response<'a>(
1101    response: &'a SnowflakeStatementResponse,
1102    column_indices: &'a std::collections::HashMap<String, usize>,
1103) -> Result<impl Iterator<Item = Example> + 'a> {
1104    if let Some(code) = &response.code {
1105        if code != SNOWFLAKE_SUCCESS_CODE {
1106            anyhow::bail!(
1107                "snowflake sql api returned error code={code} message={}",
1108                response.message.as_deref().unwrap_or("<no message>")
1109            );
1110        }
1111    }
1112
1113    let iter = response
1114        .data
1115        .iter()
1116        .enumerate()
1117        .filter_map(move |(row_index, data_row)| {
1118            let get_string = |name: &str| -> Option<String> {
1119                let index = column_indices.get(name).copied()?;
1120                match data_row.get(index)? {
1121                    JsonValue::String(s) => Some(s.clone()),
1122                    JsonValue::Null => None,
1123                    other => Some(other.to_string()),
1124                }
1125            };
1126
1127            let get_json = |name: &str| -> Option<JsonValue> {
1128                let index = column_indices.get(name).copied()?;
1129                let value = data_row.get(index)?;
1130                if value.is_null() {
1131                    return None;
1132                }
1133                match value {
1134                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1135                    other => Some(other.clone()),
1136                }
1137            };
1138
1139            let request_id_str = get_string("request_id");
1140            let device_id = get_string("device_id");
1141            let time = get_string("time");
1142            let input_json = get_json("input");
1143            let input: Option<ZetaPromptInput> =
1144                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1145            let zed_version = get_string("zed_version");
1146
1147            match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1148                (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1149                    Some(build_example_from_snowflake(
1150                        request_id,
1151                        device_id,
1152                        time,
1153                        input,
1154                        vec!["requested".to_string()],
1155                        None,
1156                        zed_version,
1157                    ))
1158                }
1159                _ => {
1160                    log::warn!(
1161                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1162                        request_id_str.is_some(),
1163                        device_id.is_some(),
1164                        time.is_some(),
1165                        input_json.is_some(),
1166                    );
1167                    None
1168                }
1169            }
1170        });
1171
1172    Ok(iter)
1173}
1174
1175fn rejected_examples_from_response<'a>(
1176    response: &'a SnowflakeStatementResponse,
1177    column_indices: &'a std::collections::HashMap<String, usize>,
1178) -> Result<impl Iterator<Item = Example> + 'a> {
1179    if let Some(code) = &response.code {
1180        if code != SNOWFLAKE_SUCCESS_CODE {
1181            anyhow::bail!(
1182                "snowflake sql api returned error code={code} message={}",
1183                response.message.as_deref().unwrap_or("<no message>")
1184            );
1185        }
1186    }
1187
1188    let iter = response
1189        .data
1190        .iter()
1191        .enumerate()
1192        .filter_map(move |(row_index, data_row)| {
1193            let get_string = |name: &str| -> Option<String> {
1194                let index = column_indices.get(name).copied()?;
1195                match data_row.get(index)? {
1196                    JsonValue::String(s) => Some(s.clone()),
1197                    JsonValue::Null => None,
1198                    other => Some(other.to_string()),
1199                }
1200            };
1201
1202            let get_json = |name: &str| -> Option<JsonValue> {
1203                let index = column_indices.get(name).copied()?;
1204                let value = data_row.get(index)?;
1205                if value.is_null() {
1206                    return None;
1207                }
1208                match value {
1209                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1210                    other => Some(other.clone()),
1211                }
1212            };
1213
1214            let get_bool = |name: &str| -> Option<bool> {
1215                let index = column_indices.get(name).copied()?;
1216                match data_row.get(index)? {
1217                    JsonValue::Bool(b) => Some(*b),
1218                    JsonValue::String(s) => s.parse().ok(),
1219                    _ => None,
1220                }
1221            };
1222
1223            let request_id_str = get_string("request_id");
1224            let device_id = get_string("device_id");
1225            let time = get_string("time");
1226            let input_json = get_json("input");
1227            let input: Option<ZetaPromptInput> =
1228                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1229            let output = get_string("output");
1230            let was_shown = get_bool("was_shown");
1231            let reason = get_string("reason");
1232            let zed_version = get_string("zed_version");
1233
1234            match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1235                (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1236                    Some(build_rejected_example(
1237                        request_id,
1238                        device_id,
1239                        time,
1240                        input,
1241                        output,
1242                        was_shown,
1243                        reason,
1244                        zed_version,
1245                    ))
1246                }
1247                _ => {
1248                    log::warn!(
1249                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1250                        request_id_str.is_some(),
1251                        device_id.is_some(),
1252                        time.is_some(),
1253                        input_json.is_some(),
1254                        output.is_some(),
1255                        was_shown.is_some(),
1256                        reason.is_some()
1257                    );
1258                    None
1259                }
1260            }
1261        });
1262
1263    Ok(iter)
1264}
1265
1266fn build_rejected_example(
1267    request_id: String,
1268    device_id: String,
1269    time: String,
1270    input: ZetaPromptInput,
1271    output: String,
1272    was_shown: bool,
1273    reason: String,
1274    zed_version: Option<String>,
1275) -> Example {
1276    let rejected_patch = build_output_patch(
1277        &input.cursor_path,
1278        input.cursor_excerpt.as_ref(),
1279        &input.editable_range_in_excerpt,
1280        &output,
1281    );
1282    let mut example = build_example_from_snowflake(
1283        request_id,
1284        device_id,
1285        time,
1286        input,
1287        vec![format!("rejection:{}", reason.to_lowercase())],
1288        Some(RejectionInfo { reason, was_shown }),
1289        zed_version,
1290    );
1291    example.spec.rejected_patch = Some(rejected_patch);
1292    example
1293}
1294
1295struct RejectionInfo {
1296    reason: String,
1297    was_shown: bool,
1298}
1299
1300fn build_example_from_snowflake(
1301    request_id: String,
1302    device_id: String,
1303    time: String,
1304    input: ZetaPromptInput,
1305    tags: Vec<String>,
1306    rejection: Option<RejectionInfo>,
1307    zed_version: Option<String>,
1308) -> Example {
1309    let events: Vec<CapturedEvent> = input
1310        .events
1311        .iter()
1312        .map(|event| match event.as_ref() {
1313            zeta_prompt::Event::BufferChange {
1314                path,
1315                old_path,
1316                diff,
1317                predicted,
1318                in_open_source_repo,
1319            } => CapturedEvent {
1320                path: path.clone(),
1321                old_path: old_path.clone(),
1322                diff: diff.clone(),
1323                predicted: *predicted,
1324                in_open_source_repo: *in_open_source_repo,
1325            },
1326        })
1327        .collect();
1328
1329    let related_files: Vec<CapturedRelatedFile> = input
1330        .related_files
1331        .iter()
1332        .map(|rf| CapturedRelatedFile {
1333            path: rf.path.clone(),
1334            max_row: rf.max_row,
1335            excerpts: rf
1336                .excerpts
1337                .iter()
1338                .map(|e| CapturedRelatedExcerpt {
1339                    row_range: e.row_range.clone(),
1340                    text: e.text.to_string(),
1341                })
1342                .collect(),
1343        })
1344        .collect();
1345
1346    let cursor_excerpt = input.cursor_excerpt.as_ref();
1347    let cursor_offset = input.cursor_offset_in_excerpt;
1348
1349    let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
1350
1351    let mut edit_history = String::new();
1352    for event in &input.events {
1353        zeta_prompt::write_event(&mut edit_history, event);
1354        edit_history.push('\n');
1355    }
1356
1357    let (rejection_reason, was_shown) = match &rejection {
1358        Some(r) => (r.reason.clone(), r.was_shown),
1359        None => (String::new(), false),
1360    };
1361
1362    let spec = ExampleSpec {
1363        name: request_id.clone(),
1364        repository_url: String::new(),
1365        revision: String::new(),
1366        tags,
1367        reasoning: None,
1368        uncommitted_diff: String::new(),
1369        cursor_path: input.cursor_path.clone(),
1370        cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1371        edit_history,
1372        expected_patches: Vec::new(),
1373        rejected_patch: None,
1374        captured_prompt_input: Some(CapturedPromptInput {
1375            cursor_file_content: cursor_excerpt.to_string(),
1376            cursor_offset,
1377            cursor_row,
1378            cursor_column,
1379            excerpt_start_row: None,
1380            events,
1381            related_files,
1382            in_open_source_repo: input.in_open_source_repo,
1383            zed_version,
1384        }),
1385        telemetry: Some(TelemetrySource {
1386            request_id,
1387            device_id,
1388            time,
1389            rejection_reason,
1390            was_shown,
1391        }),
1392        human_feedback: Vec::new(),
1393        rating: None,
1394    };
1395
1396    Example {
1397        spec,
1398        prompt_inputs: None,
1399        prompt: None,
1400        predictions: Vec::new(),
1401        score: Vec::new(),
1402        qa: Vec::new(),
1403        state: None,
1404    }
1405}
1406
1407fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
1408    let mut row = 0u32;
1409    let mut last_newline_offset = 0;
1410    for (i, c) in text.char_indices() {
1411        if i >= offset {
1412            break;
1413        }
1414        if c == '\n' {
1415            row += 1;
1416            last_newline_offset = i + 1;
1417        }
1418    }
1419    let column = (offset - last_newline_offset) as u32;
1420    (row, column)
1421}
1422
1423fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1424    let before = &excerpt[..cursor_offset.min(excerpt.len())];
1425    let after = &excerpt[cursor_offset.min(excerpt.len())..];
1426    format!("{}[CURSOR_POSITION]{}", before, after)
1427}
1428
1429fn build_output_patch(
1430    cursor_path: &std::path::Path,
1431    cursor_excerpt: &str,
1432    editable_range: &std::ops::Range<usize>,
1433    model_output: &str,
1434) -> String {
1435    let old_text = &cursor_excerpt[editable_range.clone()];
1436
1437    let editable_start_row = cursor_excerpt[..editable_range.start]
1438        .chars()
1439        .filter(|&c| c == '\n')
1440        .count() as u32;
1441
1442    let diff_body = language::unified_diff_with_offsets(
1443        old_text,
1444        model_output,
1445        editable_start_row,
1446        editable_start_row,
1447    );
1448
1449    let mut patch = String::new();
1450    writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1451    writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1452    patch.push_str(&diff_body);
1453    patch
1454}
1455
1456pub(crate) fn get_column_indices(
1457    meta: &Option<SnowflakeResultSetMetaData>,
1458    names: &[&str],
1459) -> std::collections::HashMap<String, usize> {
1460    let mut indices = std::collections::HashMap::new();
1461    if let Some(meta) = meta {
1462        for (index, col) in meta.row_type.iter().enumerate() {
1463            for &name in names {
1464                if col.name.eq_ignore_ascii_case(name) {
1465                    indices.insert(name.to_string(), index);
1466                }
1467            }
1468        }
1469    }
1470    indices
1471}