pull_examples.rs

   1use anyhow::{Context as _, Result};
   2use flate2::read::GzDecoder;
   3use gpui::BackgroundExecutor;
   4use http_client::{AsyncBody, HttpClient, Method, Request};
   5use indoc::indoc;
   6use serde::Deserialize;
   7use serde_json::{Value as JsonValue, json};
   8use std::io::Read;
   9use std::sync::Arc;
  10use std::time::Duration;
  11use telemetry_events::EditPredictionRating;
  12
  13use zeta_prompt::ZetaPromptInput;
  14
  15use crate::example::Example;
  16use crate::progress::{InfoStyle, Progress, Step};
  17const EDIT_PREDICTION_DEPLOYMENT_EVENT: &str = "Edit Prediction Deployment";
  18use edit_prediction::example_spec::{ExampleSpec, TelemetrySource};
  19use std::fmt::Write as _;
  20
  21pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
  22pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
  23const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
  24const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
  25const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
  26
  27/// Minimum Zed version for filtering captured examples.
  28/// For example, `MinCaptureVersion { minor: 224, patch: 1 }` means only pull examples
  29/// where `zed_version >= 0.224.1`.
  30#[derive(Clone, Copy, Debug)]
  31pub struct MinCaptureVersion {
  32    pub minor: u32,
  33    pub patch: u32,
  34}
  35
  36const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
  37pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
  38pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
  39
  40/// Parse an input token of the form `captured-after:{timestamp}`.
  41pub fn parse_captured_after_input(input: &str) -> Option<&str> {
  42    input.strip_prefix("captured-after:")
  43}
  44
  45/// Parse an input token of the form `rejected-after:{timestamp}`.
  46pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
  47    input.strip_prefix("rejected-after:")
  48}
  49
  50/// Parse an input token of the form `requested-after:{timestamp}`.
  51pub fn parse_requested_after_input(input: &str) -> Option<&str> {
  52    input.strip_prefix("requested-after:")
  53}
  54
  55/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
  56/// or `rated-negative-after:{timestamp}`.
  57/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
  58pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
  59    if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
  60        Some((timestamp, Some(EditPredictionRating::Positive)))
  61    } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
  62        Some((timestamp, Some(EditPredictionRating::Negative)))
  63    } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
  64        Some((timestamp, None))
  65    } else {
  66        None
  67    }
  68}
  69
  70#[derive(Debug, Clone, Deserialize)]
  71#[serde(rename_all = "camelCase")]
  72pub(crate) struct SnowflakeStatementResponse {
  73    #[serde(default)]
  74    pub(crate) data: Vec<Vec<JsonValue>>,
  75    #[serde(default)]
  76    pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
  77    #[serde(default)]
  78    pub(crate) code: Option<String>,
  79    #[serde(default)]
  80    pub(crate) message: Option<String>,
  81    #[serde(default)]
  82    pub(crate) statement_handle: Option<String>,
  83}
  84
  85#[derive(Debug, Clone, Deserialize)]
  86#[serde(rename_all = "camelCase")]
  87pub(crate) struct SnowflakeResultSetMetaData {
  88    #[serde(default, rename = "rowType")]
  89    row_type: Vec<SnowflakeColumnMeta>,
  90    #[serde(default)]
  91    num_rows: Option<i64>,
  92    #[serde(default)]
  93    partition_info: Vec<SnowflakePartitionInfo>,
  94}
  95
  96#[derive(Debug, Clone, Deserialize)]
  97#[serde(rename_all = "camelCase")]
  98struct SnowflakePartitionInfo {}
  99
 100#[derive(Debug, Clone, Deserialize)]
 101struct SnowflakeColumnMeta {
 102    #[serde(default)]
 103    name: String,
 104}
 105
 106async fn run_sql_with_polling(
 107    http_client: Arc<dyn HttpClient>,
 108    base_url: &str,
 109    token: &str,
 110    request: &serde_json::Value,
 111    step_progress: &crate::progress::StepProgress,
 112    background_executor: BackgroundExecutor,
 113) -> Result<SnowflakeStatementResponse> {
 114    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
 115
 116    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 117        let statement_handle = response
 118            .statement_handle
 119            .as_ref()
 120            .context("async query response missing statementHandle")?
 121            .clone();
 122
 123        for attempt in 1..=MAX_POLL_ATTEMPTS {
 124            step_progress.set_substatus(format!("polling ({attempt})"));
 125
 126            background_executor.timer(POLL_INTERVAL).await;
 127
 128            response =
 129                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
 130
 131            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 132                break;
 133            }
 134        }
 135
 136        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 137            anyhow::bail!(
 138                "query still running after {} poll attempts ({} seconds)",
 139                MAX_POLL_ATTEMPTS,
 140                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
 141            );
 142        }
 143    }
 144
 145    Ok(response)
 146}
 147
 148pub(crate) async fn fetch_partition(
 149    http_client: Arc<dyn HttpClient>,
 150    base_url: &str,
 151    token: &str,
 152    statement_handle: &str,
 153    partition: usize,
 154) -> Result<SnowflakeStatementResponse> {
 155    let url = format!(
 156        "{}/api/v2/statements/{}?partition={}",
 157        base_url.trim_end_matches('/'),
 158        statement_handle,
 159        partition
 160    );
 161
 162    let http_request = Request::builder()
 163        .method(Method::GET)
 164        .uri(url.as_str())
 165        .header("Authorization", format!("Bearer {token}"))
 166        .header(
 167            "X-Snowflake-Authorization-Token-Type",
 168            "PROGRAMMATIC_ACCESS_TOKEN",
 169        )
 170        .header("Accept", "application/json")
 171        .header("Accept-Encoding", "gzip")
 172        .header("User-Agent", "edit_prediction_cli")
 173        .body(AsyncBody::empty())?;
 174
 175    let response = http_client
 176        .send(http_request)
 177        .await
 178        .context("failed to send partition request to Snowflake SQL API")?;
 179
 180    let status = response.status();
 181    let content_encoding = response
 182        .headers()
 183        .get("content-encoding")
 184        .and_then(|v| v.to_str().ok())
 185        .map(|s| s.to_lowercase());
 186
 187    let body_bytes = {
 188        use futures::AsyncReadExt as _;
 189
 190        let mut body = response.into_body();
 191        let mut bytes = Vec::new();
 192        body.read_to_end(&mut bytes)
 193            .await
 194            .context("failed to read Snowflake SQL API partition response body")?;
 195        bytes
 196    };
 197
 198    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
 199        let mut decoder = GzDecoder::new(&body_bytes[..]);
 200        let mut decompressed = Vec::new();
 201        decoder
 202            .read_to_end(&mut decompressed)
 203            .context("failed to decompress gzip response")?;
 204        decompressed
 205    } else {
 206        body_bytes
 207    };
 208
 209    if !status.is_success() && status.as_u16() != 202 {
 210        let body_text = String::from_utf8_lossy(&body_bytes);
 211        anyhow::bail!(
 212            "snowflake sql api partition request http {}: {}",
 213            status.as_u16(),
 214            body_text
 215        );
 216    }
 217
 218    if body_bytes.is_empty() {
 219        anyhow::bail!(
 220            "snowflake sql api partition {} returned empty response body (http {})",
 221            partition,
 222            status.as_u16()
 223        );
 224    }
 225
 226    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
 227        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
 228        format!(
 229            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
 230            partition,
 231            status.as_u16(),
 232            body_preview
 233        )
 234    })
 235}
 236
 237pub(crate) async fn run_sql(
 238    http_client: Arc<dyn HttpClient>,
 239    base_url: &str,
 240    token: &str,
 241    request: &serde_json::Value,
 242) -> Result<SnowflakeStatementResponse> {
 243    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
 244
 245    let request_body =
 246        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
 247
 248    let http_request = Request::builder()
 249        .method(Method::POST)
 250        .uri(url.as_str())
 251        .header("Authorization", format!("Bearer {token}"))
 252        .header(
 253            "X-Snowflake-Authorization-Token-Type",
 254            "PROGRAMMATIC_ACCESS_TOKEN",
 255        )
 256        .header("Content-Type", "application/json")
 257        .header("Accept", "application/json")
 258        .header("User-Agent", "edit_prediction_cli")
 259        .body(AsyncBody::from(request_body.clone()))?;
 260
 261    let response = http_client
 262        .send(http_request)
 263        .await
 264        .context("failed to send request to Snowflake SQL API")?;
 265
 266    let status = response.status();
 267    let body_bytes = {
 268        use futures::AsyncReadExt as _;
 269
 270        let mut body = response.into_body();
 271        let mut bytes = Vec::new();
 272        body.read_to_end(&mut bytes)
 273            .await
 274            .context("failed to read Snowflake SQL API response body")?;
 275        bytes
 276    };
 277
 278    if !status.is_success() && status.as_u16() != 202 {
 279        let body_text = String::from_utf8_lossy(&body_bytes);
 280        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
 281    }
 282
 283    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
 284        .context("failed to parse Snowflake SQL API response JSON")
 285}
 286
 287pub async fn fetch_rejected_examples_after(
 288    http_client: Arc<dyn HttpClient>,
 289    after_timestamps: &[String],
 290    max_rows_per_timestamp: usize,
 291    offset: usize,
 292    background_executor: BackgroundExecutor,
 293    min_capture_version: Option<MinCaptureVersion>,
 294) -> Result<Vec<Example>> {
 295    if after_timestamps.is_empty() {
 296        return Ok(Vec::new());
 297    }
 298
 299    let progress = Progress::global();
 300
 301    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 302        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 303    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 304        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 305    )?;
 306    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 307
 308    let mut all_examples = Vec::new();
 309
 310    for after_date in after_timestamps.iter() {
 311        let step_progress_name = format!("rejected>{after_date}");
 312        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 313        step_progress.set_substatus("querying");
 314
 315        // Join rejected events with their corresponding request events to get the full context.
 316        // We filter for V3 sampling data which contains the structured input we need.
 317        // We also filter for predictions that were actually shown to the user (was_shown = true)
 318        // to focus on explicit user rejections rather than implicit cancellations.
 319        let statement = indoc! {r#"
 320            SELECT
 321                req.event_properties:request_id::string AS request_id,
 322                req.device_id::string AS device_id,
 323                req.time::string AS time,
 324                req.event_properties:input AS input,
 325                req.event_properties:prompt::string AS prompt,
 326                req.event_properties:output::string AS output,
 327                rej.event_properties:was_shown::boolean AS was_shown,
 328                rej.event_properties:reason::string AS reason,
 329                req.event_properties:zed_version::string AS zed_version
 330            FROM events req
 331            INNER JOIN events rej
 332                ON req.event_properties:request_id = rej.event_properties:request_id
 333            WHERE req.event_type = ?
 334                AND rej.event_type = ?
 335                AND req.event_properties:version = 'V3'
 336                AND rej.event_properties:was_shown = true
 337                AND req.event_properties:input:can_collect_data = true
 338                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 339                AND (? IS NULL OR (
 340                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 341                    OR (
 342                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 343                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 344                    )
 345                ))
 346            ORDER BY req.time ASC
 347            LIMIT ?
 348            OFFSET ?
 349        "#};
 350
 351        let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
 352        let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
 353        let min_minor_str_ref = min_minor_str.as_deref();
 354        let min_patch_str_ref = min_patch_str.as_deref();
 355        let request = json!({
 356            "statement": statement,
 357            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 358            "database": "EVENTS",
 359            "schema": "PUBLIC",
 360            "warehouse": "DBT",
 361            "role": role,
 362            "bindings": {
 363                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 364                "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
 365                "3": { "type": "TEXT", "value": after_date },
 366                "4": { "type": "FIXED", "value": min_minor_str_ref },
 367                "5": { "type": "FIXED", "value": min_minor_str_ref },
 368                "6": { "type": "FIXED", "value": min_minor_str_ref },
 369                "7": { "type": "FIXED", "value": min_patch_str_ref },
 370                "8": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 371                "9": { "type": "FIXED", "value": offset.to_string() }
 372            }
 373        });
 374
 375        let response = run_sql_with_polling(
 376            http_client.clone(),
 377            &base_url,
 378            &token,
 379            &request,
 380            &step_progress,
 381            background_executor.clone(),
 382        )
 383        .await?;
 384
 385        let total_rows = response
 386            .result_set_meta_data
 387            .as_ref()
 388            .and_then(|m| m.num_rows)
 389            .unwrap_or(response.data.len() as i64);
 390
 391        let num_partitions = response
 392            .result_set_meta_data
 393            .as_ref()
 394            .map(|m| m.partition_info.len())
 395            .unwrap_or(1)
 396            .max(1);
 397
 398        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 399        step_progress.set_substatus("parsing");
 400
 401        let column_indices = get_column_indices(
 402            &response.result_set_meta_data,
 403            &[
 404                "request_id",
 405                "device_id",
 406                "time",
 407                "input",
 408                "prompt",
 409                "output",
 410                "was_shown",
 411                "reason",
 412                "zed_version",
 413            ],
 414        );
 415
 416        all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
 417
 418        if num_partitions > 1 {
 419            let statement_handle = response
 420                .statement_handle
 421                .as_ref()
 422                .context("response has multiple partitions but no statementHandle")?;
 423
 424            for partition in 1..num_partitions {
 425                step_progress.set_substatus(format!(
 426                    "fetching partition {}/{}",
 427                    partition + 1,
 428                    num_partitions
 429                ));
 430
 431                let partition_response = fetch_partition(
 432                    http_client.clone(),
 433                    &base_url,
 434                    &token,
 435                    statement_handle,
 436                    partition,
 437                )
 438                .await?;
 439
 440                all_examples.extend(rejected_examples_from_response(
 441                    &partition_response,
 442                    &column_indices,
 443                )?);
 444            }
 445        }
 446
 447        step_progress.set_substatus("done");
 448    }
 449
 450    Ok(all_examples)
 451}
 452
 453pub async fn fetch_requested_examples_after(
 454    http_client: Arc<dyn HttpClient>,
 455    after_timestamps: &[String],
 456    max_rows_per_timestamp: usize,
 457    offset: usize,
 458    background_executor: BackgroundExecutor,
 459    min_capture_version: Option<MinCaptureVersion>,
 460) -> Result<Vec<Example>> {
 461    if after_timestamps.is_empty() {
 462        return Ok(Vec::new());
 463    }
 464
 465    let progress = Progress::global();
 466
 467    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 468        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 469    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 470        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 471    )?;
 472    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 473
 474    let mut all_examples = Vec::new();
 475
 476    for after_date in after_timestamps.iter() {
 477        let step_progress_name = format!("requested>{after_date}");
 478        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 479        step_progress.set_substatus("querying");
 480
 481        let statement = indoc! {r#"
 482            SELECT
 483                req.event_properties:request_id::string AS request_id,
 484                req.device_id::string AS device_id,
 485                req.time::string AS time,
 486                req.event_properties:input AS input,
 487                req.event_properties:zed_version::string AS zed_version
 488            FROM events req
 489            WHERE req.event_type = ?
 490                AND req.event_properties:version = 'V3'
 491                AND req.event_properties:input:can_collect_data = true
 492                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 493                AND (? IS NULL OR (
 494                    TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) > ?
 495                    OR (
 496                        TRY_CAST(SPLIT_PART(req.event_properties:zed_version::string, '.', 2) AS INTEGER) = ?
 497                        AND TRY_CAST(SPLIT_PART(SPLIT_PART(req.event_properties:zed_version::string, '.', 3), '+', 1) AS INTEGER) >= ?
 498                    )
 499                ))
 500            ORDER BY req.time ASC
 501            LIMIT ?
 502            OFFSET ?
 503        "#};
 504
 505        let min_minor_str = min_capture_version.map(|v| v.minor.to_string());
 506        let min_patch_str = min_capture_version.map(|v| v.patch.to_string());
 507        let min_minor_str_ref = min_minor_str.as_deref();
 508        let min_patch_str_ref = min_patch_str.as_deref();
 509        let request = json!({
 510            "statement": statement,
 511            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 512            "database": "EVENTS",
 513            "schema": "PUBLIC",
 514            "warehouse": "DBT",
 515            "role": role,
 516            "bindings": {
 517                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 518                "2": { "type": "TEXT", "value": after_date },
 519                "3": { "type": "FIXED", "value": min_minor_str_ref },
 520                "4": { "type": "FIXED", "value": min_minor_str_ref },
 521                "5": { "type": "FIXED", "value": min_minor_str_ref },
 522                "6": { "type": "FIXED", "value": min_patch_str_ref },
 523                "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 524                "8": { "type": "FIXED", "value": offset.to_string() }
 525            }
 526        });
 527
 528        let response = run_sql_with_polling(
 529            http_client.clone(),
 530            &base_url,
 531            &token,
 532            &request,
 533            &step_progress,
 534            background_executor.clone(),
 535        )
 536        .await?;
 537
 538        let total_rows = response
 539            .result_set_meta_data
 540            .as_ref()
 541            .and_then(|m| m.num_rows)
 542            .unwrap_or(response.data.len() as i64);
 543
 544        let num_partitions = response
 545            .result_set_meta_data
 546            .as_ref()
 547            .map(|m| m.partition_info.len())
 548            .unwrap_or(1)
 549            .max(1);
 550
 551        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 552        step_progress.set_substatus("parsing");
 553
 554        let column_indices = get_column_indices(
 555            &response.result_set_meta_data,
 556            &["request_id", "device_id", "time", "input", "zed_version"],
 557        );
 558
 559        all_examples.extend(requested_examples_from_response(
 560            &response,
 561            &column_indices,
 562        )?);
 563
 564        if num_partitions > 1 {
 565            let statement_handle = response
 566                .statement_handle
 567                .as_ref()
 568                .context("response has multiple partitions but no statementHandle")?;
 569
 570            for partition in 1..num_partitions {
 571                step_progress.set_substatus(format!(
 572                    "fetching partition {}/{}",
 573                    partition + 1,
 574                    num_partitions
 575                ));
 576
 577                let partition_response = fetch_partition(
 578                    http_client.clone(),
 579                    &base_url,
 580                    &token,
 581                    statement_handle,
 582                    partition,
 583                )
 584                .await?;
 585
 586                all_examples.extend(requested_examples_from_response(
 587                    &partition_response,
 588                    &column_indices,
 589                )?);
 590            }
 591        }
 592
 593        step_progress.set_substatus("done");
 594    }
 595
 596    Ok(all_examples)
 597}
 598
 599pub async fn fetch_rated_examples_after(
 600    http_client: Arc<dyn HttpClient>,
 601    inputs: &[(String, Option<EditPredictionRating>)],
 602    max_rows_per_timestamp: usize,
 603    offset: usize,
 604    background_executor: BackgroundExecutor,
 605    _min_capture_version: Option<MinCaptureVersion>,
 606) -> Result<Vec<Example>> {
 607    if inputs.is_empty() {
 608        return Ok(Vec::new());
 609    }
 610
 611    let progress = Progress::global();
 612
 613    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 614        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 615    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 616        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 617    )?;
 618    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 619
 620    let mut all_examples = Vec::new();
 621
 622    for (after_date, rating_filter) in inputs.iter() {
 623        let filter_label = match rating_filter {
 624            None => "",
 625            Some(EditPredictionRating::Positive) => ":positive",
 626            Some(EditPredictionRating::Negative) => ":negative",
 627        };
 628        let step_progress_name = format!("rated{filter_label}>{after_date}");
 629        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 630        step_progress.set_substatus("querying");
 631
 632        let rating_value = rating_filter.as_ref().map(|r| match r {
 633            EditPredictionRating::Positive => "Positive",
 634            EditPredictionRating::Negative => "Negative",
 635        });
 636
 637        let statement = indoc! {r#"
 638            SELECT
 639                rated.event_properties:request_id::string AS request_id,
 640                rated.event_properties:inputs AS inputs,
 641                rated.event_properties:output::string AS output,
 642                rated.event_properties:rating::string AS rating,
 643                rated.event_properties:feedback::string AS feedback,
 644                rated.device_id::string AS device_id,
 645                rated.time::string AS time,
 646                deploy.event_properties:experiment_name::string AS experiment_name,
 647                deploy.event_properties:environment::string AS environment,
 648                rated.event_properties:zed_version::string AS zed_version
 649            FROM events rated
 650            LEFT JOIN events req
 651                ON rated.event_properties:request_id::string = req.event_properties:request_id::string
 652                AND req.event_type = ?
 653            LEFT JOIN events deploy
 654                ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
 655                AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
 656                AND deploy.event_type = ?
 657            WHERE rated.event_type = ?
 658                AND (? IS NULL OR rated.event_properties:rating::string = ?)
 659                AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
 660                AND rated.event_properties:inputs IS NOT NULL
 661                AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
 662                AND rated.event_properties:output IS NOT NULL
 663                AND rated.event_properties:can_collect_data = true
 664            ORDER BY rated.time ASC
 665            LIMIT ?
 666            OFFSET ?
 667        "#};
 668
 669        let bindings = json!({
 670            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 671            "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
 672            "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
 673            "4": { "type": "TEXT", "value": rating_value },
 674            "5": { "type": "TEXT", "value": rating_value },
 675            "6": { "type": "TEXT", "value": after_date },
 676            "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() },
 677            "8": { "type": "FIXED", "value": offset.to_string() }
 678        });
 679
 680        let request = json!({
 681            "statement": statement,
 682            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 683            "database": "EVENTS",
 684            "schema": "PUBLIC",
 685            "warehouse": "DBT",
 686            "role": role,
 687            "bindings": bindings
 688        });
 689
 690        let response = run_sql_with_polling(
 691            http_client.clone(),
 692            &base_url,
 693            &token,
 694            &request,
 695            &step_progress,
 696            background_executor.clone(),
 697        )
 698        .await?;
 699
 700        let total_rows = response
 701            .result_set_meta_data
 702            .as_ref()
 703            .and_then(|m| m.num_rows)
 704            .unwrap_or(response.data.len() as i64);
 705
 706        let num_partitions = response
 707            .result_set_meta_data
 708            .as_ref()
 709            .map(|m| m.partition_info.len())
 710            .unwrap_or(1)
 711            .max(1);
 712
 713        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 714        step_progress.set_substatus("parsing");
 715
 716        let column_indices = get_column_indices(
 717            &response.result_set_meta_data,
 718            &[
 719                "request_id",
 720                "inputs",
 721                "output",
 722                "rating",
 723                "feedback",
 724                "device_id",
 725                "time",
 726                "experiment_name",
 727                "environment",
 728                "zed_version",
 729            ],
 730        );
 731
 732        all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
 733
 734        if num_partitions > 1 {
 735            let statement_handle = response
 736                .statement_handle
 737                .as_ref()
 738                .context("response has multiple partitions but no statementHandle")?;
 739
 740            for partition in 1..num_partitions {
 741                step_progress.set_substatus(format!(
 742                    "fetching partition {}/{}",
 743                    partition + 1,
 744                    num_partitions
 745                ));
 746
 747                let partition_response = fetch_partition(
 748                    http_client.clone(),
 749                    &base_url,
 750                    &token,
 751                    statement_handle,
 752                    partition,
 753                )
 754                .await?;
 755
 756                all_examples.extend(rated_examples_from_response(
 757                    &partition_response,
 758                    &column_indices,
 759                )?);
 760            }
 761        }
 762
 763        step_progress.set_substatus("done");
 764    }
 765
 766    Ok(all_examples)
 767}
 768
 769fn rated_examples_from_response<'a>(
 770    response: &'a SnowflakeStatementResponse,
 771    column_indices: &'a std::collections::HashMap<String, usize>,
 772) -> Result<impl Iterator<Item = Example> + 'a> {
 773    if let Some(code) = &response.code {
 774        if code != SNOWFLAKE_SUCCESS_CODE {
 775            anyhow::bail!(
 776                "snowflake sql api returned error code={code} message={}",
 777                response.message.as_deref().unwrap_or("<no message>")
 778            );
 779        }
 780    }
 781
 782    let iter = response
 783        .data
 784        .iter()
 785        .enumerate()
 786        .filter_map(move |(row_index, data_row)| {
 787            let get_string = |name: &str| -> Option<String> {
 788                let index = column_indices.get(name).copied()?;
 789                match data_row.get(index)? {
 790                    JsonValue::String(s) => Some(s.clone()),
 791                    JsonValue::Null => None,
 792                    other => Some(other.to_string()),
 793                }
 794            };
 795
 796            let get_json = |name: &str| -> Option<JsonValue> {
 797                let index = column_indices.get(name).copied()?;
 798                let value = data_row.get(index)?;
 799                if value.is_null() {
 800                    return None;
 801                }
 802                match value {
 803                    JsonValue::String(s) => serde_json::from_str(s).ok(),
 804                    other => Some(other.clone()),
 805                }
 806            };
 807
 808            let request_id = get_string("request_id");
 809            let inputs_json = get_json("inputs");
 810            let inputs: Option<ZetaPromptInput> = match &inputs_json {
 811                Some(v) => match serde_json::from_value(v.clone()) {
 812                    Ok(parsed) => Some(parsed),
 813                    Err(e) => {
 814                        log::warn!(
 815                            "skipping row {row_index}: failed to parse inputs - {e}",
 816                        );
 817                        return None;
 818                    }
 819                },
 820                None => None,
 821            };
 822            let output = get_string("output");
 823            let rating = get_string("rating");
 824            let feedback = get_string("feedback").unwrap_or_default();
 825            let device_id = get_string("device_id");
 826            let time = get_string("time");
 827            let experiment_name = get_string("experiment_name");
 828            let environment = get_string("environment");
 829            let zed_version = get_string("zed_version");
 830
 831            match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
 832                (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
 833                    Some(build_rated_example(
 834                        request_id,
 835                        device_id,
 836                        time,
 837                        inputs,
 838                        output,
 839                        rating,
 840                        feedback,
 841                        experiment_name,
 842                        environment,
 843                        zed_version,
 844                    ))
 845                }
 846                _ => {
 847                    log::warn!(
 848                        "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
 849                        inputs_json.is_some(),
 850                        output.is_some(),
 851                        rating.is_some(),
 852                        device_id.is_some(),
 853                        time.is_some(),
 854                    );
 855                    None
 856                }
 857            }
 858        });
 859
 860    Ok(iter)
 861}
 862
 863fn build_rated_example(
 864    request_id: Option<String>,
 865    device_id: String,
 866    time: String,
 867    input: ZetaPromptInput,
 868    output: String,
 869    rating: String,
 870    feedback: String,
 871    experiment_name: Option<String>,
 872    environment: Option<String>,
 873    zed_version: Option<String>,
 874) -> Example {
 875    let parsed_rating = if rating == "Positive" {
 876        EditPredictionRating::Positive
 877    } else {
 878        EditPredictionRating::Negative
 879    };
 880    let is_positive = parsed_rating == EditPredictionRating::Positive;
 881    let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
 882
 883    let mut tags = Vec::with_capacity(3);
 884    tags.push(if is_positive {
 885        "rated:positive".to_string()
 886    } else {
 887        "rated:negative".to_string()
 888    });
 889    if let Some(experiment) = experiment_name {
 890        tags.push(format!("experiment:{experiment}"));
 891    }
 892    if let Some(env) = environment {
 893        tags.push(format!("environment:{env}"));
 894    }
 895
 896    let mut example =
 897        build_example_from_snowflake(request_id, device_id, time, input, tags, None, zed_version);
 898
 899    example.spec.rating = Some(parsed_rating);
 900
 901    if !feedback.is_empty() {
 902        example
 903            .spec
 904            .human_feedback
 905            .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
 906    }
 907
 908    if is_positive {
 909        example.spec.expected_patches = vec![output];
 910    } else {
 911        example.spec.rejected_patch = Some(output);
 912    }
 913
 914    example
 915}
 916
 917fn requested_examples_from_response<'a>(
 918    response: &'a SnowflakeStatementResponse,
 919    column_indices: &'a std::collections::HashMap<String, usize>,
 920) -> Result<impl Iterator<Item = Example> + 'a> {
 921    if let Some(code) = &response.code {
 922        if code != SNOWFLAKE_SUCCESS_CODE {
 923            anyhow::bail!(
 924                "snowflake sql api returned error code={code} message={}",
 925                response.message.as_deref().unwrap_or("<no message>")
 926            );
 927        }
 928    }
 929
 930    let iter = response
 931        .data
 932        .iter()
 933        .enumerate()
 934        .filter_map(move |(row_index, data_row)| {
 935            let get_string = |name: &str| -> Option<String> {
 936                let index = column_indices.get(name).copied()?;
 937                match data_row.get(index)? {
 938                    JsonValue::String(s) => Some(s.clone()),
 939                    JsonValue::Null => None,
 940                    other => Some(other.to_string()),
 941                }
 942            };
 943
 944            let get_json = |name: &str| -> Option<JsonValue> {
 945                let index = column_indices.get(name).copied()?;
 946                let value = data_row.get(index)?;
 947                if value.is_null() {
 948                    return None;
 949                }
 950                match value {
 951                    JsonValue::String(s) => serde_json::from_str(s).ok(),
 952                    other => Some(other.clone()),
 953                }
 954            };
 955
 956            let request_id_str = get_string("request_id");
 957            let device_id = get_string("device_id");
 958            let time = get_string("time");
 959            let input_json = get_json("input");
 960            let input: Option<ZetaPromptInput> =
 961                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
 962            let zed_version = get_string("zed_version");
 963
 964            match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
 965                (Some(request_id), Some(device_id), Some(time), Some(input)) => {
 966                    Some(build_example_from_snowflake(
 967                        request_id,
 968                        device_id,
 969                        time,
 970                        input,
 971                        vec!["requested".to_string()],
 972                        None,
 973                        zed_version,
 974                    ))
 975                }
 976                _ => {
 977                    log::warn!(
 978                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
 979                        request_id_str.is_some(),
 980                        device_id.is_some(),
 981                        time.is_some(),
 982                        input_json.is_some(),
 983                    );
 984                    None
 985                }
 986            }
 987        });
 988
 989    Ok(iter)
 990}
 991
 992fn rejected_examples_from_response<'a>(
 993    response: &'a SnowflakeStatementResponse,
 994    column_indices: &'a std::collections::HashMap<String, usize>,
 995) -> Result<impl Iterator<Item = Example> + 'a> {
 996    if let Some(code) = &response.code {
 997        if code != SNOWFLAKE_SUCCESS_CODE {
 998            anyhow::bail!(
 999                "snowflake sql api returned error code={code} message={}",
1000                response.message.as_deref().unwrap_or("<no message>")
1001            );
1002        }
1003    }
1004
1005    let iter = response
1006        .data
1007        .iter()
1008        .enumerate()
1009        .filter_map(move |(row_index, data_row)| {
1010            let get_string = |name: &str| -> Option<String> {
1011                let index = column_indices.get(name).copied()?;
1012                match data_row.get(index)? {
1013                    JsonValue::String(s) => Some(s.clone()),
1014                    JsonValue::Null => None,
1015                    other => Some(other.to_string()),
1016                }
1017            };
1018
1019            let get_json = |name: &str| -> Option<JsonValue> {
1020                let index = column_indices.get(name).copied()?;
1021                let value = data_row.get(index)?;
1022                if value.is_null() {
1023                    return None;
1024                }
1025                match value {
1026                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1027                    other => Some(other.clone()),
1028                }
1029            };
1030
1031            let get_bool = |name: &str| -> Option<bool> {
1032                let index = column_indices.get(name).copied()?;
1033                match data_row.get(index)? {
1034                    JsonValue::Bool(b) => Some(*b),
1035                    JsonValue::String(s) => s.parse().ok(),
1036                    _ => None,
1037                }
1038            };
1039
1040            let request_id_str = get_string("request_id");
1041            let device_id = get_string("device_id");
1042            let time = get_string("time");
1043            let input_json = get_json("input");
1044            let input: Option<ZetaPromptInput> =
1045                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1046            let output = get_string("output");
1047            let was_shown = get_bool("was_shown");
1048            let reason = get_string("reason");
1049            let zed_version = get_string("zed_version");
1050
1051            match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1052                (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1053                    Some(build_rejected_example(
1054                        request_id,
1055                        device_id,
1056                        time,
1057                        input,
1058                        output,
1059                        was_shown,
1060                        reason,
1061                        zed_version,
1062                    ))
1063                }
1064                _ => {
1065                    log::warn!(
1066                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1067                        request_id_str.is_some(),
1068                        device_id.is_some(),
1069                        time.is_some(),
1070                        input_json.is_some(),
1071                        output.is_some(),
1072                        was_shown.is_some(),
1073                        reason.is_some()
1074                    );
1075                    None
1076                }
1077            }
1078        });
1079
1080    Ok(iter)
1081}
1082
1083fn build_rejected_example(
1084    request_id: String,
1085    device_id: String,
1086    time: String,
1087    input: ZetaPromptInput,
1088    output: String,
1089    was_shown: bool,
1090    reason: String,
1091    zed_version: Option<String>,
1092) -> Example {
1093    let rejected_patch = build_output_patch(
1094        &input.cursor_path,
1095        input.cursor_excerpt.as_ref(),
1096        &input.editable_range_in_excerpt,
1097        &output,
1098    );
1099    let mut example = build_example_from_snowflake(
1100        request_id,
1101        device_id,
1102        time,
1103        input,
1104        vec![format!("rejection:{}", reason.to_lowercase())],
1105        Some(RejectionInfo { reason, was_shown }),
1106        zed_version,
1107    );
1108    example.spec.rejected_patch = Some(rejected_patch);
1109    example
1110}
1111
1112struct RejectionInfo {
1113    reason: String,
1114    was_shown: bool,
1115}
1116
1117fn build_example_from_snowflake(
1118    request_id: String,
1119    device_id: String,
1120    time: String,
1121    input: ZetaPromptInput,
1122    tags: Vec<String>,
1123    rejection: Option<RejectionInfo>,
1124    zed_version: Option<String>,
1125) -> Example {
1126    let cursor_excerpt = input.cursor_excerpt.as_ref();
1127    let cursor_offset = input.cursor_offset_in_excerpt;
1128
1129    let mut edit_history = String::new();
1130    for event in &input.events {
1131        zeta_prompt::write_event(&mut edit_history, event);
1132        edit_history.push('\n');
1133    }
1134
1135    let (rejection_reason, was_shown) = match &rejection {
1136        Some(r) => (r.reason.clone(), r.was_shown),
1137        None => (String::new(), false),
1138    };
1139
1140    let spec = ExampleSpec {
1141        name: request_id.clone(),
1142        repository_url: String::new(),
1143        revision: String::new(),
1144        tags,
1145        reasoning: None,
1146        uncommitted_diff: String::new(),
1147        cursor_path: input.cursor_path.clone(),
1148        cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1149        edit_history,
1150        expected_patches: Vec::new(),
1151        rejected_patch: None,
1152        telemetry: Some(TelemetrySource {
1153            request_id,
1154            device_id,
1155            time,
1156            rejection_reason,
1157            was_shown,
1158        }),
1159        human_feedback: Vec::new(),
1160        rating: None,
1161    };
1162
1163    Example {
1164        spec,
1165        zed_version,
1166        prompt_inputs: Some(input),
1167        prompt: None,
1168        predictions: Vec::new(),
1169        score: Vec::new(),
1170        qa: Vec::new(),
1171        state: None,
1172    }
1173}
1174
1175fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1176    let before = &excerpt[..cursor_offset.min(excerpt.len())];
1177    let after = &excerpt[cursor_offset.min(excerpt.len())..];
1178    format!("{}[CURSOR_POSITION]{}", before, after)
1179}
1180
1181fn build_output_patch(
1182    cursor_path: &std::path::Path,
1183    cursor_excerpt: &str,
1184    editable_range: &std::ops::Range<usize>,
1185    model_output: &str,
1186) -> String {
1187    let old_text = &cursor_excerpt[editable_range.clone()];
1188
1189    let editable_start_row = cursor_excerpt[..editable_range.start]
1190        .chars()
1191        .filter(|&c| c == '\n')
1192        .count() as u32;
1193
1194    let diff_body = language::unified_diff_with_offsets(
1195        old_text,
1196        model_output,
1197        editable_start_row,
1198        editable_start_row,
1199    );
1200
1201    let mut patch = String::new();
1202    writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1203    writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1204    patch.push_str(&diff_body);
1205    patch
1206}
1207
1208pub(crate) fn get_column_indices(
1209    meta: &Option<SnowflakeResultSetMetaData>,
1210    names: &[&str],
1211) -> std::collections::HashMap<String, usize> {
1212    let mut indices = std::collections::HashMap::new();
1213    if let Some(meta) = meta {
1214        for (index, col) in meta.row_type.iter().enumerate() {
1215            for &name in names {
1216                if col.name.eq_ignore_ascii_case(name) {
1217                    indices.insert(name.to_string(), index);
1218                }
1219            }
1220        }
1221    }
1222    indices
1223}