pull_examples.rs

   1use anyhow::{Context as _, Result};
   2use flate2::read::GzDecoder;
   3use gpui::BackgroundExecutor;
   4use http_client::{AsyncBody, HttpClient, Method, Request};
   5use indoc::indoc;
   6use serde::Deserialize;
   7use serde_json::{Value as JsonValue, json};
   8use std::io::Read;
   9use std::sync::Arc;
  10use std::time::Duration;
  11use telemetry_events::EditPredictionRating;
  12
  13use zeta_prompt::ZetaPromptInput;
  14
  15use crate::example::Example;
  16use crate::progress::{InfoStyle, Progress, Step};
  17use crate::sync_deployments::EDIT_PREDICTION_DEPLOYMENT_EVENT;
  18use edit_prediction::example_spec::{
  19    CapturedEvent, CapturedPromptInput, CapturedRelatedExcerpt, CapturedRelatedFile, ExampleSpec,
  20    TelemetrySource,
  21};
  22use std::fmt::Write as _;
  23
  24pub(crate) const SNOWFLAKE_SUCCESS_CODE: &str = "090001";
  25pub(crate) const SNOWFLAKE_ASYNC_IN_PROGRESS_CODE: &str = "333334";
  26const EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT: &str = "Edit Prediction Example Captured";
  27const PREDICTIVE_EDIT_REQUESTED_EVENT: &str = "Predictive Edit Requested";
  28const PREDICTIVE_EDIT_REJECTED_EVENT: &str = "Predictive Edit Rejected";
  29const EDIT_PREDICTION_RATED_EVENT: &str = "Edit Prediction Rated";
  30
  31const DEFAULT_STATEMENT_TIMEOUT_SECONDS: u64 = 120;
  32pub(crate) const POLL_INTERVAL: Duration = Duration::from_secs(2);
  33pub(crate) const MAX_POLL_ATTEMPTS: usize = 120;
  34
  35/// Parse an input token of the form `captured-after:{timestamp}`.
  36pub fn parse_captured_after_input(input: &str) -> Option<&str> {
  37    input.strip_prefix("captured-after:")
  38}
  39
  40/// Parse an input token of the form `rejected-after:{timestamp}`.
  41pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
  42    input.strip_prefix("rejected-after:")
  43}
  44
  45/// Parse an input token of the form `requested-after:{timestamp}`.
  46pub fn parse_requested_after_input(input: &str) -> Option<&str> {
  47    input.strip_prefix("requested-after:")
  48}
  49
  50/// Parse an input token of the form `rated-after:{timestamp}`, `rated-positive-after:{timestamp}`,
  51/// or `rated-negative-after:{timestamp}`.
  52/// Returns `(timestamp, Option<EditPredictionRating>)` where `None` means all ratings.
  53pub fn parse_rated_after_input(input: &str) -> Option<(&str, Option<EditPredictionRating>)> {
  54    if let Some(timestamp) = input.strip_prefix("rated-positive-after:") {
  55        Some((timestamp, Some(EditPredictionRating::Positive)))
  56    } else if let Some(timestamp) = input.strip_prefix("rated-negative-after:") {
  57        Some((timestamp, Some(EditPredictionRating::Negative)))
  58    } else if let Some(timestamp) = input.strip_prefix("rated-after:") {
  59        Some((timestamp, None))
  60    } else {
  61        None
  62    }
  63}
  64
  65pub async fn fetch_captured_examples_after(
  66    http_client: Arc<dyn HttpClient>,
  67    after_timestamps: &[String],
  68    max_rows_per_timestamp: usize,
  69    background_executor: BackgroundExecutor,
  70) -> Result<Vec<Example>> {
  71    if after_timestamps.is_empty() {
  72        return Ok(Vec::new());
  73    }
  74
  75    let progress = Progress::global();
  76
  77    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
  78        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
  79    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
  80        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
  81    )?;
  82    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
  83
  84    let mut all_examples = Vec::new();
  85
  86    for after_date in after_timestamps.iter() {
  87        let step_progress_name = format!(">{after_date}");
  88        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
  89        step_progress.set_substatus("querying");
  90
  91        let statement = indoc! {r#"
  92            SELECT
  93                event_properties:example AS example
  94            FROM events
  95            WHERE event_type = ?
  96                AND time > TRY_TO_TIMESTAMP_NTZ(?)
  97            ORDER BY time ASC
  98            LIMIT ?
  99        "#};
 100
 101        let request = json!({
 102            "statement": statement,
 103            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 104            "database": "EVENTS",
 105            "schema": "PUBLIC",
 106            "warehouse": "DBT",
 107            "role": role,
 108            "bindings": {
 109                "1": { "type": "TEXT", "value": EDIT_PREDICTION_EXAMPLE_CAPTURED_EVENT },
 110                "2": { "type": "TEXT", "value": after_date },
 111                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 112            }
 113        });
 114
 115        let response = run_sql_with_polling(
 116            http_client.clone(),
 117            &base_url,
 118            &token,
 119            &request,
 120            &step_progress,
 121            background_executor.clone(),
 122        )
 123        .await?;
 124
 125        let total_rows = response
 126            .result_set_meta_data
 127            .as_ref()
 128            .and_then(|m| m.num_rows)
 129            .unwrap_or(response.data.len() as i64);
 130
 131        let num_partitions = response
 132            .result_set_meta_data
 133            .as_ref()
 134            .map(|m| m.partition_info.len())
 135            .unwrap_or(1)
 136            .max(1);
 137
 138        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 139        step_progress.set_substatus("parsing");
 140
 141        let example_index = response
 142            .result_set_meta_data
 143            .as_ref()
 144            .and_then(|m| {
 145                m.row_type.iter().enumerate().find_map(|(index, col)| {
 146                    if col.name.eq_ignore_ascii_case("example") {
 147                        Some(index)
 148                    } else {
 149                        None
 150                    }
 151                })
 152            })
 153            .unwrap_or(0);
 154
 155        all_examples.extend(examples_from_response(&response, example_index)?);
 156
 157        if num_partitions > 1 {
 158            let statement_handle = response
 159                .statement_handle
 160                .as_ref()
 161                .context("response has multiple partitions but no statementHandle")?;
 162
 163            for partition in 1..num_partitions {
 164                step_progress.set_substatus(format!(
 165                    "fetching partition {}/{}",
 166                    partition + 1,
 167                    num_partitions
 168                ));
 169
 170                let partition_response = fetch_partition(
 171                    http_client.clone(),
 172                    &base_url,
 173                    &token,
 174                    statement_handle,
 175                    partition,
 176                )
 177                .await?;
 178
 179                all_examples.extend(examples_from_response(&partition_response, example_index)?);
 180            }
 181        }
 182
 183        step_progress.set_substatus("done");
 184    }
 185
 186    Ok(all_examples)
 187}
 188
 189#[derive(Debug, Clone, Deserialize)]
 190#[serde(rename_all = "camelCase")]
 191pub(crate) struct SnowflakeStatementResponse {
 192    #[serde(default)]
 193    pub(crate) data: Vec<Vec<JsonValue>>,
 194    #[serde(default)]
 195    pub(crate) result_set_meta_data: Option<SnowflakeResultSetMetaData>,
 196    #[serde(default)]
 197    pub(crate) code: Option<String>,
 198    #[serde(default)]
 199    pub(crate) message: Option<String>,
 200    #[serde(default)]
 201    pub(crate) statement_handle: Option<String>,
 202}
 203
 204#[derive(Debug, Clone, Deserialize)]
 205#[serde(rename_all = "camelCase")]
 206pub(crate) struct SnowflakeResultSetMetaData {
 207    #[serde(default, rename = "rowType")]
 208    row_type: Vec<SnowflakeColumnMeta>,
 209    #[serde(default)]
 210    num_rows: Option<i64>,
 211    #[serde(default)]
 212    partition_info: Vec<SnowflakePartitionInfo>,
 213}
 214
 215#[derive(Debug, Clone, Deserialize)]
 216#[serde(rename_all = "camelCase")]
 217struct SnowflakePartitionInfo {}
 218
 219#[derive(Debug, Clone, Deserialize)]
 220struct SnowflakeColumnMeta {
 221    #[serde(default)]
 222    name: String,
 223}
 224
 225fn examples_from_response(
 226    response: &SnowflakeStatementResponse,
 227    example_index: usize,
 228) -> Result<impl Iterator<Item = Example> + '_> {
 229    if let Some(code) = &response.code {
 230        if code != SNOWFLAKE_SUCCESS_CODE {
 231            anyhow::bail!(
 232                "snowflake sql api returned error code={code} message={}",
 233                response.message.as_deref().unwrap_or("<no message>")
 234            );
 235        }
 236    }
 237
 238    let iter = response.data.iter().enumerate().filter_map(move |(row_index, data_row)| {
 239        let Some(example_value) = data_row.get(example_index) else {
 240            return None;
 241        };
 242        if example_value.is_null() {
 243            return None;
 244        }
 245
 246        let parse_result = match example_value {
 247            JsonValue::String(encoded_json) => serde_json::from_str::<ExampleSpec>(encoded_json),
 248            _ => serde_json::from_value::<ExampleSpec>(example_value.clone()),
 249        };
 250
 251        match parse_result {
 252            Ok(spec) => Some(Example {
 253                spec,
 254                prompt_inputs: None,
 255                prompt: None,
 256                predictions: Vec::new(),
 257                score: Vec::new(),
 258                qa: Vec::new(),
 259                state: None,
 260            }),
 261            Err(error) => {
 262                let raw_json = serde_json::to_string_pretty(example_value)
 263                    .unwrap_or_else(|_| "<failed to serialize json>".to_string());
 264                log::error!(
 265                    "failed to parse ExampleSpec for row {row_index}: {error:#}\nraw json:\n{raw_json}"
 266                );
 267                None
 268            }
 269        }
 270    });
 271
 272    Ok(iter)
 273}
 274
 275async fn run_sql_with_polling(
 276    http_client: Arc<dyn HttpClient>,
 277    base_url: &str,
 278    token: &str,
 279    request: &serde_json::Value,
 280    step_progress: &crate::progress::StepProgress,
 281    background_executor: BackgroundExecutor,
 282) -> Result<SnowflakeStatementResponse> {
 283    let mut response = run_sql(http_client.clone(), base_url, token, request).await?;
 284
 285    if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 286        let statement_handle = response
 287            .statement_handle
 288            .as_ref()
 289            .context("async query response missing statementHandle")?
 290            .clone();
 291
 292        for attempt in 1..=MAX_POLL_ATTEMPTS {
 293            step_progress.set_substatus(format!("polling ({attempt})"));
 294
 295            background_executor.timer(POLL_INTERVAL).await;
 296
 297            response =
 298                fetch_partition(http_client.clone(), base_url, token, &statement_handle, 0).await?;
 299
 300            if response.code.as_deref() != Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 301                break;
 302            }
 303        }
 304
 305        if response.code.as_deref() == Some(SNOWFLAKE_ASYNC_IN_PROGRESS_CODE) {
 306            anyhow::bail!(
 307                "query still running after {} poll attempts ({} seconds)",
 308                MAX_POLL_ATTEMPTS,
 309                MAX_POLL_ATTEMPTS as u64 * POLL_INTERVAL.as_secs()
 310            );
 311        }
 312    }
 313
 314    Ok(response)
 315}
 316
 317pub(crate) async fn fetch_partition(
 318    http_client: Arc<dyn HttpClient>,
 319    base_url: &str,
 320    token: &str,
 321    statement_handle: &str,
 322    partition: usize,
 323) -> Result<SnowflakeStatementResponse> {
 324    let url = format!(
 325        "{}/api/v2/statements/{}?partition={}",
 326        base_url.trim_end_matches('/'),
 327        statement_handle,
 328        partition
 329    );
 330
 331    let http_request = Request::builder()
 332        .method(Method::GET)
 333        .uri(url.as_str())
 334        .header("Authorization", format!("Bearer {token}"))
 335        .header(
 336            "X-Snowflake-Authorization-Token-Type",
 337            "PROGRAMMATIC_ACCESS_TOKEN",
 338        )
 339        .header("Accept", "application/json")
 340        .header("Accept-Encoding", "gzip")
 341        .header("User-Agent", "edit_prediction_cli")
 342        .body(AsyncBody::empty())?;
 343
 344    let response = http_client
 345        .send(http_request)
 346        .await
 347        .context("failed to send partition request to Snowflake SQL API")?;
 348
 349    let status = response.status();
 350    let content_encoding = response
 351        .headers()
 352        .get("content-encoding")
 353        .and_then(|v| v.to_str().ok())
 354        .map(|s| s.to_lowercase());
 355
 356    let body_bytes = {
 357        use futures::AsyncReadExt as _;
 358
 359        let mut body = response.into_body();
 360        let mut bytes = Vec::new();
 361        body.read_to_end(&mut bytes)
 362            .await
 363            .context("failed to read Snowflake SQL API partition response body")?;
 364        bytes
 365    };
 366
 367    let body_bytes = if content_encoding.as_deref() == Some("gzip") {
 368        let mut decoder = GzDecoder::new(&body_bytes[..]);
 369        let mut decompressed = Vec::new();
 370        decoder
 371            .read_to_end(&mut decompressed)
 372            .context("failed to decompress gzip response")?;
 373        decompressed
 374    } else {
 375        body_bytes
 376    };
 377
 378    if !status.is_success() && status.as_u16() != 202 {
 379        let body_text = String::from_utf8_lossy(&body_bytes);
 380        anyhow::bail!(
 381            "snowflake sql api partition request http {}: {}",
 382            status.as_u16(),
 383            body_text
 384        );
 385    }
 386
 387    if body_bytes.is_empty() {
 388        anyhow::bail!(
 389            "snowflake sql api partition {} returned empty response body (http {})",
 390            partition,
 391            status.as_u16()
 392        );
 393    }
 394
 395    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes).with_context(|| {
 396        let body_preview = String::from_utf8_lossy(&body_bytes[..body_bytes.len().min(500)]);
 397        format!(
 398            "failed to parse Snowflake SQL API partition {} response JSON (http {}): {}",
 399            partition,
 400            status.as_u16(),
 401            body_preview
 402        )
 403    })
 404}
 405
 406pub(crate) async fn run_sql(
 407    http_client: Arc<dyn HttpClient>,
 408    base_url: &str,
 409    token: &str,
 410    request: &serde_json::Value,
 411) -> Result<SnowflakeStatementResponse> {
 412    let url = format!("{}/api/v2/statements", base_url.trim_end_matches('/'));
 413
 414    let request_body =
 415        serde_json::to_vec(request).context("failed to serialize Snowflake SQL API request")?;
 416
 417    let http_request = Request::builder()
 418        .method(Method::POST)
 419        .uri(url.as_str())
 420        .header("Authorization", format!("Bearer {token}"))
 421        .header(
 422            "X-Snowflake-Authorization-Token-Type",
 423            "PROGRAMMATIC_ACCESS_TOKEN",
 424        )
 425        .header("Content-Type", "application/json")
 426        .header("Accept", "application/json")
 427        .header("User-Agent", "edit_prediction_cli")
 428        .body(AsyncBody::from(request_body.clone()))?;
 429
 430    let response = http_client
 431        .send(http_request)
 432        .await
 433        .context("failed to send request to Snowflake SQL API")?;
 434
 435    let status = response.status();
 436    let body_bytes = {
 437        use futures::AsyncReadExt as _;
 438
 439        let mut body = response.into_body();
 440        let mut bytes = Vec::new();
 441        body.read_to_end(&mut bytes)
 442            .await
 443            .context("failed to read Snowflake SQL API response body")?;
 444        bytes
 445    };
 446
 447    if !status.is_success() && status.as_u16() != 202 {
 448        let body_text = String::from_utf8_lossy(&body_bytes);
 449        anyhow::bail!("snowflake sql api http {}: {}", status.as_u16(), body_text);
 450    }
 451
 452    serde_json::from_slice::<SnowflakeStatementResponse>(&body_bytes)
 453        .context("failed to parse Snowflake SQL API response JSON")
 454}
 455
 456pub async fn fetch_rejected_examples_after(
 457    http_client: Arc<dyn HttpClient>,
 458    after_timestamps: &[String],
 459    max_rows_per_timestamp: usize,
 460    background_executor: BackgroundExecutor,
 461) -> Result<Vec<Example>> {
 462    if after_timestamps.is_empty() {
 463        return Ok(Vec::new());
 464    }
 465
 466    let progress = Progress::global();
 467
 468    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 469        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 470    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 471        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 472    )?;
 473    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 474
 475    let mut all_examples = Vec::new();
 476
 477    for after_date in after_timestamps.iter() {
 478        let step_progress_name = format!("rejected>{after_date}");
 479        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 480        step_progress.set_substatus("querying");
 481
 482        // Join rejected events with their corresponding request events to get the full context.
 483        // We filter for V3 sampling data which contains the structured input we need.
 484        // We also filter for predictions that were actually shown to the user (was_shown = true)
 485        // to focus on explicit user rejections rather than implicit cancellations.
 486        let statement = indoc! {r#"
 487            SELECT
 488                req.event_properties:request_id::string AS request_id,
 489                req.device_id::string AS device_id,
 490                req.time::string AS time,
 491                req.event_properties:input AS input,
 492                req.event_properties:prompt::string AS prompt,
 493                req.event_properties:output::string AS output,
 494                rej.event_properties:was_shown::boolean AS was_shown,
 495                rej.event_properties:reason::string AS reason
 496            FROM events req
 497            INNER JOIN events rej
 498                ON req.event_properties:request_id = rej.event_properties:request_id
 499            WHERE req.event_type = ?
 500                AND rej.event_type = ?
 501                AND req.event_properties:version = 'V3'
 502                AND rej.event_properties:was_shown = true
 503                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 504            ORDER BY req.time ASC
 505            LIMIT ?
 506        "#};
 507
 508        let request = json!({
 509            "statement": statement,
 510            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 511            "database": "EVENTS",
 512            "schema": "PUBLIC",
 513            "warehouse": "DBT",
 514            "role": role,
 515            "bindings": {
 516                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 517                "2": { "type": "TEXT", "value": PREDICTIVE_EDIT_REJECTED_EVENT },
 518                "3": { "type": "TEXT", "value": after_date },
 519                "4": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 520            }
 521        });
 522
 523        let response = run_sql_with_polling(
 524            http_client.clone(),
 525            &base_url,
 526            &token,
 527            &request,
 528            &step_progress,
 529            background_executor.clone(),
 530        )
 531        .await?;
 532
 533        let total_rows = response
 534            .result_set_meta_data
 535            .as_ref()
 536            .and_then(|m| m.num_rows)
 537            .unwrap_or(response.data.len() as i64);
 538
 539        let num_partitions = response
 540            .result_set_meta_data
 541            .as_ref()
 542            .map(|m| m.partition_info.len())
 543            .unwrap_or(1)
 544            .max(1);
 545
 546        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 547        step_progress.set_substatus("parsing");
 548
 549        let column_indices = get_column_indices(
 550            &response.result_set_meta_data,
 551            &[
 552                "request_id",
 553                "device_id",
 554                "time",
 555                "input",
 556                "prompt",
 557                "output",
 558                "was_shown",
 559                "reason",
 560            ],
 561        );
 562
 563        all_examples.extend(rejected_examples_from_response(&response, &column_indices)?);
 564
 565        if num_partitions > 1 {
 566            let statement_handle = response
 567                .statement_handle
 568                .as_ref()
 569                .context("response has multiple partitions but no statementHandle")?;
 570
 571            for partition in 1..num_partitions {
 572                step_progress.set_substatus(format!(
 573                    "fetching partition {}/{}",
 574                    partition + 1,
 575                    num_partitions
 576                ));
 577
 578                let partition_response = fetch_partition(
 579                    http_client.clone(),
 580                    &base_url,
 581                    &token,
 582                    statement_handle,
 583                    partition,
 584                )
 585                .await?;
 586
 587                all_examples.extend(rejected_examples_from_response(
 588                    &partition_response,
 589                    &column_indices,
 590                )?);
 591            }
 592        }
 593
 594        step_progress.set_substatus("done");
 595    }
 596
 597    Ok(all_examples)
 598}
 599
 600pub async fn fetch_requested_examples_after(
 601    http_client: Arc<dyn HttpClient>,
 602    after_timestamps: &[String],
 603    max_rows_per_timestamp: usize,
 604    background_executor: BackgroundExecutor,
 605) -> Result<Vec<Example>> {
 606    if after_timestamps.is_empty() {
 607        return Ok(Vec::new());
 608    }
 609
 610    let progress = Progress::global();
 611
 612    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 613        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 614    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 615        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 616    )?;
 617    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 618
 619    let mut all_examples = Vec::new();
 620
 621    for after_date in after_timestamps.iter() {
 622        let step_progress_name = format!("requested>{after_date}");
 623        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 624        step_progress.set_substatus("querying");
 625
 626        let statement = indoc! {r#"
 627            SELECT
 628                req.event_properties:request_id::string AS request_id,
 629                req.device_id::string AS device_id,
 630                req.time::string AS time,
 631                req.event_properties:input AS input
 632            FROM events req
 633            WHERE req.event_type = ?
 634                AND req.event_properties:version = 'V3'
 635                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
 636            ORDER BY req.time ASC
 637            LIMIT ?
 638        "#};
 639
 640        let request = json!({
 641            "statement": statement,
 642            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 643            "database": "EVENTS",
 644            "schema": "PUBLIC",
 645            "warehouse": "DBT",
 646            "role": role,
 647            "bindings": {
 648                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 649                "2": { "type": "TEXT", "value": after_date },
 650                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 651            }
 652        });
 653
 654        let response = run_sql_with_polling(
 655            http_client.clone(),
 656            &base_url,
 657            &token,
 658            &request,
 659            &step_progress,
 660            background_executor.clone(),
 661        )
 662        .await?;
 663
 664        let total_rows = response
 665            .result_set_meta_data
 666            .as_ref()
 667            .and_then(|m| m.num_rows)
 668            .unwrap_or(response.data.len() as i64);
 669
 670        let num_partitions = response
 671            .result_set_meta_data
 672            .as_ref()
 673            .map(|m| m.partition_info.len())
 674            .unwrap_or(1)
 675            .max(1);
 676
 677        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 678        step_progress.set_substatus("parsing");
 679
 680        let column_indices = get_column_indices(
 681            &response.result_set_meta_data,
 682            &["request_id", "device_id", "time", "input"],
 683        );
 684
 685        all_examples.extend(requested_examples_from_response(
 686            &response,
 687            &column_indices,
 688        )?);
 689
 690        if num_partitions > 1 {
 691            let statement_handle = response
 692                .statement_handle
 693                .as_ref()
 694                .context("response has multiple partitions but no statementHandle")?;
 695
 696            for partition in 1..num_partitions {
 697                step_progress.set_substatus(format!(
 698                    "fetching partition {}/{}",
 699                    partition + 1,
 700                    num_partitions
 701                ));
 702
 703                let partition_response = fetch_partition(
 704                    http_client.clone(),
 705                    &base_url,
 706                    &token,
 707                    statement_handle,
 708                    partition,
 709                )
 710                .await?;
 711
 712                all_examples.extend(requested_examples_from_response(
 713                    &partition_response,
 714                    &column_indices,
 715                )?);
 716            }
 717        }
 718
 719        step_progress.set_substatus("done");
 720    }
 721
 722    Ok(all_examples)
 723}
 724
 725pub async fn fetch_rated_examples_after(
 726    http_client: Arc<dyn HttpClient>,
 727    inputs: &[(String, Option<EditPredictionRating>)],
 728    max_rows_per_timestamp: usize,
 729    background_executor: BackgroundExecutor,
 730) -> Result<Vec<Example>> {
 731    if inputs.is_empty() {
 732        return Ok(Vec::new());
 733    }
 734
 735    let progress = Progress::global();
 736
 737    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
 738        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
 739    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
 740        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
 741    )?;
 742    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
 743
 744    let mut all_examples = Vec::new();
 745
 746    for (after_date, rating_filter) in inputs.iter() {
 747        let filter_label = match rating_filter {
 748            None => "",
 749            Some(EditPredictionRating::Positive) => ":positive",
 750            Some(EditPredictionRating::Negative) => ":negative",
 751        };
 752        let step_progress_name = format!("rated{filter_label}>{after_date}");
 753        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
 754        step_progress.set_substatus("querying");
 755
 756        let rating_value = rating_filter.as_ref().map(|r| match r {
 757            EditPredictionRating::Positive => "Positive",
 758            EditPredictionRating::Negative => "Negative",
 759        });
 760
 761        let statement = indoc! {r#"
 762            SELECT
 763                rated.event_properties:request_id::string AS request_id,
 764                rated.event_properties:inputs AS inputs,
 765                rated.event_properties:output::string AS output,
 766                rated.event_properties:rating::string AS rating,
 767                rated.event_properties:feedback::string AS feedback,
 768                rated.device_id::string AS device_id,
 769                rated.time::string AS time,
 770                deploy.event_properties:experiment_name::string AS experiment_name,
 771                deploy.event_properties:environment::string AS environment
 772            FROM events rated
 773            LEFT JOIN events req
 774                ON rated.event_properties:request_id::string = req.event_properties:request_id::string
 775                AND req.event_type = ?
 776            LEFT JOIN events deploy
 777                ON req.event_properties:headers:x_baseten_model_id::string = deploy.event_properties:model_id::string
 778                AND req.event_properties:headers:x_baseten_model_version_id::string = deploy.event_properties:model_version_id::string
 779                AND deploy.event_type = ?
 780            WHERE rated.event_type = ?
 781                AND (? IS NULL OR rated.event_properties:rating::string = ?)
 782                AND rated.time > TRY_TO_TIMESTAMP_NTZ(?)
 783                AND rated.event_properties:inputs IS NOT NULL
 784                AND rated.event_properties:inputs:cursor_excerpt IS NOT NULL
 785                AND rated.event_properties:output IS NOT NULL
 786            ORDER BY rated.time ASC
 787            LIMIT ?
 788        "#};
 789
 790        let bindings = json!({
 791            "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
 792            "2": { "type": "TEXT", "value": EDIT_PREDICTION_DEPLOYMENT_EVENT },
 793            "3": { "type": "TEXT", "value": EDIT_PREDICTION_RATED_EVENT },
 794            "4": { "type": "TEXT", "value": rating_value },
 795            "5": { "type": "TEXT", "value": rating_value },
 796            "6": { "type": "TEXT", "value": after_date },
 797            "7": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
 798        });
 799
 800        let request = json!({
 801            "statement": statement,
 802            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
 803            "database": "EVENTS",
 804            "schema": "PUBLIC",
 805            "warehouse": "DBT",
 806            "role": role,
 807            "bindings": bindings
 808        });
 809
 810        let response = run_sql_with_polling(
 811            http_client.clone(),
 812            &base_url,
 813            &token,
 814            &request,
 815            &step_progress,
 816            background_executor.clone(),
 817        )
 818        .await?;
 819
 820        let total_rows = response
 821            .result_set_meta_data
 822            .as_ref()
 823            .and_then(|m| m.num_rows)
 824            .unwrap_or(response.data.len() as i64);
 825
 826        let num_partitions = response
 827            .result_set_meta_data
 828            .as_ref()
 829            .map(|m| m.partition_info.len())
 830            .unwrap_or(1)
 831            .max(1);
 832
 833        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
 834        step_progress.set_substatus("parsing");
 835
 836        let column_indices = get_column_indices(
 837            &response.result_set_meta_data,
 838            &[
 839                "request_id",
 840                "inputs",
 841                "output",
 842                "rating",
 843                "feedback",
 844                "device_id",
 845                "time",
 846                "experiment_name",
 847                "environment",
 848            ],
 849        );
 850
 851        all_examples.extend(rated_examples_from_response(&response, &column_indices)?);
 852
 853        if num_partitions > 1 {
 854            let statement_handle = response
 855                .statement_handle
 856                .as_ref()
 857                .context("response has multiple partitions but no statementHandle")?;
 858
 859            for partition in 1..num_partitions {
 860                step_progress.set_substatus(format!(
 861                    "fetching partition {}/{}",
 862                    partition + 1,
 863                    num_partitions
 864                ));
 865
 866                let partition_response = fetch_partition(
 867                    http_client.clone(),
 868                    &base_url,
 869                    &token,
 870                    statement_handle,
 871                    partition,
 872                )
 873                .await?;
 874
 875                all_examples.extend(rated_examples_from_response(
 876                    &partition_response,
 877                    &column_indices,
 878                )?);
 879            }
 880        }
 881
 882        step_progress.set_substatus("done");
 883    }
 884
 885    Ok(all_examples)
 886}
 887
 888fn rated_examples_from_response<'a>(
 889    response: &'a SnowflakeStatementResponse,
 890    column_indices: &'a std::collections::HashMap<String, usize>,
 891) -> Result<impl Iterator<Item = Example> + 'a> {
 892    if let Some(code) = &response.code {
 893        if code != SNOWFLAKE_SUCCESS_CODE {
 894            anyhow::bail!(
 895                "snowflake sql api returned error code={code} message={}",
 896                response.message.as_deref().unwrap_or("<no message>")
 897            );
 898        }
 899    }
 900
 901    let iter = response
 902        .data
 903        .iter()
 904        .enumerate()
 905        .filter_map(move |(row_index, data_row)| {
 906            let get_string = |name: &str| -> Option<String> {
 907                let index = column_indices.get(name).copied()?;
 908                match data_row.get(index)? {
 909                    JsonValue::String(s) => Some(s.clone()),
 910                    JsonValue::Null => None,
 911                    other => Some(other.to_string()),
 912                }
 913            };
 914
 915            let get_json = |name: &str| -> Option<JsonValue> {
 916                let index = column_indices.get(name).copied()?;
 917                let value = data_row.get(index)?;
 918                if value.is_null() {
 919                    return None;
 920                }
 921                match value {
 922                    JsonValue::String(s) => serde_json::from_str(s).ok(),
 923                    other => Some(other.clone()),
 924                }
 925            };
 926
 927            let request_id = get_string("request_id");
 928            let inputs_json = get_json("inputs");
 929            let inputs: Option<ZetaPromptInput> = match &inputs_json {
 930                Some(v) => match serde_json::from_value(v.clone()) {
 931                    Ok(parsed) => Some(parsed),
 932                    Err(e) => {
 933                        log::warn!(
 934                            "skipping row {row_index}: failed to parse inputs - {e}",
 935                        );
 936                        return None;
 937                    }
 938                },
 939                None => None,
 940            };
 941            let output = get_string("output");
 942            let rating = get_string("rating");
 943            let feedback = get_string("feedback").unwrap_or_default();
 944            let device_id = get_string("device_id");
 945            let time = get_string("time");
 946            let experiment_name = get_string("experiment_name");
 947            let environment = get_string("environment");
 948
 949            match (inputs, output.clone(), rating.clone(), device_id.clone(), time.clone()) {
 950                (Some(inputs), Some(output), Some(rating), Some(device_id), Some(time)) => {
 951                    Some(build_rated_example(
 952                        request_id,
 953                        device_id,
 954                        time,
 955                        inputs,
 956                        output,
 957                        rating,
 958                        feedback,
 959                        experiment_name,
 960                        environment,
 961                    ))
 962                }
 963                _ => {
 964                    log::warn!(
 965                        "skipping row {row_index}: missing fields - inputs={:?} output={:?} rating={:?} device_id={:?} time={:?}",
 966                        inputs_json.is_some(),
 967                        output.is_some(),
 968                        rating.is_some(),
 969                        device_id.is_some(),
 970                        time.is_some(),
 971                    );
 972                    None
 973                }
 974            }
 975        });
 976
 977    Ok(iter)
 978}
 979
 980fn build_rated_example(
 981    request_id: Option<String>,
 982    device_id: String,
 983    time: String,
 984    input: ZetaPromptInput,
 985    output: String,
 986    rating: String,
 987    feedback: String,
 988    experiment_name: Option<String>,
 989    environment: Option<String>,
 990) -> Example {
 991    let parsed_rating = if rating == "Positive" {
 992        EditPredictionRating::Positive
 993    } else {
 994        EditPredictionRating::Negative
 995    };
 996    let is_positive = parsed_rating == EditPredictionRating::Positive;
 997    let request_id = request_id.unwrap_or_else(|| format!("rated-{}-{}", device_id, time));
 998
 999    let mut tags = Vec::with_capacity(3);
1000    tags.push(if is_positive {
1001        "rated:positive".to_string()
1002    } else {
1003        "rated:negative".to_string()
1004    });
1005    if let Some(experiment) = experiment_name {
1006        tags.push(format!("experiment:{experiment}"));
1007    }
1008    if let Some(env) = environment {
1009        tags.push(format!("environment:{env}"));
1010    }
1011
1012    let mut example = build_example_from_snowflake(request_id, device_id, time, input, tags, None);
1013
1014    example.spec.rating = Some(parsed_rating);
1015
1016    if !feedback.is_empty() {
1017        example
1018            .spec
1019            .human_feedback
1020            .push(edit_prediction::example_spec::HumanFeedback { message: feedback });
1021    }
1022
1023    if is_positive {
1024        example.spec.expected_patches = vec![output];
1025    } else {
1026        example.spec.rejected_patch = Some(output);
1027    }
1028
1029    example
1030}
1031
1032fn requested_examples_from_response<'a>(
1033    response: &'a SnowflakeStatementResponse,
1034    column_indices: &'a std::collections::HashMap<String, usize>,
1035) -> Result<impl Iterator<Item = Example> + 'a> {
1036    if let Some(code) = &response.code {
1037        if code != SNOWFLAKE_SUCCESS_CODE {
1038            anyhow::bail!(
1039                "snowflake sql api returned error code={code} message={}",
1040                response.message.as_deref().unwrap_or("<no message>")
1041            );
1042        }
1043    }
1044
1045    let iter = response
1046        .data
1047        .iter()
1048        .enumerate()
1049        .filter_map(move |(row_index, data_row)| {
1050            let get_string = |name: &str| -> Option<String> {
1051                let index = column_indices.get(name).copied()?;
1052                match data_row.get(index)? {
1053                    JsonValue::String(s) => Some(s.clone()),
1054                    JsonValue::Null => None,
1055                    other => Some(other.to_string()),
1056                }
1057            };
1058
1059            let get_json = |name: &str| -> Option<JsonValue> {
1060                let index = column_indices.get(name).copied()?;
1061                let value = data_row.get(index)?;
1062                if value.is_null() {
1063                    return None;
1064                }
1065                match value {
1066                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1067                    other => Some(other.clone()),
1068                }
1069            };
1070
1071            let request_id_str = get_string("request_id");
1072            let device_id = get_string("device_id");
1073            let time = get_string("time");
1074            let input_json = get_json("input");
1075            let input: Option<ZetaPromptInput> =
1076                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1077
1078            match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
1079                (Some(request_id), Some(device_id), Some(time), Some(input)) => {
1080                    Some(build_example_from_snowflake(
1081                        request_id,
1082                        device_id,
1083                        time,
1084                        input,
1085                        vec!["requested".to_string()],
1086                        None,
1087                    ))
1088                }
1089                _ => {
1090                    log::warn!(
1091                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
1092                        request_id_str.is_some(),
1093                        device_id.is_some(),
1094                        time.is_some(),
1095                        input_json.is_some(),
1096                    );
1097                    None
1098                }
1099            }
1100        });
1101
1102    Ok(iter)
1103}
1104
1105fn rejected_examples_from_response<'a>(
1106    response: &'a SnowflakeStatementResponse,
1107    column_indices: &'a std::collections::HashMap<String, usize>,
1108) -> Result<impl Iterator<Item = Example> + 'a> {
1109    if let Some(code) = &response.code {
1110        if code != SNOWFLAKE_SUCCESS_CODE {
1111            anyhow::bail!(
1112                "snowflake sql api returned error code={code} message={}",
1113                response.message.as_deref().unwrap_or("<no message>")
1114            );
1115        }
1116    }
1117
1118    let iter = response
1119        .data
1120        .iter()
1121        .enumerate()
1122        .filter_map(move |(row_index, data_row)| {
1123            let get_string = |name: &str| -> Option<String> {
1124                let index = column_indices.get(name).copied()?;
1125                match data_row.get(index)? {
1126                    JsonValue::String(s) => Some(s.clone()),
1127                    JsonValue::Null => None,
1128                    other => Some(other.to_string()),
1129                }
1130            };
1131
1132            let get_json = |name: &str| -> Option<JsonValue> {
1133                let index = column_indices.get(name).copied()?;
1134                let value = data_row.get(index)?;
1135                if value.is_null() {
1136                    return None;
1137                }
1138                match value {
1139                    JsonValue::String(s) => serde_json::from_str(s).ok(),
1140                    other => Some(other.clone()),
1141                }
1142            };
1143
1144            let get_bool = |name: &str| -> Option<bool> {
1145                let index = column_indices.get(name).copied()?;
1146                match data_row.get(index)? {
1147                    JsonValue::Bool(b) => Some(*b),
1148                    JsonValue::String(s) => s.parse().ok(),
1149                    _ => None,
1150                }
1151            };
1152
1153            let request_id_str = get_string("request_id");
1154            let device_id = get_string("device_id");
1155            let time = get_string("time");
1156            let input_json = get_json("input");
1157            let input: Option<ZetaPromptInput> =
1158                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
1159            let output = get_string("output");
1160            let was_shown = get_bool("was_shown");
1161            let reason = get_string("reason");
1162
1163            match (request_id_str.clone(), device_id.clone(), time.clone(), input, output.clone(), was_shown, reason.clone()) {
1164                (Some(request_id), Some(device_id), Some(time), Some(input), Some(output), Some(was_shown), Some(reason)) => {
1165                    Some(build_rejected_example(
1166                        request_id,
1167                        device_id,
1168                        time,
1169                        input,
1170                        output,
1171                        was_shown,
1172                        reason,
1173                    ))
1174                }
1175                _ => {
1176                    log::warn!(
1177                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?} output={:?} was_shown={:?} reason={:?}",
1178                        request_id_str.is_some(),
1179                        device_id.is_some(),
1180                        time.is_some(),
1181                        input_json.is_some(),
1182                        output.is_some(),
1183                        was_shown.is_some(),
1184                        reason.is_some()
1185                    );
1186                    None
1187                }
1188            }
1189        });
1190
1191    Ok(iter)
1192}
1193
1194fn build_rejected_example(
1195    request_id: String,
1196    device_id: String,
1197    time: String,
1198    input: ZetaPromptInput,
1199    output: String,
1200    was_shown: bool,
1201    reason: String,
1202) -> Example {
1203    let rejected_patch = build_output_patch(
1204        &input.cursor_path,
1205        input.cursor_excerpt.as_ref(),
1206        &input.editable_range_in_excerpt,
1207        &output,
1208    );
1209    let mut example = build_example_from_snowflake(
1210        request_id,
1211        device_id,
1212        time,
1213        input,
1214        vec![format!("rejection:{}", reason.to_lowercase())],
1215        Some(RejectionInfo { reason, was_shown }),
1216    );
1217    example.spec.rejected_patch = Some(rejected_patch);
1218    example
1219}
1220
1221struct RejectionInfo {
1222    reason: String,
1223    was_shown: bool,
1224}
1225
1226fn build_example_from_snowflake(
1227    request_id: String,
1228    device_id: String,
1229    time: String,
1230    input: ZetaPromptInput,
1231    tags: Vec<String>,
1232    rejection: Option<RejectionInfo>,
1233) -> Example {
1234    let events: Vec<CapturedEvent> = input
1235        .events
1236        .iter()
1237        .map(|event| match event.as_ref() {
1238            zeta_prompt::Event::BufferChange {
1239                path,
1240                old_path,
1241                diff,
1242                predicted,
1243                in_open_source_repo,
1244            } => CapturedEvent {
1245                path: path.clone(),
1246                old_path: old_path.clone(),
1247                diff: diff.clone(),
1248                predicted: *predicted,
1249                in_open_source_repo: *in_open_source_repo,
1250            },
1251        })
1252        .collect();
1253
1254    let related_files: Vec<CapturedRelatedFile> = input
1255        .related_files
1256        .iter()
1257        .map(|rf| CapturedRelatedFile {
1258            path: rf.path.clone(),
1259            max_row: rf.max_row,
1260            excerpts: rf
1261                .excerpts
1262                .iter()
1263                .map(|e| CapturedRelatedExcerpt {
1264                    row_range: e.row_range.clone(),
1265                    text: e.text.to_string(),
1266                })
1267                .collect(),
1268        })
1269        .collect();
1270
1271    let cursor_excerpt = input.cursor_excerpt.as_ref();
1272    let cursor_offset = input.cursor_offset_in_excerpt;
1273
1274    let (cursor_row, cursor_column) = compute_row_column(cursor_excerpt, cursor_offset);
1275
1276    let mut edit_history = String::new();
1277    for event in &input.events {
1278        zeta_prompt::write_event(&mut edit_history, event);
1279        edit_history.push('\n');
1280    }
1281
1282    let (rejection_reason, was_shown) = match &rejection {
1283        Some(r) => (r.reason.clone(), r.was_shown),
1284        None => (String::new(), false),
1285    };
1286
1287    let spec = ExampleSpec {
1288        name: request_id.clone(),
1289        repository_url: String::new(),
1290        revision: String::new(),
1291        tags,
1292        reasoning: None,
1293        uncommitted_diff: String::new(),
1294        cursor_path: input.cursor_path.clone(),
1295        cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
1296        edit_history,
1297        expected_patches: Vec::new(),
1298        rejected_patch: None,
1299        captured_prompt_input: Some(CapturedPromptInput {
1300            cursor_file_content: cursor_excerpt.to_string(),
1301            cursor_offset,
1302            cursor_row,
1303            cursor_column,
1304            excerpt_start_row: None,
1305            events,
1306            related_files,
1307        }),
1308        telemetry: Some(TelemetrySource {
1309            request_id,
1310            device_id,
1311            time,
1312            rejection_reason,
1313            was_shown,
1314        }),
1315        human_feedback: Vec::new(),
1316        rating: None,
1317    };
1318
1319    Example {
1320        spec,
1321        prompt_inputs: None,
1322        prompt: None,
1323        predictions: Vec::new(),
1324        score: Vec::new(),
1325        qa: Vec::new(),
1326        state: None,
1327    }
1328}
1329
1330fn compute_row_column(text: &str, offset: usize) -> (u32, u32) {
1331    let mut row = 0u32;
1332    let mut last_newline_offset = 0;
1333    for (i, c) in text.char_indices() {
1334        if i >= offset {
1335            break;
1336        }
1337        if c == '\n' {
1338            row += 1;
1339            last_newline_offset = i + 1;
1340        }
1341    }
1342    let column = (offset - last_newline_offset) as u32;
1343    (row, column)
1344}
1345
1346fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
1347    let before = &excerpt[..cursor_offset.min(excerpt.len())];
1348    let after = &excerpt[cursor_offset.min(excerpt.len())..];
1349    format!("{}[CURSOR_POSITION]{}", before, after)
1350}
1351
1352fn build_output_patch(
1353    cursor_path: &std::path::Path,
1354    cursor_excerpt: &str,
1355    editable_range: &std::ops::Range<usize>,
1356    model_output: &str,
1357) -> String {
1358    let old_text = &cursor_excerpt[editable_range.clone()];
1359
1360    let editable_start_row = cursor_excerpt[..editable_range.start]
1361        .chars()
1362        .filter(|&c| c == '\n')
1363        .count() as u32;
1364
1365    let diff_body = language::unified_diff_with_offsets(
1366        old_text,
1367        model_output,
1368        editable_start_row,
1369        editable_start_row,
1370    );
1371
1372    let mut patch = String::new();
1373    writeln!(&mut patch, "--- a/{}", cursor_path.display()).ok();
1374    writeln!(&mut patch, "+++ b/{}", cursor_path.display()).ok();
1375    patch.push_str(&diff_body);
1376    patch
1377}
1378
1379pub(crate) fn get_column_indices(
1380    meta: &Option<SnowflakeResultSetMetaData>,
1381    names: &[&str],
1382) -> std::collections::HashMap<String, usize> {
1383    let mut indices = std::collections::HashMap::new();
1384    if let Some(meta) = meta {
1385        for (index, col) in meta.row_type.iter().enumerate() {
1386            for &name in names {
1387                if col.name.eq_ignore_ascii_case(name) {
1388                    indices.insert(name.to_string(), index);
1389                }
1390            }
1391        }
1392    }
1393    indices
1394}